【问题标题】:Fastest way to find the nearest element in the array of real numbers在实数数组中找到最近元素的最快方法
【发布时间】:2019-07-07 07:06:23
【问题描述】:

对于每个元素的给定实数数组,找到比当前元素小不超过0.5的元素个数并写入新数组。

例如:

原始数组:

[0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]

结果数组:

[0,   0,   1,   2,    3,   0,   1]

解决这个问题的算法和方法是什么?

重要的是只在负方向上选择点的邻域,这使得无法使用KdtreeBalltree 算法。

我的所有问题都是here,我尝试编写代码。

【问题讨论】:

  • 你能展示你尝试过的东西吗
  • 我的所有问题都是here,我尝试了代码。
  • 数字是否总是排序的?
  • @max9111 不,它们没有排序

标签: python search tree binary-search-tree kdtree


【解决方案1】:

下面的方法虽然逻辑简单,写起来容易,但是速度慢。我们可以通过使用修饰的Numba 函数来加速它。这会将简单的循环任务加速到接近汇编语言的速度。

使用pip install numba 安装 Numba。

from numba import jit
import numpy as np

# Create a numpy array of length 10000 with float values between 0 and 10
random_values = np.random.uniform(0.0,10.0,size=(100*100,))

@jit(nopython=True, nogil=True)
def find_nearest(input):
  result = []
  for e in input:
    counter = 0
    for j in input:
      if j >= (e-0.5) and j < e:
        counter += 1
    result.append(counter)
  return result

result = find_nearest(random_values)

请注意,为测试用例返回预期结果:

test = [0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7]
result = find_nearest(test)
print result

返回:

[0, 0, 1, 2, 3, 0, 1]

【讨论】:

  • 这是一个非常缓慢的决定,难度为 O(n^2)。
  • 帖子中没有提到速度优化....但是是的,这可能不是最快的。
  • 也不是他想要的,因为他只对将当前数字与数组左侧进行比较感兴趣。
  • 真的吗?我不这样读问题....“点的邻域仅在负方向上选择”意味着搜索邻域距离为 0.5,但仅限于仅搜索 x - 0.5 内的值(其中 x是目标值)。而不是 x +/- 0.5。
  • @Иван 我已经修改了我的答案以获得更好的速度优化。
【解决方案2】:

这将解决您的特定任务。

def find_nearest_element(original_array):
    result_array = []
    for e in original_array:
        result_array.append(len(original_array[(e-0.5 < original_array) & (e > original_array)]))
    return result_array

original_array = np.array([0.1, 0.7, 0.8, 0.85, 0.9, 1.5, 1.7])
print(find_nearest_element(original_array))

输出:

[0, 0, 1, 2, 3, 0, 1]

编辑:对于较小的数组(ca. len 10000),使用掩码显然比使用 numba 的版本快。对于更大的阵列,使用 Numba 的版本更快。因此,这取决于您要进行的数组大小。

一些运行时比较(以秒为单位):

For smaller arrays(size=250):
Using Numba 0.2569999694824219
Using mask 0.0350041389465332
For bigger arrays(size=40000):
Using Numba 1.4619991779327393
Using mask 4.280000686645508

我的设备上的盈亏平衡点大约是 10000(两者都需要大约 0.33 秒)。

【讨论】:

  • 真的很有趣,感谢时间比较! (虽然,也许这不应该被格式化为代码块?)
  • 我也这么认为。但是如果我不把它放在代码块中,行之间的距离会变得非常大:/至少它是一个脚本的输出^^
  • 你的意思是“都需要大约 0.33 秒”吗?
【解决方案3】:

对于有序数组,这个问题很容易解决。您只需向后搜索并计算所有大于实际数字半径的数字。如果不再满足该条件,您可以跳出内部循环(这样可以节省大量时间)。

示例

import numpy as np
from scipy import spatial
import numba as nb

@nb.njit(parallel=True)
def get_counts_2(Points_sorted,ind,r):
  counts=np.zeros(Points_sorted.shape[0],dtype=np.int64)
  for i in nb.prange(0,Points_sorted.shape[0]):
    count=0
    for j in range(i-1,0,-1):
      if (Points_sorted[i]-r<Points_sorted[j]):
        count+=1
      else:
        break
    counts[ind[i]]=count
  return counts

时间

r=0.001
Points=np.random.rand(1_000_000)

t1=time.time()
ind=np.argsort(Points)
Points_sorted=Points[ind]
counts=get_counts_2(Points_sorted,ind,r)
print(time.time()-t1)
#0.29s

【讨论】:

    猜你喜欢
    • 1970-01-01
    • 2011-12-09
    • 1970-01-01
    • 2014-04-19
    • 2013-12-26
    • 2021-10-28
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    相关资源
    最近更新 更多