【问题标题】:Trying to understand what is happening in this Python Function试图了解这个 Python 函数中发生了什么
【发布时间】:2018-05-12 15:35:13
【问题描述】:
def closest_centroid(points, centroids):
    """returns an array containing the index to the nearest centroid for each point"""
    distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))
    return np.argmin(distances, axis=0)

有人能解释一下这个函数的具体工作原理吗?我目前得到points,它看起来像:

31998888119     0.94     34
23423423422     0.45     43
....

等等。在这个 numpy 数组中,points[1] 将是长 ID,而 points[2]0.94points[3] 将是 34 的第一个条目。

Centroids 只是从这个特定数组中随机选择的:

def initialize_centroids(points, k):
    """returns k centroids from the initial points"""
    centroids = points.copy()
    np.random.shuffle(centroids)
    return centroids[:k] 

现在我想从 points 的值中获取欧几里得距离,忽略 ID 的第一列和 centroids(再次忽略第一列)。我不完全理解distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2)) 行的语法。为什么我们在第三列中求和,而新轴的减法:np.newaxis?我还应该沿着哪个轴使np.argmin 工作?

【问题讨论】:

    标签: python numpy k-means euclidean-distance


    【解决方案1】:

    考虑尺寸会有所帮助。假设k=4有10个点,所以points.shape = (10,3)

    接下来,centroids = initialize_centroids(points, 4) 返回一个维度为 (4,3) 的对象。

    让我们从内部分解这条线:

    distances = np.sqrt(((points - centroids[:, np.newaxis])**2).sum(axis=2))

    1. 我们想从每个点中减去每个质心。因为pointscentroids 是二维的,所以每个points - centroid 都是二维的。如果只有 1 个质心,那么我们就可以了。但是我们有 4 个质心!所以我们需要为每个质心执行points - centroids。因此,我们需要另一个维度来存储它。因此添加了np.newaxis

    2. 我们将其平方是因为它是一个距离,因此我们希望将负数转换为正数(也因为我们正在最小化欧几里得距离)。

    3. 我们没有对第三列求和。事实上,我们正在为每个点、每个质心求和点和质心之间的差异。

    4. np.argmin() 找到距离最短的质心。因此,对于每个质心,对于每个点,找到最小索引(因此 argmin 而不是 min)。该索引是分配给该点的质心。

    这是一个例子:

    points = np.array([
    [   1, 2, 4],
    [   1, 1, 3],
    [   1, 6, 2],
    [   6, 2, 3],
    [   7, 2, 3],
    [   1, 9, 6],
    [   6, 9, 1],
    [   3, 8, 6],
    [   10, 9, 6],
    [   0, 2, 0],
    ])
    
    centroids = initialize_centroids(points, 4)
    
    print(centroids)
    array([[10,  9,  6],
       [ 3,  8,  6],
       [ 6,  2,  3],
       [ 1,  1,  3]])
    
    distances = (pts - centroids[:, np.newaxis])**2
    
    print(distances)
    array([[[ 81,  49,   4],
        [ 81,  64,   9],
        [ 81,   9,  16],
        [ 16,  49,   9],
        [  9,  49,   9],
        [ 81,   0,   0],
        [ 16,   0,  25],
        [ 49,   1,   0],
        [  0,   0,   0],
        [100,  49,  36]],
    
       [[  4,  36,   4],
        [  4,  49,   9],
        [  4,   4,  16],
        [  9,  36,   9],
        [ 16,  36,   9],
        [  4,   1,   0],
        [  9,   1,  25],
        [  0,   0,   0],
        [ 49,   1,   0],
        [  9,  36,  36]],
    
       [[ 25,   0,   1],
        [ 25,   1,   0],
        [ 25,  16,   1],
        [  0,   0,   0],
        [  1,   0,   0],
        [ 25,  49,   9],
        [  0,  49,   4],
        [  9,  36,   9],
        [ 16,  49,   9],
        [ 36,   0,   9]],
    
       [[  0,   1,   1],
        [  0,   0,   0],
        [  0,  25,   1],
        [ 25,   1,   0],
        [ 36,   1,   0],
        [  0,  64,   9],
        [ 25,  64,   4],
        [  4,  49,   9],
        [ 81,  64,   9],
        [  1,   1,   9]]])
    
    print(distances.sum(axis=2))
    array([[134, 154, 106,  74,  67,  81,  41,  50,   0, 185],
       [ 44,  62,  24,  54,  61,   5,  35,   0,  50,  81],
       [ 26,  26,  42,   0,   1,  83,  53,  54,  74,  45],
       [  2,   0,  26,  26,  37,  73,  93,  62, 154,  11]])
    
    # The minimum of the first 4 centroids is index 3. The minimum of the second 4 centroids is index 3 again.
    
    print(np.argmin(distances.sum(axis=2), axis=0))
    array([3, 3, 1, 2, 2, 1, 1, 1, 0, 3])
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2014-01-25
      • 1970-01-01
      • 2020-03-12
      • 1970-01-01
      • 1970-01-01
      • 2017-08-08
      • 1970-01-01
      相关资源
      最近更新 更多