K-Means++聚类算法详解(附带实例)
K-Means++ 算法是 K-Means 算法的改进版,主要是为了选择出更优的初始聚类中心。
K-Means++ 算法的基本思路如下:
【实例】通过例子说明 K-Means++ 算法是如何选取初始聚类中心的。
解析:数据集中共有 8 个样本,分布情况如下图所示:

图 1 样本分布情况
假设经过 K-Means++ 算法第 1 步后,6 号点(1,2)被选择为第一个初始聚类中心,在进行第 2 步时每个样本的 D(x) 和被选择为第二个聚类中心的概率如下表所示。
其中的 P(x) 就是每个样本被选为下一个聚类中心的概率。最后一行的 Sum 是概率 P(x) 的累加和。
用轮盘法选择出第二个聚类中心的方法:随机产生出一个 [0,1] 的随机数,判断它属于哪个区间,那么该区间对应的序号就被选择作为第二个聚类中心。例如 1 号点的区间为 [0,0.2],2 号点的区间为 (0.2,0.525]等。如果给出的随机数是 0.45,那么 2 号点就是第二个聚类中心。
从上表可以直观地看到第二个初始聚类中心是 1 号、2 号、3 号、4 号中的一个,因为这 4 个点的累计概率为 0.9,占了很大一部分比例。
从图 1 中也可以看到,这 4 个点正好是离第一个初始聚类中心 6 号点较远的 4 个点。这也验证了 K-Means 算法的改进思想,即离当前已有聚类中心较远的点有更大的概率被选为下一个聚类中心。
代码如下:

图 3 K-Means++算法原理聚类效果
用 K-Means++ 算法原理聚类耗时:4.481517791748047
K-Means++ 算法的基本思路如下:
- 在数据集中随机选择一个样本作为第一个初始聚类中心;
-
选择出其余的聚类中心:
- 计算数据集中的每个样本与已经初始化的聚类中心之间的距离,并选择其中最短的距离,记为 di。
- 以概率选择距离最大的样本作为新的聚类中心,重复上述过程,直到 K 个聚类中心都被确定。
- 对 K 个初始的聚类中心,利用 K-Means 算法计算出最终的聚类中心。
【实例】通过例子说明 K-Means++ 算法是如何选取初始聚类中心的。
解析:数据集中共有 8 个样本,分布情况如下图所示:

图 1 样本分布情况
假设经过 K-Means++ 算法第 1 步后,6 号点(1,2)被选择为第一个初始聚类中心,在进行第 2 步时每个样本的 D(x) 和被选择为第二个聚类中心的概率如下表所示。
表:D(x) 和被选择为第二个聚类中心的概率

其中的 P(x) 就是每个样本被选为下一个聚类中心的概率。最后一行的 Sum 是概率 P(x) 的累加和。
用轮盘法选择出第二个聚类中心的方法:随机产生出一个 [0,1] 的随机数,判断它属于哪个区间,那么该区间对应的序号就被选择作为第二个聚类中心。例如 1 号点的区间为 [0,0.2],2 号点的区间为 (0.2,0.525]等。如果给出的随机数是 0.45,那么 2 号点就是第二个聚类中心。
从上表可以直观地看到第二个初始聚类中心是 1 号、2 号、3 号、4 号中的一个,因为这 4 个点的累计概率为 0.9,占了很大一部分比例。
从图 1 中也可以看到,这 4 个点正好是离第一个初始聚类中心 6 号点较远的 4 个点。这也验证了 K-Means 算法的改进思想,即离当前已有聚类中心较远的点有更大的概率被选为下一个聚类中心。
代码如下:
import time import numpy as np import matplotlib.pyplot as plt from sklearn.datasets import make_blobs # 支持中文与负号 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # ---------- 工具函数 ---------- def distEclud(vecA, vecB): """计算两个向量的欧氏距离""" return np.sqrt(np.sum(np.power(vecA - vecB, 2))) def get_closest_dist(point, centroids): """计算样本点到当前已有聚类中心的最短距离""" min_dist = np.inf for centroid in centroids: dist = distEclud(np.array(centroid), np.array(point)) if dist < min_dist: min_dist = dist return min_dist def RWS(P, r): """轮盘法选择下一个聚类中心索引""" q = 0.0 for i in range(len(P)): q += P[i] if i == len(P) - 1: # 避免浮点误差 q = 1.0 if r <= q: return i def getCent(dataSet, k): """K-Means++ 生成 k 个初始质心""" m, n = dataSet.shape centroids = np.mat(np.zeros((k, n))) # 随机选第一个中心 idx = np.random.randint(0, m) centroids[0, :] = dataSet[idx, :] d = np.mat(np.zeros((m, 1))) for j in range(1, k): for i in range(m): d[i, 0] = get_closest_dist(dataSet[i], centroids) P = np.square(d) / np.square(d).sum() r = np.random.random() chosen = RWS(P, r) centroids[j, :] = dataSet[chosen, :] return centroids def kMeans_plus2(dataSet, k, distMeas=distEclud): """K-Means++ 聚类主算法""" m = dataSet.shape[0] clusterAssment = np.mat(np.zeros((m, 2))) centroids = getCent(dataSet, k) clusterChanged = True while clusterChanged: clusterChanged = False for i in range(m): minDist, minIndex = np.inf, -1 for j in range(k): dist = distMeas(centroids[j, :], dataSet[i, :]) if dist < minDist: minDist, minIndex = dist, j if clusterAssment[i, 0] != minIndex: clusterChanged = True clusterAssment[i, :] = minIndex, minDist ** 2 # 更新质心 for cent in range(k): ptsInClust = dataSet[np.nonzero(clusterAssment[:, 0].A == cent)[0]] centroids[cent, :] = np.mean(ptsInClust, axis=0) return centroids, clusterAssment def plotResult(myCentroids, clustAssing, X): """可视化结果""" centroids = myCentroids.A y_kmeans = clustAssing[:, 0].A[:, 0] plt.subplot(1, 2, 1) plt.scatter(X[:, 0], X[:, 1], s=50) plt.title("未聚类前的数据分布") plt.subplot(1, 2, 2) plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=50, cmap='viridis') plt.scatter(centroids[:, 0], centroids[:, 1], c='red', s=200, alpha=0.6, marker='*') plt.title("用K-Means++算法原理聚类的效果") plt.subplots_adjust(wspace=0.5) plt.show() def load_data_make_blobs(): """生成模拟数据""" k = 5 X, _ = make_blobs(n_samples=1000, n_features=2, centers=k, random_state=1) return X, k # ---------- 主程序 ---------- if __name__ == '__main__': X, k = load_data_make_blobs() t0 = time.time() myCentroids, clustAssing = kMeans_plus2(X, k) print("用K-Means++算法原理聚类耗时:", time.time() - t0) plotResult(myCentroids, clustAssing, X)运行程序,输出如下:
[[-12.22482514 -5.65268215] [-1.05724063 4.82677207] [-4.97357093 -3.40117757] [-8.04704314 -8.60053627] [-9.11308264 -7.66934446]] ... [[-10.66631851 -3.50135699] [-1.80530913 2.66422491] [-5.90187775 -2.8993223] [-7.05942132 -8.07760549] [-9.02638865 -4.33638277]]效果如下图所示:

图 3 K-Means++算法原理聚类效果
用 K-Means++ 算法原理聚类耗时:4.481517791748047