【发布时间】:2017-06-29 15:00:44
【问题描述】:
我的问题大致如下。给定一个数字矩阵 X,其中每一行是一个项目。我想根据除自身之外的所有行中的 L2 距离找到每一行的最近邻居。我尝试阅读官方文档,但对于如何实现这一点仍然有些困惑。有人可以给我一些提示吗?
我的代码如下
function l2_dist(v1, v2)
return sqrt(sum((v1 - v2) .^ 2))
end
function main(Mat, dist_fun)
n = size(Mat, 1)
Dist = SharedArray{Float64}(n) #[Inf for i in 1:n]
Id = SharedArray{Int64}(n) #[-1 for i in 1:n]
@parallel for i = 1:n
Dist[i] = Inf
Id[i] = 0
end
Threads.@threads for i in 1:n
for j in 1:n
if i != j
println(i, j)
dist_temp = dist_fun(Mat[i, :], Mat[j, :])
if dist_temp < Dist[i]
println("Dist updated!")
Dist[i] = dist_temp
Id[i] = j
end
end
end
end
return Dict("Dist" => Dist, "Id" => Id)
end
n = 4000
p = 30
X = [rand() for i in 1:n, j in 1:p];
main(X[1:30, :], l2_dist)
@time N = main(X, l2_dist)
我正在尝试将所有 i (即计算每行最小值)分布在不同的核心上。但是上面的版本显然不能正常工作。它甚至比顺序版本还要慢。有人可以指出我正确的方向吗?谢谢。
【问题讨论】:
-
我建议在以任何其他方式优化之前将该字典从内部循环中取出。字典对运行时有害。如果你真的想要一个字典,可以在没有字典的情况下创建数组,然后在循环后将它们添加到字典中。这可能会有所帮助。
-
@ChrisRackauckas 谢谢。我想我可以只创建两个数组,并且只在最后一步使它们成为字典。你对并行化部分有什么提示吗?
-
“您对并行化部分有任何提示吗?”,除了使用
Threads.@threads对其进行多线程处理,然后通过每个线程使用不同的数组并随后合并使其成为线程安全?不,应该很标准。如果我有时间,我可以把它写下来。 -
@ChrisRackauckas 非常感谢。我刚刚尝试按照您对字典的建议修改代码(现在在问题中更新)。只是这样做给了我 20 倍的加速。太棒了!
-
是否打算使用
dist_fun(Mat[i, ], Mat[j, ])而不是dist_fun(Mat[i,:], Mat[j,:])?前者要快得多,但后者实际上给出了正确的答案。在表达式之前添加@views可以提高速度/内存。
标签: parallel-processing julia distance-matrix