首页 > 编程笔记 > Python笔记 阅读:5

K-Means++聚类算法详解(附带实例)

K-Means++ 算法是 K-Means 算法的改进版,主要是为了选择出更优的初始聚类中心。

K-Means++ 算法的基本思路如下:
【实例】通过例子说明 K-Means++ 算法是如何选取初始聚类中心的。
解析:数据集中共有 8 个样本,分布情况如下图所示:


图 1 样本分布情况

假设经过 K-Means++ 算法第 1 步后,6 号点(1,2)被选择为第一个初始聚类中心,在进行第 2 步时每个样本的 D(x) 和被选择为第二个聚类中心的概率如下表所示。

表:D(x) 和被选择为第二个聚类中心的概率

其中的 P(x) 就是每个样本被选为下一个聚类中心的概率。最后一行的 Sum 是概率 P(x) 的累加和。

用轮盘法选择出第二个聚类中心的方法:随机产生出一个 [0,1] 的随机数,判断它属于哪个区间,那么该区间对应的序号就被选择作为第二个聚类中心。例如 1 号点的区间为 [0,0.2],2 号点的区间为 (0.2,0.525]等。如果给出的随机数是 0.45,那么 2 号点就是第二个聚类中心。

从上表可以直观地看到第二个初始聚类中心是 1 号、2 号、3 号、4 号中的一个,因为这 4 个点的累计概率为 0.9,占了很大一部分比例。

从图 1 中也可以看到,这 4 个点正好是离第一个初始聚类中心 6 号点较远的 4 个点。这也验证了 K-Means 算法的改进思想,即离当前已有聚类中心较远的点有更大的概率被选为下一个聚类中心。

代码如下:
import time
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_blobs

# 支持中文与负号
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

# ---------- 工具函数 ----------
def distEclud(vecA, vecB):
    """计算两个向量的欧氏距离"""
    return np.sqrt(np.sum(np.power(vecA - vecB, 2)))

def get_closest_dist(point, centroids):
    """计算样本点到当前已有聚类中心的最短距离"""
    min_dist = np.inf
    for centroid in centroids:
        dist = distEclud(np.array(centroid), np.array(point))
        if dist < min_dist:
            min_dist = dist
    return min_dist

def RWS(P, r):
    """轮盘法选择下一个聚类中心索引"""
    q = 0.0
    for i in range(len(P)):
        q += P[i]
        if i == len(P) - 1:          # 避免浮点误差
            q = 1.0
        if r <= q:
            return i

def getCent(dataSet, k):
    """K-Means++ 生成 k 个初始质心"""
    m, n = dataSet.shape
    centroids = np.mat(np.zeros((k, n)))
    # 随机选第一个中心
    idx = np.random.randint(0, m)
    centroids[0, :] = dataSet[idx, :]

    d = np.mat(np.zeros((m, 1)))
    for j in range(1, k):
        for i in range(m):
            d[i, 0] = get_closest_dist(dataSet[i], centroids)
        P = np.square(d) / np.square(d).sum()
        r = np.random.random()
        chosen = RWS(P, r)
        centroids[j, :] = dataSet[chosen, :]
    return centroids

def kMeans_plus2(dataSet, k, distMeas=distEclud):
    """K-Means++ 聚类主算法"""
    m = dataSet.shape[0]
    clusterAssment = np.mat(np.zeros((m, 2)))
    centroids = getCent(dataSet, k)

    clusterChanged = True
    while clusterChanged:
        clusterChanged = False
        for i in range(m):
            minDist, minIndex = np.inf, -1
            for j in range(k):
                dist = distMeas(centroids[j, :], dataSet[i, :])
                if dist < minDist:
                    minDist, minIndex = dist, j
            if clusterAssment[i, 0] != minIndex:
                clusterChanged = True
            clusterAssment[i, :] = minIndex, minDist ** 2
        # 更新质心
        for cent in range(k):
            ptsInClust = dataSet[np.nonzero(clusterAssment[:, 0].A == cent)[0]]
            centroids[cent, :] = np.mean(ptsInClust, axis=0)
    return centroids, clusterAssment

def plotResult(myCentroids, clustAssing, X):
    """可视化结果"""
    centroids = myCentroids.A
    y_kmeans = clustAssing[:, 0].A[:, 0]

    plt.subplot(1, 2, 1)
    plt.scatter(X[:, 0], X[:, 1], s=50)
    plt.title("未聚类前的数据分布")

    plt.subplot(1, 2, 2)
    plt.scatter(X[:, 0], X[:, 1], c=y_kmeans, s=50, cmap='viridis')
    plt.scatter(centroids[:, 0], centroids[:, 1], c='red', s=200, alpha=0.6, marker='*')
    plt.title("用K-Means++算法原理聚类的效果")
    plt.subplots_adjust(wspace=0.5)
    plt.show()

def load_data_make_blobs():
    """生成模拟数据"""
    k = 5
    X, _ = make_blobs(n_samples=1000, n_features=2, centers=k, random_state=1)
    return X, k

# ---------- 主程序 ----------
if __name__ == '__main__':
    X, k = load_data_make_blobs()
    t0 = time.time()
    myCentroids, clustAssing = kMeans_plus2(X, k)
    print("用K-Means++算法原理聚类耗时:", time.time() - t0)
    plotResult(myCentroids, clustAssing, X)
运行程序,输出如下:
[[-12.22482514  -5.65268215]
[-1.05724063   4.82677207]
[-4.97357093   -3.40117757]
[-8.04704314   -8.60053627]
[-9.11308264   -7.66934446]]
...
[[-10.66631851  -3.50135699]
[-1.80530913   2.66422491]
[-5.90187775   -2.8993223]
[-7.05942132   -8.07760549]
[-9.02638865   -4.33638277]]
效果如下图所示:


图 3 K-Means++算法原理聚类效果

用 K-Means++ 算法原理聚类耗时:4.481517791748047

相关文章