【问题标题】:Is there a faster way to mask an array?有没有更快的方法来屏蔽数组?
【发布时间】:2020-05-10 10:29:32
【问题描述】:

我有一个 numpy 数组,我需要屏蔽它。 我的函数如下所示:

def mask_arr(arr, min, max):
    for i in range(arr.size-1):
        if arr[i] < min:
            arr[i] = 0
        elif arr[i] > max:
            arr[i] = 1
        else:
            arr[i] = 10

问题是数组很大,屏蔽它需要很长时间。 我怎样才能获得相同的结果但更快?

【问题讨论】:

    标签: python performance numpy


    【解决方案1】:

    您可以使用嵌套的np.where,如下所示:

    import numpy as np
    q = np.random.rand(4,4)
    # array([[0.86305369, 0.88477713, 0.58776518, 0.69122533],
    #   [0.52591559, 0.33155238, 0.50139987, 0.66812239],
    #   [0.83240284, 0.70147098, 0.17118681, 0.59652636],
    #   [0.82031661, 0.32032657, 0.55088698, 0.28931661]])
    np.where(q > 0.8, 1, np.where(q < 0.3, 0, 10))
    # array([[ 1,  1, 10, 10],
    #   [10, 10, 10, 10],
    #   [ 1, 10,  0, 10],
    #   [ 1, 10, 10,  0]])
    

    编辑:

    根据您的问题,如果您想在数组元素不大于maxVal 或小于minVal 的情况下更改值,您可以执行或您想要的任何其他逻辑:

    import numpy as np
    q = q = np.random.rand(4,4)
    minVal = 0.3
    maxVal = 0.9
    qq = np.where(q > 0.8, 1, np.where(q < 0.3, 0, 2 * q))
    

    q 在哪里:

    [[0.63604995 0.18637738 0.90680287 0.64617278]
     [0.97435344 0.04670638 0.3510053  0.71613776]
     [0.17973416 0.50296747 0.35085383 0.853201  ]
     [0.27820978 0.69438172 0.96186074 0.96625938]]
    

    qq 是:

    [[1.27209991 0.         1.         1.29234556]
     [1.         0.         0.7020106  1.43227553]
     [0.         1.00593493 0.70170767 1.        ]
     [0.         1.38876345 1.         1.        ]]
    

    【讨论】:

      【解决方案2】:

      解决方案

      您可以根据自己的规则使用三个简单的分配。这使用了numpy 中可用的本机矢量化,因此与您尝试过的相比会更快。

      # minval, maxval = 0.3, 0.8
      condition = np.logical_and(a>=minval, a<=maxval)
      a[a<minval] = 0 
      a[a>maxval] = 1
      a[condition] = 10 # if a constant value of 10
      a[condition] *= 2 # if each element gets multiplied by 2
      

      输出

      [[10.  0. 10.  1.  0.]
       [10. 10. 10.  0. 10.]
       [ 1. 10. 10.  1.  1.]
       [ 0.  1. 10. 10.  0.]
       [ 0.  0. 10. 10. 10.]]
      

      虚拟数据

      a = np.random.rand(5,5)
      

      输出

      array([[0.68554168, 0.27430639, 0.4382025 , 0.97162651, 0.16740865],
             [0.32530579, 0.3415287 , 0.45920916, 0.09422211, 0.75247522],
             [0.91621921, 0.65845783, 0.38678723, 0.83644281, 0.95865701],
             [0.26290637, 0.83810284, 0.55327399, 0.3406887 , 0.26173914],
             [0.24974815, 0.08543414, 0.78509214, 0.64663201, 0.61502744]])
      

      便利功能

      由于您提到您还可以将目标元素自乘以两倍,因此我将该功能扩展到绝对赋值(设置值 10)或相对更新(加、减、乘、除)w.r.t数组的当前值。

      def mask_arr(arr, 
                   minval: float = 0.3, 
                   maxval: float = 0.8, 
                   update_type: str = 'abs', 
                   update_value: float = 10, 
                   rel_update_method: str = '*', 
                   mask_floor: float = 0.0, 
                   mesk_ceiling: float = 1.0):
          """Returns the array arr after setting lower-bound (mask_floor), 
          upper-bound (mask_ceiling), and logic-for-in-between-values. 
      
          """
          # minval, maxval = 0.3, 0.8
          condition = np.logical_and(arr>=minval, arr<=maxval)
          arr[arr<minval] = lowerbound 
          arr[arr>maxval] = upperbound
          if update_type=='abs':
              # absolute update 
              arr[condition] = update_value
          if update_type=='rel': 
              # relative update
              if rel_update_method=='+':
                  arr[condition] += update_value
              if rel_update_method=='-':
                  arr[condition] -= update_value
              if rel_update_method=='*':
                  arr[condition] *= update_value
              if rel_update_method=='/':
                  arr[condition] /= update_value
          return arr
      

      示例

      # declare all inputs
      arr = mask_arr(arr, 
                      minval = 0.3, 
                      maxval = 0.8, 
                      update_type = 'rel', 
                      update_value = 2.0, 
                      rel_update_method = '*', 
                      mask_floor = 0.0, 
                      mesk_ceiling = 1.0)
      
      # using defaults for 
      #   mask_floor = 0.0, 
      #   mesk_ceiling = 1.0
      arr = mask_arr(arr, 
                      minval = 0.3, 
                      maxval = 0.8, 
                      update_type = 'rel', 
                      update_value = 2.0, 
                      rel_update_method = '*')
      
      # using defaults as before and 
      # setting a fixed value of 10
      arr = mask_arr(arr, 
                      minval = 0.3, 
                      maxval = 0.8, 
                      update_type = 'abs', 
                      update_value = 10.0)
      

      【讨论】:

      • 如果不是 10,我想放 2*arr[i]
      • 为 arr[i]*2 添加逻辑
      【解决方案3】:

      使用 numpy,您无需为此类操作执行循环。 此外,我建议您不要使用 'min' 和 'max' 作为变量名,因为它们是保留名称。

      试试下面的

      arr[arr < min_val]=0
      arr[arr > max_val]=1
      arr[(arr<=max_val) & (arr>=min_val)]=10
      

      【讨论】:

        猜你喜欢
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 2014-09-23
        • 2015-04-22
        • 2021-05-06
        • 1970-01-01
        • 2020-02-26
        • 2012-04-01
        相关资源
        最近更新 更多