Mean Shift是一种强大的无监督学习算法, 用于对数据点进行聚类. 它将数据点的分布视为概率密度函数, 并尝试在特征空间中找到模式. 这些模式基本上是对应于局部最大值的点. Mean Shift算法的主要优点是我们不需要事先知道簇的数量.

假设我们有一组输入点, 我们正在尝试在其中找到集群, 而不知道我们正在寻找多少个集群. Mean Shift算法将这些点从概率密度函数中进行采样. 如果数据点中存在簇, 则它们对应于该概率密度函数的峰值. 该算法从随机点开始, 并逐渐收敛于这些峰值. 您可以在http://homepages.inf.ed.ac.uk/rbf/CVonline/LOCAL_COPIES/TUZEL1/MeanShift.pdf了解更多信息.

怎么做...?

  • 该食谱的完整代码在mean_shift.py文件中给出. 我们来看看它是如何构建的. 创建一个新的Python文件, 并导入几个必需的包:
import numpy as np
from sklearn.cluster import MeanShift, estimate_bandwidth
import matplotlib.pyplot as plt
import utilities
  • 从data_multivar.txt文件中导入数据:
# Load data from input file
X = utilities.load_data('data_multivar.txt')
  • 通过指定输入参数构建Mean Shift聚类模型:
# Estimating the bandwidth
bandwidth = estimate_bandwidth(X, quantile=0.1, n_samples=len(X))

# Compute clustering with MeanShift
meanshift_estimator = MeanShift(bandwidth=bandwidth, bin_seeding=True)
  • 训练模型:
meanshift_estimator.fit(X)
  • 获取标签:
labels = meanshift_estimator.labels_
  • 从模型中提取聚类的质心并打印出聚类数:
centroids = meanshift_estimator.cluster_centers_
num_clusters = len(np.unique(labels))

print ("Number of clusters in input data =", num_clusters)
  • 图形可视化:
###########################################################
# Plot the points and centroids
plt.figure()

# specify marker shapes for different clusters
markers = '.*xv'

for i, marker in zip(range(num_clusters), markers):
    # plot the points belong to the current cluster
    plt.scatter(X[labels == i, 0], X[labels == i, 1], marker=marker, color='k')

    # plot the centroid of the current cluster
    centroid = centroids[i]
    plt.plot(centroid[0], centroid[1], marker='o', markerfacecolor='k',
             markeredgecolor='k', markersize=15)

plt.title('Clusters and their centroids')
plt.show()

results matching ""

    No results matching ""