【问题标题】:Finding K-nearest neighbors and its implementation寻找K近邻及其实现
【发布时间】:2015-02-13 01:44:56
【问题描述】:

我正在使用具有欧几里得距离的 KNN 对简单数据进行分类。我已经看到一个示例,说明我想使用 MATLAB knnsearch 函数完成,如下所示:

load fisheriris 
x = meas(:,3:4);
gscatter(x(:,1),x(:,2),species)
newpoint = [5 1.45];
[n,d] = knnsearch(x,newpoint,'k',10);
line(x(n,1),x(n,2),'color',[.5 .5 .5],'marker','o','linestyle','none','markersize',10)

上面的代码采用一个新点,即[5 1.45],并找到最接近新点的 10 个值。谁能给我看一个 MATLAB 算法,详细解释 knnsearch 函数的作用?有没有其他方法可以做到这一点?

【问题讨论】:

  • 这很简单。对于特定点,我们会找到数据与该点之间最近的 10 个点,并返回那些最近的点,这些点是您数据的一部分。通常,欧几里得距离用于一个点的分量用于比较另一点的分量。维基百科上的这篇文章特别有用:en.wikipedia.org/wiki/K-nearest_neighbors_algorithm
  • 噢……你想自己实现这个过程吗?我当然可以为你提供答案。实现算法实际上并没有你想象的那么难。请说明您需要什么。
  • 是的,我正在尝试自己实现“knnsearch”功能,就像我的代码示例一样,谢谢!
  • 没问题。一会儿我给你写一个答案。我在一个没有 MATLAB 来测试我的代码的地方。当我这样做时,我会写一个答案。但是,为了让您开始,基本过程是找到您的测试点与数据矩阵中的所有其他点之间的欧几里得距离。您将距离从最小到最大排序,然后选择产生最小距离的k 点。尽快回复您!
  • 你好 rayryeng,澄清一下;在这种情况下,我的测试点是 newpoint = [5 1.45];正确的?因此,我现在将用我的数据中的其他点计算 EU 距离; x = 测量(:,3:4); fisheriris 数据是 matlab 示例数据,如果有机会请加载并查看。谢谢!

标签: matlab machine-learning classification knn


【解决方案1】:

K-Nearest Neighbor (KNN) 算法的基础是您有一个由N 行和M 列组成的数据矩阵,其中N 是我们拥有的数据点的数量,而@ 987654332@是每个数据点的维数。例如,如果我们将笛卡尔坐标放在数据矩阵中,这通常是N x 2N x 3 矩阵。使用该数据矩阵,您可以提供一个查询点,然后在该数据矩阵中搜索最接近该查询点的k 点。

我们通常使用查询与数据矩阵中其余点之间的欧几里得距离来计算我们的距离。但是,也使用其他距离,如 L1 或城市街区/曼哈顿距离。完成此操作后,您将获得N 欧几里得或曼哈顿距离,它们表示查询与数据集中每个对应点之间的距离。找到这些后,您只需按升序对距离进行排序并检索那些在您的数据集和查询之间具有最小距离的 k 点,即可搜索到查询最近的 k 点。

假设您的数据矩阵存储在x 中,而newpoint 是一个样本点,其中有M 列(即1 x M),这是您将以点形式遵循的一般过程:

  1. newpointx 中每个点之间的欧几里得或曼哈顿距离。
  2. 按升序对这些距离进行排序。
  3. 返回x 中最接近newpointk 数据点。

让我们慢慢做每一步。


步骤#1

某人可能会这样做的一种方法可能是在for 循环中,如下所示:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
end

如果你想实现曼哈顿距离,这只是:

N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
    dists(idx) = sum(abs(x(idx,:) - newpoint));
end

dists 将是一个N 元素向量,其中包含xnewpoint 中每个数据点之间的距离。我们在newpointx 中的一个数据点之间进行逐个元素的减法,将差值平方,然后将sum 全部加在一起。然后这个和是平方根的,这就完成了欧几里得距离。对于曼哈顿距离,您将执行逐个元素的减法,取绝对值,然后将所有分量相加。这可能是最容易理解的实现,但也可能是效率最低的……尤其是对于更大的数据集和更大维度的数据。

另一种可能的解决方案是复制newpoint 并使该矩阵与x 大小相同,然后对该矩阵进行逐个元素减法,然后对每一行的所有列求和并执行平方根。因此,我们可以这样做:

N = size(x, 1);
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));

对于曼哈顿距离,你会这样做:

N = size(x, 1);
dists = sum(abs(x - repmat(newpoint, N, 1)), 2);

repmat 采用矩阵或向量并在给定方向上重复它们一定次数。在我们的例子中,我们想要使用我们的newpoint 向量,并将这个N 次堆叠在一起以创建一个N x M 矩阵,其中每一行都是M 元素长。我们将这两个矩阵相减,然后对每个分量求平方。一旦我们这样做了,我们sum 在每一行的所有列上,最后取所有结果的平方根。对于曼哈顿距离,我们做减法,取绝对值,然后求和。

但是,我认为最有效的方法是使用bsxfun。这实质上是通过一个函数调用完成了我们在后台讨论的复制。因此,代码将是这样的:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));

在我看来,这看起来更加简洁明了。对于曼哈顿距离,您可以:

dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);

步骤 #2

现在我们有了距离,我们只需对它们进行排序。我们可以使用sort 对我们的距离进行排序:

[d,ind] = sort(dists);

d 将包含按升序排序的距离,而ind 会告诉您 unsorted 数组中的每个值在 sorted 结果中出现的位置。我们需要使用ind,提取这个向量的第一个k元素,然后使用ind索引到我们的x数据矩阵中,返回最接近newpoint的点。

步骤#3

最后一步是现在返回最接近newpoint 的那些k 数据点。我们可以通过以下方式非常简单地做到这一点:

ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

ind_closest 应该包含原始数据矩阵x 中最接近newpoint 的索引。具体来说,ind_closest 包含您需要从x 中采样的,以获得最接近newpoint 的点。 x_closest 将包含这些实际数据点。


为了您的复制和粘贴乐趣,代码如下所示:

dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
%// Or do this for Manhattan
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

运行您的示例,让我们看看我们的代码在运行中:

load fisheriris 
x = meas(:,3:4);
newpoint = [5 1.45];
k = 10;

%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);

通过检查ind_closestx_closest,我们得到了:

>> ind_closest

ind_closest =

   120
    53
    73
   134
    84
    77
    78
    51
    64
    87

>> x_closest

x_closest =

    5.0000    1.5000
    4.9000    1.5000
    4.9000    1.5000
    5.1000    1.5000
    5.1000    1.6000
    4.8000    1.4000
    5.0000    1.7000
    4.7000    1.4000
    4.7000    1.4000
    4.7000    1.5000

如果您运行knnsearch,您将看到您的变量nind_closest 匹配。但是,变量d 返回从newpoint 到每个点x距离,而不是实际数据点本身。如果您想要实际距离,只需在我编写的代码之后执行以下操作:

dist_sorted = d(1:k);

请注意,上述答案仅使用了一批N 示例中的一个查询点。 KNN 经常同时用于多个示例。假设我们有要在 KNN 中测试的 Q 查询点。这将产生一个k x M x Q 矩阵,其中对于每个示例或每个切片,我们返回维度为Mk 最近点。或者,我们可以返回k 最近点的ID,从而生成Q x k 矩阵。让我们计算两者。

一种天真的方法是在循环中应用上述代码并循环遍历每个示例。

如果我们分配一个Q x k 矩阵并应用基于bsxfun 的方法将输出矩阵的每一行设置为数据集中的k 最近点,这样的事情就可以工作,我们将在其中使用Fisher Iris数据集就像我们之前的一样。我们还将保持与上一个示例相同的维度,我将使用四个示例,所以Q = 4M = 2

%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and the output matrices
Q = size(newpoints, 1);
M = size(x, 2);
k = 10;
x_closest = zeros(k, M, Q);
ind_closest = zeros(Q, k);

%// Loop through each point and do logic as seen above:
for ii = 1 : Q
    %// Get the point
    newpoint = newpoints(ii, :);

    %// Use Euclidean
    dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
    [d,ind] = sort(dists);

    %// New - Output the IDs of the match as well as the points themselves
    ind_closest(ii, :) = ind(1 : k).';
    x_closest(:, :, ii) = x(ind_closest(ii, :), :);
end

虽然这很好,但我们可以做得更好。有一种方法可以有效地计算两组向量之间的平方欧几里得距离。如果你想在曼哈顿做这个,我会把它作为一个练习。咨询this blog,假设A 是一个Q1 x M 矩阵,其中每一行是一个维度MQ1 点,B 是一个Q2 x M 矩阵,其中每一行也是一个点维数MQ2 点,我们可以有效地计算距离矩阵D(i, j),其中行i 和列j 的元素表示A 的行i 和行j 之间的距离B 使用以下矩阵公式:

nA = sum(A.^2, 2); %// Sum of squares for each row of A
nB = sum(B.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation

因此,如果我们让A 是查询点矩阵,B 是包含您的原始数据的数据集,我们可以通过单独排序每一行并确定k 来确定最接近的点 k每行最小的位置。我们还可以另外使用它来自己检索实际点。

因此:

%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];

%// Define k and other variables
k = 10;
Q = size(newpoints, 1);
M = size(x, 2);

nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
nB = sum(x.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation 

%// Sort the distances 
[d, ind] = sort(D, 2);

%// Get the indices of the closest distances
ind_closest = ind(:, 1:k);

%// Also get the nearest points
x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]);

我们看到我们使用的计算距离矩阵的逻辑是相同的,但一些变量已经改变以适应示例。我们还使用sort 的两个输入版本对每一行进行独立排序,因此ind 将包含每行的ID,d 将包含相应的距离。然后我们通过简单地将这个矩阵截断为k 列来找出最接近每个查询点的索引。然后我们使用permutereshape 来确定相关的最近点是什么。我们首先使用所有最接近的索引并创建一个点矩阵,将所有 ID 堆叠在一起,因此我们得到一个Q * k x M 矩阵。使用reshapepermute 允许我们创建我们的3D 矩阵,使其成为我们指定的k x M x Q 矩阵。如果您想自己获取实际距离,我们可以索引d 并获取我们需要的内容。为此,您需要使用sub2ind 来获取线性索引,以便我们可以一次性索引到dind_closest 的值已经告诉我们需要访问哪些列。我们需要访问的行只是 1、ktimes、2、ktimes 等,直到Qk 是我们想要返回的点数:

row_indices = repmat((1:Q).', 1, k);
linear_ind = sub2ind(size(d), row_indices, ind_closest);
dist_sorted = D(linear_ind);

当我们对上述查询点运行上述代码时,这些是我们得到的索引、点和距离:

>> ind_closest

ind_closest =

   120   134    53    73    84    77    78    51    64    87
   123   119   118   106   132   108   131   136   126   110
   107    62    86   122    71   127   139   115    60    52
    99    65    58    94    60    61    80    44    54    72

>> x_closest

x_closest(:,:,1) =

    5.0000    1.5000
    6.7000    2.0000
    4.5000    1.7000
    3.0000    1.1000
    5.1000    1.5000
    6.9000    2.3000
    4.2000    1.5000
    3.6000    1.3000
    4.9000    1.5000
    6.7000    2.2000


x_closest(:,:,2) =

    4.5000    1.6000
    3.3000    1.0000
    4.9000    1.5000
    6.6000    2.1000
    4.9000    2.0000
    3.3000    1.0000
    5.1000    1.6000
    6.4000    2.0000
    4.8000    1.8000
    3.9000    1.4000


x_closest(:,:,3) =

    4.8000    1.4000
    6.3000    1.8000
    4.8000    1.8000
    3.5000    1.0000
    5.0000    1.7000
    6.1000    1.9000
    4.8000    1.8000
    3.5000    1.0000
    4.7000    1.4000
    6.1000    2.3000


x_closest(:,:,4) =

    5.1000    2.4000
    1.6000    0.6000
    4.7000    1.4000
    6.0000    1.8000
    3.9000    1.4000
    4.0000    1.3000
    4.7000    1.5000
    6.1000    2.5000
    4.5000    1.5000
    4.0000    1.3000

>> dist_sorted

dist_sorted =

    0.0500    0.1118    0.1118    0.1118    0.1803    0.2062    0.2500    0.3041    0.3041    0.3041
    0.3000    0.3162    0.3606    0.4123    0.6000    0.7280    0.9055    0.9487    1.0198    1.0296
    0.9434    1.0198    1.0296    1.0296    1.0630    1.0630    1.0630    1.1045    1.1045    1.1180
    2.6000    2.7203    2.8178    2.8178    2.8320    2.9155    2.9155    2.9275    2.9732    2.9732

要将其与knnsearch 进行比较,您可以为第二个参数指定一个点矩阵,其中每一行都是一个查询点,您将看到此实现与knnsearch 之间的索引和排序距离匹配。


希望这对您有所帮助。祝你好运!

【讨论】:

  • 这确实很有帮助!非常感谢!现在我明白了@rayryeng​​span>
  • @Young_DataAnalyst - 很高兴!如果我帮助了你,请考虑接受我的回答:)。祝你好运!
  • @Kamtal - 酷:)!我很高兴!
  • @rayryeng 如果 x 和 newpoints 在同一对称平面上都是对称的,则 knnsearch 可能会在该平面上返回不对称索引。有没有办法实现对称性?
  • @JuneWang 我不明白你的询问。你能给我举个例子吗?
猜你喜欢
  • 2012-08-03
  • 1970-01-01
  • 2011-06-24
  • 1970-01-01
  • 2015-03-27
相关资源
最近更新 更多