【问题标题】:mlpack nearest neighbor with cosine distance?具有余弦距离的mlpack最近邻居?
【发布时间】:2017-06-25 04:13:09
【问题描述】:

我想使用 mlpack 中的 NeighborSearch 类对一些表示文档的向量进行 KNN 分类。

我想使用余弦距离,但我遇到了问题。我认为这样做的方法是使用内积度量“IPMetric”并指定 CosineDistance 内核......这就是我所拥有的:

NeighborSearch<NearestNeighborSort, IPMetric<CosineDistance>> nn(X_train);

但我得到以下编译错误:

/usr/include/mlpack/core/tree/hrectbound_impl.hpp:211:15: error: ‘Power’ is not a member of ‘mlpack::metric::IPMetric<mlpack::kernel::CosineDistance>’
 sum += pow((lower + fabs(lower)) + (higher + fabs(higher)),
           ^
/usr/include/mlpack/core/tree/hrectbound_impl.hpp:220:3: error: ‘TakeRoot’ is not a member of ‘mlpack::metric::IPMetric<mlpack::kernel::CosineDistance>’
if (MetricType::TakeRoot)
^

我怀疑问题可能是默认的树类型 KDTree 不支持这个距离度量?如果这是问题所在,是否有适用于 CosineDistance 的树类型?

最后,是否可以使用暴力搜索?我似乎根本找不到不使用树的方法...

谢谢!

【问题讨论】:

    标签: c++ metrics knn mlpack


    【解决方案1】:

    不幸的是,正如您所怀疑的那样,任意度量类型不适用于 KDTree——这是因为 kd-tree 需要一个可以分解为不同维度的距离。但这对于IPMetric 是不可能的。相反,为什么不尝试使用覆盖树呢?树的构建时间可能会稍长一些,但它应该提供相当的性能:

    NeighborSearch<NearestNeighborSort, IPMetric<CosineDistance>, arma::mat,
        tree::StandardCoverTree> nn(X_train);
    

    如果要进行暴力搜索,请在构造函数中指定搜索方式:

    NeighborSearch<NearestNeighborSort, IPMetric<CosineDistance>, arma::mat,
        tree::StandardCoverTree> nn(X_train, NAIVE_MODE);
    

    我希望这会有所帮助;如果我能澄清任何事情,请告诉我。

    【讨论】:

    • Ack,我应该补充两点:我认为tree::BallTree 在这里也适合你,而且,你的任务是找到具有最大余弦相似度的点吗?您使用的设置将找到具有最小余弦相似度的点。
    • 谢谢!!我尝试使用 BallTree,但我仍然遇到编译错误......这是说邻居搜索需要 binary_space_tree,它使用 hrectbound,它需要 LMetric。如果有帮助,我已经分享了我的代码 here 和编译器输出 here。可能是我在使用#includes 做一些愚蠢的事情吗?
    • 我看到您使用的是 mlpack 2.0.2,因此请使用 true 作为构造函数的第二个参数,而不是 NAIVE_MODE。您提供的代码和编译器输出没有加起来;如果正在使用 BallTree,它根本不应该实例化 HRectBound 类。你确定这是一个干净的构建吗?
    • 我想我可能已经弄明白了——NeighborSearch 也将“TraversalType”作为模板参数,默认为 mlpack::tree::BinarySpaceTree,它会生成那些编译错误。对我应该将 TraversalType 设置为什么有任何见解?
    • 这很奇怪,TraversalType 应该默认为给定的 TreeType,而不是 BinarySpaceTree。您可能可以将其设置为BallTree&lt;IPMetric&lt;CosineDistance&gt;, NeighborSearchStat&lt;NearestNeighborSort&gt;&gt;, arma::mat&gt;::DualTreeTraverser,我认为这会起作用。不过,您想在 mlpack github 上打开一个错误吗?我认为你发现了一个问题。