头歌 Spark的机器学习-MLlib

打印 上一主题 下一主题

主题 556|帖子 556|积分 1668

第1关 MLlib介绍

  1. package com.educoder.bigData.sparksql5;
  2. import java.util.Arrays;
  3. import java.util.List;
  4. import org.apache.spark.ml.Pipeline;
  5. import org.apache.spark.ml.PipelineModel;
  6. import org.apache.spark.ml.PipelineStage;
  7. import org.apache.spark.ml.classification.LogisticRegression;
  8. import org.apache.spark.ml.feature.HashingTF;
  9. import org.apache.spark.ml.feature.Tokenizer;
  10. import org.apache.spark.sql.Dataset;
  11. import org.apache.spark.sql.Row;
  12. import org.apache.spark.sql.RowFactory;
  13. import org.apache.spark.sql.SparkSession;
  14. import org.apache.spark.sql.types.DataTypes;
  15. import org.apache.spark.sql.types.Metadata;
  16. import org.apache.spark.sql.types.StructField;
  17. import org.apache.spark.sql.types.StructType;
  18. public class Test1 {
  19. public static void main(String[] args) {
  20. SparkSession spark = SparkSession.builder().appName("test1").master("local").getOrCreate();
  21. List<Row> trainingList = Arrays.asList(
  22. RowFactory.create(1.0, "a b c d E spark"),
  23. RowFactory.create(0.0, "b d"),
  24. RowFactory.create(1.0, "hadoop Mapreduce"),
  25. RowFactory.create(0.0, "f g h"));
  26. List<Row> testList = Arrays.asList(
  27. RowFactory.create(0.0, "spark I j k"),
  28. RowFactory.create(0.0, "l M n"),
  29. RowFactory.create(0.0, "f g"),
  30. RowFactory.create(0.0, "apache hadoop")
  31. );
  32. /********* Begin *********/
  33. StructType schema = new StructType(
  34. new StructField[] { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
  35. new StructField("text", DataTypes.StringType, false, Metadata.empty()) });
  36. Dataset<Row> training = spark.createDataFrame(trainingList, schema);
  37. Dataset<Row> test = spark.createDataFrame(testList, schema);
  38. Tokenizer tokenizer = new Tokenizer().setInputCol("text").setOutputCol("words");
  39. HashingTF hashingTF = new HashingTF().setNumFeatures(1000).setInputCol("words").setOutputCol("features");
  40. LogisticRegression lr = new LogisticRegression().setMaxIter(10).setRegParam(0.001);
  41. Pipeline pipeline = new Pipeline().setStages(new PipelineStage[] { tokenizer, hashingTF, lr });
  42. PipelineModel fit = pipeline.fit(training);
  43. fit.transform(test).select("prediction").show();
  44. /********* End *********/
  45. }
  46. }
复制代码
第2关 MLlib-垃圾邮件检测

  1. package com.educoder.bigData.sparksql5;
  2. import java.util.Arrays;
  3. import org.apache.spark.api.java.JavaRDD;
  4. import org.apache.spark.api.java.function.Function;
  5. import org.apache.spark.ml.Pipeline;
  6. import org.apache.spark.ml.PipelineModel;
  7. import org.apache.spark.ml.PipelineStage;
  8. import org.apache.spark.ml.classification.GBTClassifier;
  9. import org.apache.spark.ml.feature.StringIndexer;
  10. import org.apache.spark.ml.feature.Word2Vec;
  11. import org.apache.spark.sql.Dataset;
  12. import org.apache.spark.sql.Row;
  13. import org.apache.spark.sql.RowFactory;
  14. import org.apache.spark.sql.SparkSession;
  15. import org.apache.spark.sql.types.DataTypes;
  16. import org.apache.spark.sql.types.Metadata;
  17. import org.apache.spark.sql.types.StructField;
  18. import org.apache.spark.sql.types.StructType;
  19. public class Case2 {
  20. public static PipelineModel training(SparkSession spark) {
  21. /********* Begin *********/
  22. JavaRDD<Row> map = spark.read().textFile("SMSSpamCollection").toJavaRDD()
  23. .map(String -> String.split(" ")).map(new Function<String[], Row>() {
  24. @Override
  25. public Row call(String[] v1) throws Exception {
  26. String[] copyOfRange = Arrays.copyOfRange(v1, 1, v1.length);
  27. Row create = RowFactory.create(v1[0], copyOfRange);
  28. return create;
  29. }
  30. });
  31. StructType schema = new StructType(new StructField[] {
  32. new StructField("label", DataTypes.StringType, false, Metadata.empty()),
  33. new StructField("message", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) });
  34. Dataset<Row> data = spark.createDataFrame(map, schema);
  35. StringIndexer labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel");
  36. Word2Vec word2Vec = new Word2Vec().setInputCol("message").setOutputCol("features");
  37. GBTClassifier mlpc = new GBTClassifier().setLabelCol("indexedLabel")
  38. .setFeaturesCol("features");
  39. Pipeline pipeline = new Pipeline()
  40. .setStages(new PipelineStage[] { labelIndexer, word2Vec, mlpc});
  41. PipelineModel fit = pipeline.fit(data);
  42. /********* End *********/
  43. return fit;
  44. }
  45. }
复制代码
第3关 MLlib-红酒分类预测

  1. package com.educoder.bigData.sparksql5;
  2. import org.apache.spark.api.java.JavaRDD;
  3. import org.apache.spark.api.java.function.Function;
  4. import org.apache.spark.ml.Pipeline;
  5. import org.apache.spark.ml.PipelineModel;
  6. import org.apache.spark.ml.PipelineStage;
  7. import org.apache.spark.ml.classification.*;
  8. import org.apache.spark.ml.linalg.VectorUDT;
  9. import org.apache.spark.ml.linalg.Vectors;
  10. import org.apache.spark.sql.Dataset;
  11. import org.apache.spark.sql.Row;
  12. import org.apache.spark.sql.RowFactory;
  13. import org.apache.spark.sql.SparkSession;
  14. import org.apache.spark.sql.types.DataTypes;
  15. import org.apache.spark.sql.types.Metadata;
  16. import org.apache.spark.sql.types.StructField;
  17. import org.apache.spark.sql.types.StructType;
  18. public class Case3 {
  19. public static PipelineModel training(SparkSession spark) {
  20. /********* Begin *********/
  21. JavaRDD<Row> javaRDD = spark.read().csv("dataset.csv").toJavaRDD();
  22. JavaRDD<Row> map = javaRDD.map(new Function<Row, Row>() {
  23. @Override
  24. public Row call(Row v1) throws Exception {
  25. int size = v1.size();
  26. // 第一列为标签
  27. double labelDouble = Double.parseDouble(v1.get(0).toString());
  28. // 获取特征数组
  29. double[] features = new double[size -1];
  30. for(int n = 1 ; n < size ; n++) {
  31. features[n-1] = Double.parseDouble(v1.get(n).toString());
  32. }
  33. // 创建 row
  34. Row create = RowFactory.create(labelDouble,Vectors.dense(features));
  35. return create;
  36. }
  37. });
  38. StructType schema = new StructType(
  39. new StructField[] { new StructField("label", DataTypes.DoubleType, false, Metadata.empty()),
  40. new StructField("features", new VectorUDT(), false, Metadata.empty()) });
  41. Dataset<Row> createDataFrame = spark.createDataFrame(map, schema);
  42. RandomForestClassifier mlpc = new RandomForestClassifier()
  43. .setLabelCol("label")
  44. .setFeaturesCol("features");
  45. Pipeline pipeline = new Pipeline()
  46. .setStages(new PipelineStage[] { mlpc});
  47. PipelineModel fit = pipeline.fit(createDataFrame);
  48. /********* End *********/
  49. return fit;
  50. }
  51. }
复制代码
免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

汕尾海湾

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表