第1关 MLlib介绍
- package com.educoder.bigData.sparksql5;
- import java.util.Arrays;
- import java.util.List;
- import org.apache.spark.ml.Pipeline;
- import org.apache.spark.ml.PipelineModel;
- import org.apache.spark.ml.PipelineStage;
- import org.apache.spark.ml.classification.LogisticRegression;
- import org.apache.spark.ml.feature.HashingTF;
- import org.apache.spark.ml.feature.Tokenizer;
- import org.apache.spark.sql.Dataset;
- import org.apache.spark.sql.Row;
- import org.apache.spark.sql.RowFactory;
- import org.apache.spark.sql.SparkSession;
- import org.apache.spark.sql.types.DataTypes;
- import org.apache.spark.sql.types.Metadata;
- import org.apache.spark.sql.types.StructField;
- import org.apache.spark.sql.types.StructType;
- public class Test1 {
- public static void main(String[] args) {
- SparkSession spark = SparkSession.builder().appName("test1").master("local").getOrCreate();
- List<Row> trainingList = Arrays.asList(
- RowFactory.create(1.0, "a b c d E spark"),
- RowFactory.create(0.0, "b d"),
- RowFactory.create(1.0, "hadoop Mapreduce"),
- RowFactory.create(0.0, "f g h"));
- List<Row> testList = Arrays.asList(
- RowFactory.create(0.0, "spark I j k"),
- RowFactory.create(0.0, "l M n"),
- RowFactory.create(0.0, "f g"),
- RowFactory.create(0.0, "apache hadoop")
- );
- /********* Begin *********/
- StructType schema = new StructType(
- new StructField[] { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
- new StructField("text", DataTypes.StringType, false, Metadata.empty()) });
- Dataset<Row> training = spark.createDataFrame(trainingList, schema);
- Dataset<Row> test = spark.createDataFrame(testList, schema);
- Tokenizer tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words");
- HashingTF hashingTF = new HashingTF().setNumFeatures(1000).setInputCol("words").setOutputCol("features");
- LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.001);
- Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { tokenizer, hashingTF, lr });
- PipelineModel fit = pipeline.fit(training);
- fit.transform(test).select("prediction").show();
- /********* End *********/
- }
- }
复制代码 第2关 MLlib-垃圾邮件检测
- package com.educoder.bigData.sparksql5;
- import java.util.Arrays;
- import org.apache.spark.api.java.JavaRDD;
- import org.apache.spark.api.java.function.Function;
- import org.apache.spark.ml.Pipeline;
- import org.apache.spark.ml.PipelineModel;
- import org.apache.spark.ml.PipelineStage;
- import org.apache.spark.ml.classification.GBTClassifier;
- import org.apache.spark.ml.feature.StringIndexer;
- import org.apache.spark.ml.feature.Word2Vec;
- import org.apache.spark.sql.Dataset;
- import org.apache.spark.sql.Row;
- import org.apache.spark.sql.RowFactory;
- import org.apache.spark.sql.SparkSession;
- import org.apache.spark.sql.types.DataTypes;
- import org.apache.spark.sql.types.Metadata;
- import org.apache.spark.sql.types.StructField;
- import org.apache.spark.sql.types.StructType;
- public class Case2 {
- public static PipelineModel training(SparkSession spark) {
- /********* Begin *********/
- JavaRDD<Row> map = spark.read().textFile("SMSSpamCollection").toJavaRDD()
- .map(String -> String.split(" ")).map(new Function<String[], Row>() {
- @Override
- public Row call(String[] v1) throws Exception {
- String[] copyOfRange = Arrays.copyOfRange(v1, 1, v1.length);
- Row create = RowFactory.create(v1[0], copyOfRange);
- return create;
- }
- });
- StructType schema = new StructType(new StructField[] {
- new StructField("label", DataTypes.StringType, false, Metadata.empty()),
- new StructField("message", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) });
- Dataset<Row> data = spark.createDataFrame(map, schema);
- StringIndexer labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel");
- Word2Vec word2Vec = new Word2Vec().setInputCol("message").setOutputCol("features");
- GBTClassifier mlpc = new GBTClassifier().setLabelCol("indexedLabel")
- .setFeaturesCol("features");
- Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] { labelIndexer, word2Vec, mlpc});
- PipelineModel fit = pipeline.fit(data);
- /********* End *********/
- return fit;
- }
- }
复制代码 第3关 MLlib-红酒分类预测
- package com.educoder.bigData.sparksql5;
- import org.apache.spark.api.java.JavaRDD;
- import org.apache.spark.api.java.function.Function;
- import org.apache.spark.ml.Pipeline;
- import org.apache.spark.ml.PipelineModel;
- import org.apache.spark.ml.PipelineStage;
- import org.apache.spark.ml.classification.*;
- import org.apache.spark.ml.linalg.VectorUDT;
- import org.apache.spark.ml.linalg.Vectors;
- import org.apache.spark.sql.Dataset;
- import org.apache.spark.sql.Row;
- import org.apache.spark.sql.RowFactory;
- import org.apache.spark.sql.SparkSession;
- import org.apache.spark.sql.types.DataTypes;
- import org.apache.spark.sql.types.Metadata;
- import org.apache.spark.sql.types.StructField;
- import org.apache.spark.sql.types.StructType;
- public class Case3 {
- public static PipelineModel training(SparkSession spark) {
- /********* Begin *********/
- JavaRDD<Row> javaRDD = spark.read().csv("dataset.csv").toJavaRDD();
- JavaRDD<Row> map = javaRDD.map(new Function<Row, Row>() {
- @Override
- public Row call(Row v1) throws Exception {
- int size = v1.size();
- // 第一列为标签
- double labelDouble = Double.parseDouble(v1.get(0).toString());
- // 获取特征数组
- double[] features = new double[size -1];
- for(int n = 1 ; n < size ; n++) {
- features[n-1] = Double.parseDouble(v1.get(n).toString());
- }
- // 创建 row
- Row create = RowFactory.create(labelDouble,Vectors.dense(features));
- return create;
- }
- });
- StructType schema = new StructType(
- new StructField[] { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
- new StructField("features", new VectorUDT(), false, Metadata.empty()) });
- Dataset<Row> createDataFrame = spark.createDataFrame(map, schema);
- RandomForestClassifier mlpc = new RandomForestClassifier()
- .setLabelCol("label")
- .setFeaturesCol("features");
- Pipeline pipeline = new Pipeline()
- .setStages(new PipelineStage[] { mlpc});
- PipelineModel fit = pipeline.fit(createDataFrame);
- /********* End *********/
- return fit;
- }
- }
复制代码 免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。 |