Spark机器学习——逻辑回归分类算法

Spark 专栏收录该内容
29 篇文章 0 订阅

逻辑回归介绍

逻辑回归是一种的监督学习算法,主要用于分类问题。

Logistic Regression 虽然被称为回归,但其实际上是分类模型,并常用于二分类。Logistic Regression 因其简单、可并行化、可解释强深受工业界喜爱。

Logistic 回归的本质是:假设数据服从这个分布,然后使用极大似然估计做参数的估计。

逻辑回归案例

这里主要通过逻辑回归模型建立一个二元分类器模型,根据过去的考试成绩预测下一次学生的考试及格/不及格成绩

scores.csv

(第一次考试的分数,第二次考试的分数,是(0)否(1)能通过第三次考试)

score1,score2,result
34.62365962451697,78.0246928153624,0
30.28671076822607,43.89499752400101,0
35.84740876993872,72.90219802708364,0
60.18259938620976,86.30855209546826,1
79.0327360507101,75.3443764369103,1
45.08327747668339,56.3163717815305,0
61.10666453684766,96.51142588489624,1
75.02474556738889,46.55401354116538,1
76.09878670226257,87.42056971926803,1
84.43281996120035,43.53339331072109,1
95.86155507093572,38.22527805795094,0
75.01365838958247,30.60326323428011,0
82.30705337399482,76.48196330235604,1
34.21206097786789,44.20952859866288,0
77.9240914545704,68.9723599933059,1
62.27101367004632,69.95445795447587,1
80.1901807509566,44.82162893218353,1
93.114388797442,38.80067033713209,0
61.83020602312595,50.25610789244621,0
38.78580379679423,64.99568095539578,0
61.379289447425,72.80788731317097,1
85.40451939411645,57.05198397627122,1
52.10797973193984,63.12762376881715,0
52.04540476831827,69.43286012045222,1
40.23689373545111,71.16774802184875,0
54.63510555424817,52.21388588061123,0
33.91550010906887,98.86943574220611,0
64.17698887494485,80.90806058670817,1
74.78925295941542,41.57341522824434,0
34.1836400264419,75.2377203360134,0
83.90239366249155,56.30804621605327,1
51.54772026906181,46.85629026349976,0
94.44336776917852,65.56892160559052,1
82.36875375713919,40.61825515970618,0
51.04775177128865,45.82270145776001,0
62.22267576120188,52.06099194836679,0
77.19303492601364,70.45820000180959,1
97.77159928000232,86.7278223300282,1
62.07306379667647,96.76882412413983,1
91.56497449807442,88.69629254546599,1
79.94481794066932,74.16311935043758,1
99.2725269292572,60.99903099844988,1
90.54671411399852,43.39060180650027,1
34.52451385320009,60.39634245837173,0
50.2864961189907,49.80453881323059,0
49.58667721632031,59.80895099453265,0
97.64563396007767,68.86157272420604,1
32.57720016809309,95.59854761387875,0
74.24869136721598,69.82457122657193,1
71.79646205863379,78.45356224515052,1
75.3956114656803,85.75993667331619,1
35.28611281526193,47.02051394723416,0
56.25381749711624,39.26147251058019,0
30.05882244669796,49.59297386723685,0
44.66826172480893,66.45008614558913,0
66.56089447242954,41.09209807936973,0
40.45755098375164,97.53518548909936,1
49.07256321908844,51.88321182073966,0
80.27957401466998,92.11606081344084,1
66.74671856944039,60.99139402740988,1
32.72283304060323,43.30717306430063,0
64.0393204150601,78.03168802018232,1
72.34649422579923,96.22759296761404,1
60.45788573918959,73.09499809758037,1
58.84095621726802,75.85844831279042,1
99.82785779692128,72.36925193383885,1
47.26426910848174,88.47586499559782,1
50.45815980285988,75.80985952982456,1
60.45555629271532,42.50840943572217,0
82.22666157785568,42.71987853716458,0
88.9138964166533,69.80378889835472,1
94.83450672430196,45.69430680250754,1
67.31925746917527,66.58935317747915,1
57.23870631569862,59.51428198012956,1
80.36675600171273,90.96014789746954,1
68.46852178591112,85.59430710452014,1
42.0754545384731,78.84478600148043,0
75.47770200533905,90.42453899753964,1
78.63542434898018,96.64742716885644,1
52.34800398794107,60.76950525602592,0
94.09433112516793,77.15910509073893,1
90.44855097096364,87.50879176484702,1
55.48216114069585,35.57070347228866,0
74.49269241843041,84.84513684930135,1
89.84580670720979,45.35828361091658,1
83.48916274498238,48.38028579728175,1
42.2617008099817,87.10385094025457,1
99.31500880510394,68.77540947206617,1
55.34001756003703,64.9319380069486,1
74.77589300092767,89.52981289513276,1
69.36458875970939,97.71869196188608,1
39.53833914367223,76.03681085115882,0
53.9710521485623,89.20735013750205,1
69.07014406283025,52.74046973016765,1
67.94685547711617,46.67857410673128,0

Maven依赖

<properties>
    <scala.version>2.11.8</scala.version>
    <spark.version>2.2.2</spark.version>
    <hadoop.version>2.7.6</hadoop.version>
</properties>
<dependencies>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-core_2.11</artifactId>
        <version>${spark.version}</version>
    </dependency>
    <dependency>
        <groupId>org.apache.spark</groupId>
        <artifactId>spark-mllib_2.11</artifactId>
        <version>${spark.version}</version>
    </dependency>
</dependencies>

LogisticRegression.scala

import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{StringIndexer, VectorAssembler}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructField, StructType}

/**
  * @Author Daniel
  * @Description 逻辑回归算法
  *
  **/
object LogisticRegression {
  def main(args: Array[String]): Unit = {
    // 构建spark编程入口
    val spark = SparkSession
      .builder
      .master("local[*]")
      .appName("LogisticRegression")
      .getOrCreate()
    val (assembler, logisticRegressionModelLoaded) = getModel(spark)
    // 测试数据,给出5个同学第一次和第二次的考试成绩,对第三次的考试成绩进行预测,0代表及格,1代表不及格(由于添加了label,所以与最初相反)
    import spark.implicits._
    val df1 = Seq(
      (70.66150955499435, 92.92713789364831),
      (76.97878372747498, 47.57596364975532),
      (67.37202754570876, 42.83843832029179),
      (89.67677575072079, 65.79936592745237),
      (50.534788289883, 48.85581152764205)
    ).toDF("score1", "score2")
    // 转换样本数据集并添加特征列
    val df2 = assembler.transform(df1)
    df2.show()
    // 最后的结果表示预测新来学生第三门课的及格和不及格状态(0表示及格,1表示不及格)
    val df3 = logisticRegressionModelLoaded.transform(df2)
    df3.show()
  }

  // 获取训练模型
  def getModel(spark: SparkSession): (VectorAssembler, LogisticRegressionModel) = {
    // 表结构
    val schema = StructType(
      StructField("score1", DoubleType, nullable = true) ::
        StructField("score2", DoubleType, nullable = true) ::
        // 0表示不及格1表示及格
        StructField("result", IntegerType, nullable = true) ::
        Nil
    )

    // 将数据转换为DataFrame
    val marksDf = spark.read.format("csv")
      .option("header", value = true)
      .option("delimiter", ",")
      .schema(schema)
      .load("scores.csv")
      // 持久化
      .cache()

    // 需要转换成特征向量的列
    val cols = Array("score1", "score2")

    // 转化成向量
    val assembler = new VectorAssembler()
      .setInputCols(cols)
      .setOutputCol("features")
    // 得到特征向量DataFrame
    val featureDf = assembler.transform(marksDf)

    // 根据result列新建一个标签列
    val indexer = new StringIndexer()
      .setInputCol("result")
      .setOutputCol("label")
    val labelDf = indexer.fit(featureDf).transform(featureDf)

    val seed = 5043
    // 70%的数据用于训练模型,30%用于测试
    val Array(trainingData, testData) = labelDf.randomSplit(Array(0.7, 0.3), seed)
    // 建立回归模型,用训练集数据开始训练
    val logisticRegression = new LogisticRegression()
      .setMaxIter(100)
      .setRegParam(0.02)
      .setElasticNetParam(0.8)
    val logisticRegressionModel = logisticRegression.fit(trainingData)
    /*
    使用测试数据集预测得到的DataFrame,添加三个新的列
    1.rawPrediction
      通常是直接概率
    2.probability
      每个类的条件概率
    3.prediction
      rawPrediction - via的统计结果
     */
    val predictionDf = logisticRegressionModel.transform(testData)
    // ROC下面积的评估模型
    val evaluator = new BinaryClassificationEvaluator()
      .setLabelCol("label")
      .setRawPredictionCol("prediction")
      .setMetricName("areaUnderROC")
    // 测量精度
    val accuracy = evaluator.evaluate(predictionDf)
    println("预测的精度为" + accuracy)
    // 保存模型
    logisticRegressionModel.write.overwrite()
      .save("score-model")
    // 加载模型
    val logisticRegressionModelLoaded = LogisticRegressionModel
      .load("score-model")
    (assembler, logisticRegressionModelLoaded)
  }
}

结果

预测的精度为0.8928571428571429
+-----------------+-----------------+--------------------+
|           score1|           score2|            features|
+-----------------+-----------------+--------------------+
|70.66150955499435|92.92713789364831|[70.6615095549943...|
|76.97878372747498|47.57596364975532|[76.9787837274749...|
|67.37202754570876|42.83843832029179|[67.3720275457087...|
| 89.6767757507208|65.79936592745237|[89.6767757507208...|
|  50.534788289883|48.85581152764205|[50.534788289883,...|
+-----------------+-----------------+--------------------+

+-----------------+-----------------+--------------------+--------------------+--------------------+----------+
|           score1|           score2|            features|       rawPrediction|         probability|prediction|
+-----------------+-----------------+--------------------+--------------------+--------------------+----------+
|70.66150955499435|92.92713789364831|[70.6615095549943...|[4.42488938425420...|[0.98816618042094...|       0.0|
|76.97878372747498|47.57596364975532|[76.9787837274749...|[0.13401559021765...|[0.53345384278692...|       0.0|
|67.37202754570876|42.83843832029179|[67.3720275457087...|[-1.4919079280137...|[0.18363553054854...|       1.0|
| 89.6767757507208|65.79936592745237|[89.6767757507208...|[3.60597758013620...|[0.97355732638820...|       0.0|
|  50.534788289883|48.85581152764205|[50.534788289883,...|[-2.7578193413834...|[0.05964655921865...|       1.0|
+-----------------+-----------------+--------------------+--------------------+--------------------+----------+

即预测有3个学生能通过考试

  • 0
    点赞
  • 0
    评论
  • 1
    收藏
  • 一键三连
    一键三连
  • 扫一扫,分享海报

相关推荐
<p> <span style="font-size:14px;color:#337FE5;">【为什么学爬虫?】</span> </p> <p> <span style="font-size:14px;">       1、爬虫入手容易,但是深入较难,如何写出高效率的爬虫,如何写出灵活性高可扩展的爬虫都是一项技术活。另外在爬虫过程中,经常容易遇到被反爬虫,比如字体反爬、IP识别、验证码等,如何层层攻克难点拿到想要的数据,这门课程,你都能学到!</span> </p> <p> <span style="font-size:14px;">       2、如果是作为一个其他行业的开发者,比如app开发,web开发,学习爬虫能让你加强对技术的认知,能够开发出更加安全的软件和网站</span> </p> <p> <br /> </p> <span style="font-size:14px;color:#337FE5;">【课程设计】</span> <p class="ql-long-10663260"> <span> </span> </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> 一个完整的爬虫程序,无论大小,总体来说可以分成三个步骤,分别是: </p> <ol> <li class="" style="font-size:11pt;color:#494949;"> 网络请求:模拟浏览器的行为从网上抓取数据。 </li> <li class="" style="font-size:11pt;color:#494949;"> 数据解析:将请求下来的数据进行过滤,提取我们想要的数据。 </li> <li class="" style="font-size:11pt;color:#494949;"> 数据存储:将提取到的数据存储到硬盘或者内存中。比如用mysql数据库或者redis等。 </li> </ol> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> 那么本课程也是按照这几个步骤循序渐进的进行讲解,带领学生完整的掌握每个步骤的技术。另外,因为爬虫的多样性,在爬取的过程中可能会发生被反爬、效率低下等。因此我们又增加了两个章节用来提高爬虫程序的灵活性,分别是: </p> <ol> <li class="" style="font-size:11pt;color:#494949;"> 爬虫进阶:包括IP代理,多线程爬虫,图形验证码识别、JS加密解密、动态网页爬虫、字体反爬识别等。 </li> <li class="" style="font-size:11pt;color:#494949;"> Scrapy和分布式爬虫:Scrapy框架、Scrapy-redis组件、分布式爬虫等。 </li> </ol> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> 通过爬虫进阶的知识点我们能应付大量的反爬网站,而Scrapy框架作为一个专业的爬虫框架,使用他可以快速提高我们编写爬虫程序的效率和速度。另外如果一台机器不能满足你的需求,我们可以用分布式爬虫让多台机器帮助你快速爬取数据。 </p> <p style="font-size:11pt;color:#494949;">   </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> 从基础爬虫到商业化应用爬虫,本套课程满足您的所有需求! </p> <p class="ql-long-26664262" style="font-size:11pt;color:#494949;"> <br /> </p> <p> <br /> </p> <p> <span style="font-size:14px;background-color:#FFFFFF;color:#337FE5;">【课程服务】</span> </p> <p> <span style="font-size:14px;">专属付费社群+定期答疑</span> </p> <p> <br /> </p> <p class="ql-long-24357476"> <span style="font-size:16px;"><br /> </span> </p> <p> <br /> </p> <p class="ql-long-24357476"> <span style="font-size:16px;"></span> </p>
©️2020 CSDN 皮肤主题: 博客之星2020 设计师:CY__ 返回首页
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值