spark 3.4.4 呆板学习基于逻辑回归算法及管道流实现鸢尾花分类推测案例 ...

打印 上一主题 下一主题

主题 821|帖子 821|积分 2463

知识点介绍:本文基于鸢尾花案例,实现逻辑回归分类推测案例

逻辑回归算法介绍

1. 定义与基本原理



  • 定义:逻辑回归(Logistic Regression)是一种广泛应用于分类题目的统计学习方法,尽管名字里带有 “回归”,但它重要用于办理二分类题目,也可以通过扩展的方式(如 One-vs-Rest、One-vs-One 等策略)应用于多分类题目。
  • 基本原理:逻辑回归基于线性回归的头脑,首先对输入特性进行线性组合(通过盘算特性向量与权重向量的内积,再加上一个偏置项),得到一个线性的推测值,然后将这个线性推测值通过一个非线性的激活函数(通常是 Sigmoid 函数用于二分类,Softmax 函数用于多分类)进行转换,将结果映射到 0 到 1 之间(二分类时表现属于某一类别的概率,多分类时表现属于各个类别的概率分布),最后根据设定的阈值(一般为 0.5 用于二分类)来判断样本属于哪一类。

例如,对于一个简单的二分类逻辑回归,假设输入特性向量为,权重向量为,偏置项为,则线性组合的盘算公式为:,然后通过 Sigmoid 函数将转换为概率值,如果,则推测样本属于正类,否则属于负类。
2. 应用场景



  • 医疗范畴:判断患者是否患有某种疾病,好比根据患者的年龄、症状、检查指标等特性来推测是否患有糖尿病、心脏病等。例如,通过收集大量已确诊患者和健康人的相干数据,利用逻辑回归构建模子,对新的患者进行疾病风险推测。
  • 金融范畴:评估客户的光荣风险,决定是否给客户发放贷款等。例如,依据客户的收入、资产、光荣历史、负债环境等特性,模子推测客户违约的概率,银行根据这个概率来决定是否批准贷款申请。
  • 市场营销范畴:推测客户是否会购买某种产物或服务,基于客户的斲丧行为、人口统计学特性、浏览历史等,企业可以针对性地开展营销活动,提高营销结果。
3. 优缺点



  • 长处

    • 简单易懂且盘算服从高:逻辑回归的原理相对直观,基于线性组合和简单的函数变换,模子的训练和推测过程盘算复杂度较低,在大规模数据集上也能较快地得到结果。
    • 可表明性强:模子的权重可以直观地反映出每个特性对分类结果的影响程度,通过分析权重的正负和巨细,能够了解特性与目标类别之间的关联关系,有助于业务理解和决策。
    • 模子训练稳定:一般环境下,逻辑回归在公道的参数设置下不轻易出现梯度消失、梯度爆炸等训练不稳定的题目,能够较为稳定地收敛到一个较优的解。

  • 缺点

    • 只能处理线性可分题目(原始特性空间下):如果数据在原始特性空间中黑白线性可分的,逻辑回归的分类结果大概不佳,须要通过人工特性工程(如添加多项式特性等)或者联合核本领等方法来扩展特性空间,使其能够处理非线性关系。
    • 对特性要求较高:须要对输入的特性进行公道的选择和预处理,如果存在冗余、无关或者高度相干的特性,大概会影响模子的性能,导致过拟合或者欠拟合等题目,所以往往须要进行特性选择和特性缩放等操作。
    • 轻易欠拟合:在处理复杂的非线性关系数据时,由于其模子自己的线性本质,相较于一些复杂的非线性模子(如深度学习模子),大概较难拟合数据中的复杂模式,轻易出现欠拟合现象。

4. 紧张参数



  • 最大迭代次数(maxIter):控制模子训练时的迭代次数,用于确保模子能够收敛到一个相对稳定的解。如果设置的值过小,模子大概还未充分学习到数据中的模式就停止训练,导致欠拟合;而设置过大则大概导致训练时间过长,甚至大概出现过拟合的环境,须要根据具体数据集和题目进行公道调解,通常通过交叉验证等方法来探求合适的值。
  • 正则化参数(regParam):为了防止模子过拟合,在丧失函数中添加正则项,正则化参数用于控制正则项的强度。常见的正则化方式有 L1 正则化(Lasso 回归)、L2 正则化(岭回归)以及两者联合的 ElasticNet 正则化。合适的正则化参数可以平衡模子的复杂度和对训练数据的拟合程度,差异数据集的最优值差异,同样须要通过实验来确定。
  • ElasticNet 参数(ElasticNetParam):当利用 ElasticNet 正则化时,该参数用于控制 L1 和 L2 正则化的比例,取值范围在 0 到 1 之间,为 0 时表现只利用 L2 正则化,为 1 时表现只利用 L1 正则化,介于两者之间则是两者按相应比例联合。L1 正则化有助于进行特性选择,能将一些不紧张的特性对应的系数压缩为 0,而 L2 正则化能使模子的系数更加平滑,避免过大的系数值导致过拟合。

Spark ML 管道流(Pipeline)介绍

1. 概念与作用



  • 概念:Spark ML 中的管道流(Pipeline)是一种将多个呆板学习相干的处理步骤按照顺序组合成一个整体工作流的工具。它类似于工业生产中的流水线,数据从一端输入,按照预先设定好的各个阶段依次进行处理,最终在另一端输出处理后的结果,好比完成模子训练或者得到推测数据。
  • 作用

    • 方便流程管理:在现实的呆板学习项目中,通常包含数据读取、数据预处理(如特性缩放、缺失值处理、特性编码等)、特性工程(特性选择、特性组合等)以及模子训练、模子评估等多个步骤,利用 Pipeline 可以清楚地将这些步骤组织起来,形成一个逻辑连贯的整体流程,便于代码的编写、维护和理解。
    • 保证处理顺序:确保各个处理步骤按照正确的顺序执行,避免因为顺序杂乱导致的错误结果。例如,必须先对数据进行特性编码,再将编码后的特性输入到模子中进行训练,Pipeline 能够严酷按照设定的顺序依次调用每个阶段的处理逻辑。
    • 便于模子复用和部署:一旦定义好一个完整的 Pipeline,它可以方便地进行生存和加载,在差异的环境(如开辟环境、测试环境、生产环境)中复用,快速进行模子训练或者利用已训练好的模子进行推测,大大简化了模子部署和应用的过程。

2. 组成部门



  • 阶段(Stage):Pipeline 由多个阶段组成,每个阶段可以是一个数据转换操作(如StringIndexer将字符串标签转换为索引、VectorAssembler将多个特性列组合成特性向量等),也可以是一个呆板学习模子(如LogisticRegression、DecisionTreeClassifier等)。这些阶段按照在Pipeline中设置的顺序依次执行,数据在前一个阶段处理完成后会自动传递到下一个阶段进行后续处理。例如,一个简单的包含数据预处理和模子训练的 Pipeline 大概有以下几个阶段:

    • StringIndexer阶段:对字符串类型的标签列进行索引化处理。
    • VectorAssembler阶段:将多个数值特性列组装成一个特性向量列。
    • LogisticRegression阶段:利用组装好的特性向量列和索引化的标签列进行逻辑回归模子训练。


鸢尾花数据介绍
鸢尾花数据集(Iris Dataset)是一类非常经典且常用的数据集,在呆板学习、数据分析和统计学习等范畴被广泛应用,以下是对它的详细介绍:

1. 数据集泉源

鸢尾花数据集是由美国植物学家埃德加・安德森(Edgar Anderson)在 20 世纪 30 年代收集整理的,后经英国统计学家和生物学家罗纳德・费舍尔(Ronald Fisher)在其 1936 年的论文《The use of multiple measurements in taxonomic problems》中利用并推广开来,成为了分类使命中的标准测试数据集之一。
2. 数据内容



  • 特性维度:该数据集包含了 4 个特性,分别是:

    • 花萼长度(sepal length):通常以厘米为单元,形貌鸢尾花花萼部门的长度。
    • 花萼宽度(sepal width):同样以厘米为单元,对应花萼部门的宽度环境。
    • 花瓣长度(petal length):指鸢尾花花瓣的长度,单元厘米,是区分差异鸢尾花种类的紧张特性之一。
    • 花瓣宽度(petal width):以厘米为单元权衡花瓣的宽窄程度,在判断鸢尾花类别时也起着关键作用。

  • 类别标签:总共有 3 种差异的鸢尾花品种类别,分别是:

    • 山鸢尾(Iris-setosa):其花瓣相对较窄较短,在形态上与别的两种有比较明显的区别。
    • 变色鸢尾(Iris-versicolor):花瓣长度、宽度等特性处于别的两种鸢尾花之间的状态,具有一定的过渡特点。
    • 维吉尼亚鸢尾(Iris-virginica):通常花瓣较为宽大,整体花朵形态与前两者差异。

3. 数据规模

鸢尾花数据集一共包含 150 条样本数据,每种鸢尾花类别各有 50 条样本,整体规模较小,便于快速进行模子的训练、测试以及算法验证等操作,尤其得当初学者理解和掌握分类算法的原理及流程。

代码实现

  1. package cn.lh.pblh123.spark2024.theorycourse.charpter9
  2. import org.apache.spark.ml.Pipeline
  3. import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
  4. import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
  5. import org.apache.spark.ml.feature.{IndexToString, StringIndexer, VectorAssembler, VectorIndexer}
  6. import org.apache.spark.sql.SparkSession
  7. import org.apache.spark.sql.functions.col
  8. import org.apache.spark.sql.types.DoubleType
  9. object SparkMLLogicalRegressionPipeLine {
  10.   def main(args: Array[String]): Unit = {
  11.     if (args.length != 3) {
  12.       System.err.println("Usage: <murl> <inputfile> <modelpath>")
  13.       System.exit(1)
  14.     }
  15.     val murl = args(0)
  16.     val inputfile = args(1)
  17.     val modelpath = args(2)
  18.     val spark = SparkSession.builder().appName(s"${this.getClass.getName}").master(murl).getOrCreate()
  19.     // 加载数据
  20.     val df = spark.read.option("header", true)
  21.       .csv(inputfile)
  22.     df.show(3, false)
  23.     df.printSchema()
  24.     val dfDouble = df.select(col("sepal_length").cast(DoubleType), col("sepal_width").cast(DoubleType),
  25.       col("petal_length").cast(DoubleType), col("petal_width").cast(DoubleType),
  26.       col("species").alias("label")
  27.     )
  28.     dfDouble.printSchema()
  29.     // 特征处理
  30.     // 创建一个VectorAssembler实例,用于将多列特征组合成单一的特征向量
  31.     val assembler = new VectorAssembler().setInputCols(
  32.       Array("sepal_length", "sepal_width", "petal_length", "petal_width")
  33.     ).setOutputCol("features")
  34.     // 使用VectorAssembler转换原始DataFrame,生成一个新的DataFrame,其中包含特征向量和标签列
  35.     val dataFrame = assembler.transform(dfDouble).select("features", "label")
  36.     // 显示转换后的DataFrame的前3行数据,以验证转换结果
  37.     dataFrame.show(3, 0)
  38.     // 获取标签列和特征列
  39.     // 使用StringIndexer将标签列转换为索引形式,以便后续的机器学习算法能够处理
  40.     val labelIndex = new StringIndexer().setInputCol("label").setOutputCol("labelIndex").fit(dataFrame)
  41.     // 使用VectorIndexer对特征列进行索引,这有助于提高机器学习模型的效率和效果
  42.     val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").fit(dataFrame)
  43.     // 创建Logistic回归模型实例,设置标签列和特征列,以及模型的训练参数
  44.     // 最大迭代次数设为100,正则化参数设为0.3,ElasticNet参数设为0.8,这样的设置旨在平衡偏差和方差,避免过拟合
  45.     val logisticRegression = new LogisticRegression().setLabelCol("labelIndex").setFeaturesCol("indexedFeatures")
  46.       .setMaxIter(100).setRegParam(0.3).setElasticNetParam(0.8)
  47.     println("logistricRegression parameters:\n" + logisticRegression.explainParams() + "\n")
  48.     // 设置indexToString转换器
  49.     val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel")
  50.       .setLabels(labelIndex.labels)
  51.     // 设置逻辑回归流水线
  52.     val lrpiple = new Pipeline().setStages(Array(labelIndex, featureIndexer, logisticRegression, labelConverter))
  53.     // 划分训练集和测试集,利用随机种子
  54.     val Array(trainingData, testData) = dataFrame.randomSplit(Array(0.7, 0.3), 1234L)
  55.     trainingData.show(3, 0)
  56.     testData.show(3, 0)
  57.     // 利用流水线训练模型
  58.     val model = lrpiple.fit(trainingData)
  59.     val predictions = model.transform(testData)
  60.     // 显示预测结果
  61.     predictions.select("predictedLabel", "label", "features", "probability").show(5, 0)
  62.     // 评估模型
  63.     // 创建一个MulticlassClassificationEvaluator实例用于评估分类模型的准确性
  64.     // 设置评估器的标签列名为"labelIndex",预测列名为"prediction",并使用"accuracy"作为评估指标
  65.     val evaluator = new MulticlassClassificationEvaluator().setLabelCol("labelIndex").setPredictionCol("prediction")
  66.       .setMetricName("accuracy")
  67.     // 使用评估器计算预测结果的准确性
  68.     val accuracy = evaluator.evaluate(predictions)
  69.     // 打印测试错误率,即1减去准确率
  70.     println("Test Error = " + (1.0 - accuracy))
  71.     // 通过流水线获取模型参数
  72.     val lrModel = model.stages(2).asInstanceOf[LogisticRegressionModel]
  73.     println("Learned classification logistic regression model:\n" + lrModel.summary.totalIterations)
  74.     println("Coefficients: \n" + lrModel.coefficientMatrix)
  75.     println("Intercepts: \n" + lrModel.interceptVector)
  76.     println("logistic regression model num of Classes" + lrModel.numClasses)
  77.     println("logistic regression model num of features" + lrModel.numFeatures)
  78.     // 保存模型
  79.     lrModel.write.overwrite().save(modelpath)
  80.     spark.stop()
  81.   }
  82. }
复制代码
代码执行结果如下:




免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

伤心客

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表