K-Nearest Neighbor (KNN) 算法的基础是您有一个由N 行和M 列组成的数据矩阵,其中N 是我们拥有的数据点的数量,而@ 987654332@是每个数据点的维数。例如,如果我们将笛卡尔坐标放在数据矩阵中,这通常是N x 2 或N x 3 矩阵。使用该数据矩阵,您可以提供一个查询点,然后在该数据矩阵中搜索最接近该查询点的k 点。
我们通常使用查询与数据矩阵中其余点之间的欧几里得距离来计算我们的距离。但是,也使用其他距离,如 L1 或城市街区/曼哈顿距离。完成此操作后,您将获得N 欧几里得或曼哈顿距离,它们表示查询与数据集中每个对应点之间的距离。找到这些后,您只需按升序对距离进行排序并检索那些在您的数据集和查询之间具有最小距离的 k 点,即可搜索到查询最近的 k 点。
假设您的数据矩阵存储在x 中,而newpoint 是一个样本点,其中有M 列(即1 x M),这是您将以点形式遵循的一般过程:
- 求
newpoint 和x 中每个点之间的欧几里得或曼哈顿距离。
- 按升序对这些距离进行排序。
- 返回
x 中最接近newpoint 的k 数据点。
让我们慢慢做每一步。
步骤#1
某人可能会这样做的一种方法可能是在for 循环中,如下所示:
N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
dists(idx) = sqrt(sum((x(idx,:) - newpoint).^2));
end
如果你想实现曼哈顿距离,这只是:
N = size(x,1);
dists = zeros(N,1);
for idx = 1 : N
dists(idx) = sum(abs(x(idx,:) - newpoint));
end
dists 将是一个N 元素向量,其中包含x 和newpoint 中每个数据点之间的距离。我们在newpoint 和x 中的一个数据点之间进行逐个元素的减法,将差值平方,然后将sum 全部加在一起。然后这个和是平方根的,这就完成了欧几里得距离。对于曼哈顿距离,您将执行逐个元素的减法,取绝对值,然后将所有分量相加。这可能是最容易理解的实现,但也可能是效率最低的……尤其是对于更大的数据集和更大维度的数据。
另一种可能的解决方案是复制newpoint 并使该矩阵与x 大小相同,然后对该矩阵进行逐个元素减法,然后对每一行的所有列求和并执行平方根。因此,我们可以这样做:
N = size(x, 1);
dists = sqrt(sum((x - repmat(newpoint, N, 1)).^2, 2));
对于曼哈顿距离,你会这样做:
N = size(x, 1);
dists = sum(abs(x - repmat(newpoint, N, 1)), 2);
repmat 采用矩阵或向量并在给定方向上重复它们一定次数。在我们的例子中,我们想要使用我们的newpoint 向量,并将这个N 次堆叠在一起以创建一个N x M 矩阵,其中每一行都是M 元素长。我们将这两个矩阵相减,然后对每个分量求平方。一旦我们这样做了,我们sum 在每一行的所有列上,最后取所有结果的平方根。对于曼哈顿距离,我们做减法,取绝对值,然后求和。
但是,我认为最有效的方法是使用bsxfun。这实质上是通过一个函数调用完成了我们在后台讨论的复制。因此,代码将是这样的:
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
在我看来,这看起来更加简洁明了。对于曼哈顿距离,您可以:
dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
步骤 #2
现在我们有了距离,我们只需对它们进行排序。我们可以使用sort 对我们的距离进行排序:
[d,ind] = sort(dists);
d 将包含按升序排序的距离,而ind 会告诉您 unsorted 数组中的每个值在 sorted 结果中出现的位置。我们需要使用ind,提取这个向量的第一个k元素,然后使用ind索引到我们的x数据矩阵中,返回最接近newpoint的点。
步骤#3
最后一步是现在返回最接近newpoint 的那些k 数据点。我们可以通过以下方式非常简单地做到这一点:
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
ind_closest 应该包含原始数据矩阵x 中最接近newpoint 的索引。具体来说,ind_closest 包含您需要从x 中采样的行,以获得最接近newpoint 的点。 x_closest 将包含这些实际数据点。
为了您的复制和粘贴乐趣,代码如下所示:
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
%// Or do this for Manhattan
% dists = sum(abs(bsxfun(@minus, x, newpoint)), 2);
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
运行您的示例,让我们看看我们的代码在运行中:
load fisheriris
x = meas(:,3:4);
newpoint = [5 1.45];
k = 10;
%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
ind_closest = ind(1:k);
x_closest = x(ind_closest,:);
通过检查ind_closest 和x_closest,我们得到了:
>> ind_closest
ind_closest =
120
53
73
134
84
77
78
51
64
87
>> x_closest
x_closest =
5.0000 1.5000
4.9000 1.5000
4.9000 1.5000
5.1000 1.5000
5.1000 1.6000
4.8000 1.4000
5.0000 1.7000
4.7000 1.4000
4.7000 1.4000
4.7000 1.5000
如果您运行knnsearch,您将看到您的变量n 与ind_closest 匹配。但是,变量d 返回从newpoint 到每个点x 的距离,而不是实际数据点本身。如果您想要实际距离,只需在我编写的代码之后执行以下操作:
dist_sorted = d(1:k);
请注意,上述答案仅使用了一批N 示例中的一个查询点。 KNN 经常同时用于多个示例。假设我们有要在 KNN 中测试的 Q 查询点。这将产生一个k x M x Q 矩阵,其中对于每个示例或每个切片,我们返回维度为M 的k 最近点。或者,我们可以返回k 最近点的ID,从而生成Q x k 矩阵。让我们计算两者。
一种天真的方法是在循环中应用上述代码并循环遍历每个示例。
如果我们分配一个Q x k 矩阵并应用基于bsxfun 的方法将输出矩阵的每一行设置为数据集中的k 最近点,这样的事情就可以工作,我们将在其中使用Fisher Iris数据集就像我们之前的一样。我们还将保持与上一个示例相同的维度,我将使用四个示例,所以Q = 4 和M = 2:
%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];
%// Define k and the output matrices
Q = size(newpoints, 1);
M = size(x, 2);
k = 10;
x_closest = zeros(k, M, Q);
ind_closest = zeros(Q, k);
%// Loop through each point and do logic as seen above:
for ii = 1 : Q
%// Get the point
newpoint = newpoints(ii, :);
%// Use Euclidean
dists = sqrt(sum(bsxfun(@minus, x, newpoint).^2, 2));
[d,ind] = sort(dists);
%// New - Output the IDs of the match as well as the points themselves
ind_closest(ii, :) = ind(1 : k).';
x_closest(:, :, ii) = x(ind_closest(ii, :), :);
end
虽然这很好,但我们可以做得更好。有一种方法可以有效地计算两组向量之间的平方欧几里得距离。如果你想在曼哈顿做这个,我会把它作为一个练习。咨询this blog,假设A 是一个Q1 x M 矩阵,其中每一行是一个维度M 与Q1 点,B 是一个Q2 x M 矩阵,其中每一行也是一个点维数M 和Q2 点,我们可以有效地计算距离矩阵D(i, j),其中行i 和列j 的元素表示A 的行i 和行j 之间的距离B 使用以下矩阵公式:
nA = sum(A.^2, 2); %// Sum of squares for each row of A
nB = sum(B.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*A*B.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation
因此,如果我们让A 是查询点矩阵,B 是包含您的原始数据的数据集,我们可以通过单独排序每一行并确定k 来确定最接近的点 k每行最小的位置。我们还可以另外使用它来自己检索实际点。
因此:
%// Load the data and create the query points
load fisheriris;
x = meas(:,3:4);
newpoints = [5 1.45; 7 2; 4 2.5; 2 3.5];
%// Define k and other variables
k = 10;
Q = size(newpoints, 1);
M = size(x, 2);
nA = sum(newpoints.^2, 2); %// Sum of squares for each row of A
nB = sum(x.^2, 2); %// Sum of squares for each row of B
D = bsxfun(@plus, nA, nB.') - 2*newpoints*x.'; %// Compute distance matrix
D = sqrt(D); %// Compute square root to complete calculation
%// Sort the distances
[d, ind] = sort(D, 2);
%// Get the indices of the closest distances
ind_closest = ind(:, 1:k);
%// Also get the nearest points
x_closest = permute(reshape(x(ind_closest(:), :).', M, k, []), [2 1 3]);
我们看到我们使用的计算距离矩阵的逻辑是相同的,但一些变量已经改变以适应示例。我们还使用sort 的两个输入版本对每一行进行独立排序,因此ind 将包含每行的ID,d 将包含相应的距离。然后我们通过简单地将这个矩阵截断为k 列来找出最接近每个查询点的索引。然后我们使用permute 和reshape 来确定相关的最近点是什么。我们首先使用所有最接近的索引并创建一个点矩阵,将所有 ID 堆叠在一起,因此我们得到一个Q * k x M 矩阵。使用reshape 和permute 允许我们创建我们的3D 矩阵,使其成为我们指定的k x M x Q 矩阵。如果您想自己获取实际距离,我们可以索引d 并获取我们需要的内容。为此,您需要使用sub2ind 来获取线性索引,以便我们可以一次性索引到d。 ind_closest 的值已经告诉我们需要访问哪些列。我们需要访问的行只是 1、ktimes、2、ktimes 等,直到Q。 k 是我们想要返回的点数:
row_indices = repmat((1:Q).', 1, k);
linear_ind = sub2ind(size(d), row_indices, ind_closest);
dist_sorted = D(linear_ind);
当我们对上述查询点运行上述代码时,这些是我们得到的索引、点和距离:
>> ind_closest
ind_closest =
120 134 53 73 84 77 78 51 64 87
123 119 118 106 132 108 131 136 126 110
107 62 86 122 71 127 139 115 60 52
99 65 58 94 60 61 80 44 54 72
>> x_closest
x_closest(:,:,1) =
5.0000 1.5000
6.7000 2.0000
4.5000 1.7000
3.0000 1.1000
5.1000 1.5000
6.9000 2.3000
4.2000 1.5000
3.6000 1.3000
4.9000 1.5000
6.7000 2.2000
x_closest(:,:,2) =
4.5000 1.6000
3.3000 1.0000
4.9000 1.5000
6.6000 2.1000
4.9000 2.0000
3.3000 1.0000
5.1000 1.6000
6.4000 2.0000
4.8000 1.8000
3.9000 1.4000
x_closest(:,:,3) =
4.8000 1.4000
6.3000 1.8000
4.8000 1.8000
3.5000 1.0000
5.0000 1.7000
6.1000 1.9000
4.8000 1.8000
3.5000 1.0000
4.7000 1.4000
6.1000 2.3000
x_closest(:,:,4) =
5.1000 2.4000
1.6000 0.6000
4.7000 1.4000
6.0000 1.8000
3.9000 1.4000
4.0000 1.3000
4.7000 1.5000
6.1000 2.5000
4.5000 1.5000
4.0000 1.3000
>> dist_sorted
dist_sorted =
0.0500 0.1118 0.1118 0.1118 0.1803 0.2062 0.2500 0.3041 0.3041 0.3041
0.3000 0.3162 0.3606 0.4123 0.6000 0.7280 0.9055 0.9487 1.0198 1.0296
0.9434 1.0198 1.0296 1.0296 1.0630 1.0630 1.0630 1.1045 1.1045 1.1180
2.6000 2.7203 2.8178 2.8178 2.8320 2.9155 2.9155 2.9275 2.9732 2.9732
要将其与knnsearch 进行比较,您可以为第二个参数指定一个点矩阵,其中每一行都是一个查询点,您将看到此实现与knnsearch 之间的索引和排序距离匹配。
希望这对您有所帮助。祝你好运!