目录[-]

Clustering

本页介绍MLlib中的聚类算法。在基于RDD-API聚类指南里还提供了有关这些算法的相关信息。

K-means

K-means是一个常用的聚类算法来将数据点按预定的簇数进行聚集。K-means算法的基本思想是:以空间中k个点为中心进行聚类,对最靠近他们的对象归类。通过迭代的方法,逐次更新各聚类中心的值,直至得到最好的聚类结果。

假设要把样本集分为c个类别,算法描述如下:

(1)适当选择c个类的初始中心;

(2)在第k次迭代中,对任意一个样本,求其到c个中心的距离,将该样本归到距离最短的中心所在的类;

(3)利用均值等方法更新该类的中心值;

(4)对于所有的c个聚类中心,如果利用(2)(3)的迭代法更新后,值保持不变,则迭代结束,否则继续迭代。

MLlib工具包含并行的K-means++算法,称为kmeans||。Kmeans是一个Estimator,它在基础模型之上产生一个KMeansModel。

  • Input Columns(输入列)
Param name(参数名称) Type(s)(类型) Default(默认) Description(描述)
featuresCol Vector "features" Feature vector(特征向量)
  • Output Columns(输出列)
Param name(参数名称) Type(s)(类型) Default(默认) Description(描述)
predictionCol Int "prediction" Predicted cluster center(预测的聚类中心)

Examples

from pyspark.ml.clustering import KMeans
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("ClusterExample").getOrCreate()
# Loads data.
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")

# Trains a k-means model.
kmeans = KMeans().setK(2).setSeed(1)
model = kmeans.fit(dataset)

# Evaluate clustering by computing Within Set Sum of Squared Errors.
wssse = model.computeCost(dataset)
print("Within Set Sum of Squared Errors = " + str(wssse))

# Shows the result.
centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
    print(center)
spark.stop()

output:

Within Set Sum of Squared Errors = 0.11999999999994547
Cluster Centers: 
[ 0.1  0.1  0.1]
[ 9.1  9.1  9.1]

Find full example code at "examples/src/main/python/ml/kmeans_example.py" in the Spark repo.

Latent Dirichlet allocation(LDA)

LDA(Latent Dirichlet Allocation)是一种文档主题生成模型,也称为一个三层贝叶斯概率模型,包含词、主题和文档三层结构。所谓生成模型,就是说,我们认为一篇文章的每个词都是通过“以一定概率选择了某个主题,并从这个主题中以一定概率选择某个词语”这样一个过程得到。文档到主题服从多项式分布,主题到词服从多项式分布。

LDA是一种非监督机器学习技术,可以用来识别大规模文档集(document collection)或语料库(corpus)中潜藏的主题信息。它采用了词袋(bag of words)的方法,这种方法将每一篇文档视为一个词频向量,从而将文本信息转化为了易于建模的数字信息。但是词袋方法没有考虑词与词之间的顺序,这简化了问题的复杂性,同时也为模型的改进提供了契机。每一篇文档代表了一些主题所构成的一个概率分布,而每一个主题又代表了很多单词所构成的一个概率分布。

LDA被实现为一个Estimator,既支持EMLDAOptimizer和OnlineLDAOptimizer,并生成一个LDAModel作为基础模型。如果需要的话,专家用户可以将EMLDAOptimizer生成的LDAModel映射到一个DistributedLDAModel

Examples

from pyspark.ml.clustering import LDA
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("LDAExample").getOrCreate()
# Loads data.
dataset = spark.read.format("libsvm").load("data/mllib/sample_lda_libsvm_data.txt")

# Trains a LDA model.
lda = LDA(k=10, maxIter=10)
model = lda.fit(dataset)

ll = model.logLikelihood(dataset)
lp = model.logPerplexity(dataset)
print("The lower bound on the log likelihood of the entire corpus: " + str(ll))
print("The upper bound on perplexity: " + str(lp))

# Describe topics.
topics = model.describeTopics(3)
print("The topics described by their top-weighted terms:")
topics.show(truncate=False)

# Shows the result
transformed = model.transform(dataset)
transformed.show(truncate=False)
spark.stop()

output:

The lower bound on the log likelihood of the entire corpus: -797.8018456907539
The upper bound on perplexity: 3.068468635551357
The topics described by their top-weighted terms:
+-----+-----------+---------------------------------------------------------------+
|topic|termIndices|termWeights                                                    |
+-----+-----------+---------------------------------------------------------------+
|0    |[0, 4, 7]  |[0.13939487929625935, 0.13346874874963285, 0.11911498796394984]|
|1    |[8, 6, 0]  |[0.09761719173430919, 0.09664530483154511, 0.0959033498887414] |
|2    |[5, 9, 1]  |[0.09763288175177705, 0.0967699480930826, 0.09474971437446654] |
|3    |[6, 2, 5]  |[0.09993087551790403, 0.09802667103524504, 0.09669791743434605]|
|4    |[10, 5, 8] |[0.10838084105098059, 0.1065719519796393, 0.10564271921581836] |
|5    |[2, 5, 3]  |[0.09975664174839147, 0.09917147147531298, 0.09482946730767593]|
|6    |[1, 7, 3]  |[0.1025918379349122, 0.09670884980694468, 0.09661321616852961] |
|7    |[3, 10, 6] |[0.18074276445784626, 0.17140880975201497, 0.11846617165050731]|
|8    |[7, 9, 1]  |[0.10376667278659339, 0.10266984655859988, 0.10261491999135175]|
|9    |[5, 9, 4]  |[0.17217259005160918, 0.11130983487715354, 0.10625585388024414]|
+-----+-----------+---------------------------------------------------------------+

+-----+---------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|label|features                                                       |topicDistribution                                                                                                                                                                                                      |
+-----+---------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|0.0  |(11,[0,1,2,4,5,6,7,10],[1.0,2.0,6.0,2.0,3.0,1.0,1.0,3.0])      |[0.004834482522877391,0.004775061546874506,0.0047750850624618665,0.00477508209536724,0.004775110752126829,0.0047751198765325934,0.0047750802565546275,0.44999380128294686,0.004775119757841731,0.5117460568464164]     |
|1.0  |(11,[0,1,3,4,7,10],[1.0,3.0,1.0,3.0,2.0,1.0])                  |[0.9268994208923648,0.007965511763080765,0.007965521320089061,0.007965447383722308,0.007965587789582014,0.007965461329343004,0.00796558757403698,0.009276986136774072,0.007965614108681227,0.008064861702326028]       |
|2.0  |(11,[0,1,2,5,6,8,9],[1.0,4.0,1.0,4.0,9.0,1.0,2.0])             |[0.004202815262490896,0.004151229704235803,0.004151279248440336,0.004151250849060332,0.004151298320120848,0.004151248811452763,0.004151213592542253,0.6501025149437936,0.00415114952939257,0.3166359997384707]         |
|3.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,3.0,9.0])            |[0.0037170513237456872,0.0036715329471578005,0.0036715360552429213,0.003671511493261907,0.003671797370463146,0.0036715102318871204,0.0036715134308361727,0.9668647838101413,0.003671504403863317,0.003717258933400576] |
|4.0  |(11,[0,1,2,3,4,6,9,10],[3.0,1.0,1.0,9.0,3.0,2.0,1.0,3.0])      |[0.004027376743557338,0.003977866599137274,0.003977850254362953,0.003977835428829377,0.0039778820932092175,0.003977853048840427,0.003977852184563374,0.9641001717255747,0.0039778458818949,0.004027466040030248]       |
|5.0  |(11,[0,1,3,4,5,6,7,8,9],[4.0,2.0,3.0,4.0,5.0,1.0,1.0,1.0,4.0]) |[0.003717509832713523,0.0036716615407946934,0.0036716846624067615,0.0036716395255962085,0.0036717149575019995,0.0036716664005927474,0.0036716667567801204,0.27461258177043146,0.0036716647781321666,0.6959682097750503]|
|6.0  |(11,[0,1,3,6,8,9,10],[2.0,1.0,3.0,5.0,2.0,2.0,9.0])            |[0.0038659082828356533,0.003818570338009387,0.0038185658000222077,0.0038185390646671936,0.003818726199778954,0.0038185379956121677,0.003818554784511252,0.9655379642100607,0.003818526437489602,0.003866106887012979]  |
|7.0  |(11,[0,1,2,3,4,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,1.0,2.0,1.0,3.0])|[0.004394125793389081,0.004340066102223131,0.004340117929521572,0.004340091402319875,0.004340183500883856,0.004340117374988447,0.004340096103563213,0.9608305723851966,0.004340058125232322,0.004394571282681922]      |
|8.0  |(11,[0,1,3,4,5,6,7],[4.0,4.0,3.0,4.0,2.0,1.0,3.0])             |[0.9601715212707249,0.0043400767428901635,0.004340086711133699,0.004340041546373581,0.004340093118553618,0.004340077924408194,0.004340099543124161,0.005053547015193133,0.004340064942938327,0.004394391184660286]     |
|9.0  |(11,[0,1,2,4,6,8,9,10],[2.0,8.0,2.0,3.0,2.0,2.0,7.0,2.0])      |[0.003332384443784424,0.0032914608990001755,0.003291474583522146,0.003291442358715674,0.003291502923651029,0.0032914477446806248,0.003291451230227242,0.9702948142666302,0.0032914840138979083,0.0033325375358905047]  |
|10.0 |(11,[0,1,2,3,5,6,9,10],[1.0,1.0,1.0,9.0,2.0,2.0,3.0,3.0])      |[0.004202933475197545,0.004151218860618786,0.004151338270182237,0.004151288340705789,0.004151431515312671,0.004151332888945593,0.00415129142785515,0.9625342071313459,0.0041512327632204984,0.004203725326615945]      |
|11.0 |(11,[0,1,4,5,6,7,9],[4.0,1.0,4.0,5.0,1.0,3.0,1.0])             |[0.5794463100207559,0.004774699657046339,0.004774740812070836,0.0047746922036681246,0.004774755044701768,0.004774721978296648,0.0047747158288502884,0.0055583559655138,0.004774694223725667,0.38157231426537064]       |
+-----+---------------------------------------------------------------+-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

Find full example code at "examples/src/main/python/ml/lda_example.py" in the Spark repo.

Bisecting k-means

二分K均值算法是一种层次聚类算法,使用自顶向下的逼近:所有的观察值开始是一个簇,递归地向下一个层级分裂。分裂依据为选择能最大程度降低聚类代价函数(也就是误差平方和)的簇划分为两个簇。以此进行下去,直到簇的数目等于用户给定的数目k为止。二分K均值常常比传统K均值算法有更快的计算速度,但产生的簇群与传统K均值算法往往也是不同的。

BisectingKMeans是一个Estimator,在基础模型上训练得到BisectingKMeansModel。

Examples

from pyspark.ml.clustering import BisectingKMeans
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("BisectingKMeansExample").getOrCreate()
# Loads data.
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")

# Trains a bisecting k-means model.
bkm = BisectingKMeans().setK(2).setSeed(1)
model = bkm.fit(dataset)

# Evaluate clustering.
cost = model.computeCost(dataset)
print("Within Set Sum of Squared Errors = " + str(cost))

# Shows the result.
print("Cluster Centers: ")
centers = model.clusterCenters()
for center in centers:
    print(center)
spark.stop()

output:

Within Set Sum of Squared Errors = 0.11999999999994547
Cluster Centers: 
[ 0.1  0.1  0.1]
[ 9.1  9.1  9.1]

Find full example code at "examples/src/main/python/ml/bisecting_k_means_example.py" in the Spark repo.

Gaussian Mixture Model(GMM)

混合高斯模型描述数据点以一定的概率服从k种高斯子分布的一种混合分布。Spark.ml使用EM算法给出一组样本的极大似然模型。

GaussianMixture被实现为一个Estimator,并生成一个GaussianMixtureModel基本模型。

  • Input Columns
Param name Type(s) Default Description
featuresCol Vector "features" Feature vector
  • Output Columns
Param name Type(s) Default Description
predictionCol Int "prediction" Predicted cluster center
probabilityCol Vector "probability" Probability of each cluster

Examples

from pyspark.ml.clustering import GaussianMixture
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("GaussianMixtureExample").getOrCreate()
# loads data
dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")

gmm = GaussianMixture().setK(2).setSeed(538009335)
model = gmm.fit(dataset)

print("Gaussians shown as a DataFrame: ")
model.gaussiansDF.show(truncate=False)
spark.stop()

output:

Gaussians shown as a DataFrame: 
+-------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|mean                                                         |cov                                                                                                                                                                                                     |
+-------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[0.10000000000001552,0.10000000000001552,0.10000000000001552]|0.006666666666806454  0.006666666666806454  0.006666666666806454  
0.006666666666806454  0.006666666666806454  0.006666666666806454  
0.006666666666806454  0.006666666666806454  0.006666666666806454  |
|[9.099999999999984,9.099999999999984,9.099999999999984]      |0.006666666666812185  0.006666666666812185  0.006666666666812185  
0.006666666666812185  0.006666666666812185  0.006666666666812185  
0.006666666666812185  0.006666666666812185  0.006666666666812185  |
+-------------------------------------------------------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+

Find full example code at "examples/src/main/python/ml/gaussian_mixture_example.py" in the Spark repo.

更多相关信息请查阅Spark Clustering文档