【问题标题】:How to replace For Loops and IF statements with Numpy arrays如何用 Numpy 数组替换 For 循环和 IF 语句
【发布时间】:2023-03-16 08:00:02
【问题描述】:

我有一个这样的 numpy 数组:

[[1, 2], [1, 3], [2, 1], [2, 2], [2, 3], ...]

我想三乘三得到所有“子”数组(即 [X, Y])的组合:

[[1, 1] [1, 1] [1, 1],
 [1, 1] [1, 1] [1, 2],
 [1, 1] [1, 1] [1, 3],
 ...
 [5, 5] [5, 5], [5, 4],
 [5, 5] [5, 5], [5, 5]]

然后,我需要对每个组合应用条件:

  • X1, X2, X3 > 0
  • X1+Y1 <= X2
  • X2+Y2 <= X3
  • [X1, Y1] =! [X2, Y2]
  • [X2, Y2] =! [X3, Y3]
  • ...

由于组合数量众多,我绝对需要避免 for 循环。

知道如何在有效的执行时间内完成这项工作吗?


我当前的带有 for 循环和 if 语句的代码:

mylist 对象类似于 [[1, 1], [1, 2], [2, 1], ...](即列表列表,如 [X, Y])。

组合 = [] 留在我的列表中:

    if left[0] > 0:
        
        for center in mylist:   
            
            if (center[0] > 0 
                and center[0] >= left[0] + left[1]
                and center[1] / left[1] < 2 
                and center[0] / left[0] < 2
                and left[1] / center[1] < 2 
                and left[0] / center[1] < 2 
                and str(left[0]) + "y" + str(left[1]) + "y" != str(center[0]) + "y" + str(center[1]) + "y"
                ):
            
                for right in mylist:   
        
                    if (right[0] > 0 
                        and right[0] >= center[0] + center[1]
                        and right[1] / center[1] < 2 
                        and right[0] / center[0] < 2
                        and center[1] / right[1] < 2 
                        and center[0] / right[0] < 2
                        and str(right[0]) + "y" + str(right[1]) + "y" != str(center[0]) + "y" + str(center[1]) + "y"
                        ):

                        Combination.append([[left[0], left[1]], [center[0], center[1]], [right[0], right[1]])

【问题讨论】:

  • 到目前为止你写过什么代码吗?
  • 我编辑了我的帖子以添加带有 for 循环和 if 语句的原始代码。

标签: python numpy


【解决方案1】:

试试itertoolnumpy 喜欢:

import numpy as np
import itertools


some_list = [[1, 2], [1, 3], [2, 1], [2, 2], [2, 3], [-1,-1]]


# use "itertools.combinations" or "itertools.combinations_with_replacement"
# whatever you want to get in therms of repeting elements.
# Then cast it into a numpy array.
combinations = np.array(list(itertools.combinations_with_replacement(some_list, 3)))


# from here your can do your boolean statements in the numpy sytax for example
# applying your first rule "X1,X2,X3 > 0" could be done with:
first_rule = combinations[:,:,0] > 0
print('boolean array for the first rule "X1,X2,X3 > 0"')
print(np.all(first_rule,axis=1))


# and the second rule "X1 + Y1 <= X2"
second_rule = combinations[:,0,0]+combinations[:,0,1] <= combinations[:,1,0]
print('\n\nboolean array for the first rule "X1 + Y1 <= X2"')
print(second_rule)

我认为它不仅仅是一个规则网格,因为第一个条件 X1,X2,X3 > 0,但是如果它是规则的,那么 meshgrid 是最好的解决方案(参见另一个答案)。

【讨论】:

  • 高效应用条件的绝佳解决方案。但是,对于创建组合,最好使用 np.meshgrid() 作为 creating a numpy array from a python list is slow.
  • 非常正确,但我不确定它是否是常规网格,如果是,则只需应用一次即可保存。但是 ofc np.meshgird 是更优雅的解决方案。
【解决方案2】:

编辑:您甚至不需要itertools,您可以使用numpy 来创建组合,而且速度非常快

# This is your input array from [1,1] to [5,5]
a = np.array(np.meshgrid(np.arange(1,6), np.arange(1,6))).T.reshape(-1,2)

b = np.array(np.meshgrid(a, a, a)).T.reshape(-1, 3, 2)

如您所见,需要 6 毫秒:5.88 ms ± 836 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

你的数组现在看起来像这样:


array([[[1, 1],
        [1, 1],
        [1, 1]],

       [[1, 1],
        [1, 1],
        [1, 2]],

       [[1, 1],
        [1, 1],
        [1, 3]],

       ...,

       [[5, 5],
        [5, 5],
        [5, 3]],

       [[5, 5],
        [5, 5],
        [5, 4]],

       [[5, 5],
        [5, 5],
        [5, 5]]])

因为现在这是一个 numpy 数组,所以您可以安全地使用 for 循环来检查您的条件。例如,row[0,0] 将是您的 X1row[0,1] 将是您的 Y1 等。

for row in b:
    row[0,0] + row[0,1] <= row[1,0]

这也需要很短的时间来执行:10.3 ms ± 278 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

因此,您也可以在其他情况下安全地执行此操作。

【讨论】:

  • 混合了 Mashudan 和 emilanov 的答案,我得到了我想要的东西!非常感谢
【解决方案3】:

from itertools import product a = [[1,1],[1, 2], [1, 3], [2, 1], [2, 2], [2, 3]] perms = np.array(list(product(a, repeat=3))) 这将创建一个形状为(n^3, 3, 2) 的数组,其中na 中的元素数。

现在您可以进行所有花哨的操作...

perms[:, :, 0] > 0
perms[:, 0, 0] + perms[:, 0, 1] <= perms[:, 1, 0]
perms[:, 1, 0] + perms[:, 1, 1] <= perms[:, 2, 0]
perms[:, 0, :] != perms[:, 1, :]
perms[:, 1, :] != perms[:, 2, :]
...

请注意,最后两个表达式将分别检查x1!=x2y1!=y2 并返回形状为(n^3,2) 的结果。但是,如果您的要求是检查这些实例是否作为一个整体不相等,您可以这样做

output = perms[:, 0, :] != perms[:, 1, :] np.logical_or(output[:, 0], output[:, 1])

这将返回形状为 (n^3) 的输出。

【讨论】:

    猜你喜欢
    • 2014-03-30
    • 1970-01-01
    • 1970-01-01
    • 2013-02-06
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-05-30
    • 2023-03-30
    相关资源
    最近更新 更多