【问题标题】:Removing all zeros from an array从数组中删除所有零
【发布时间】:2019-02-14 01:24:11
【问题描述】:

我有一个形状为 [120000, 3] 的数组,其中只有前 1500 个元素有用,其他元素为 0。

这里是一个例子

[15.0, 14.0, 13.0]
[11.0, 7.0, 8.0]
[4.0, 1.0, 3.0]
[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]
[0.0, 0.0, 0.0]

我必须找到一种方法来删除所有 [0.0, 0.0, 0.0] 的元素。我试图写这个,但它不起作用

for point in points:
        if point[0] == 0.0 and point[1] == 0.0 and point[2] == 0.0:
            np.delete(points, point)

编辑

评论中的所有解决方案都有效,但我给我用过的那个打了绿色勾号。谢谢大家。

【问题讨论】:

    标签: python arrays python-2.7 numpy


    【解决方案1】:

    不要使用 for 循环——它们很慢。在 for 循环中重复调用 np.delete 会导致性能下降。

    相反,创建一个掩码:

    zero_rows = (points == 0).all(1)
    

    这是一个长度为 120000 的数组,它是 True,其中该行中的所有元素都是 0。

    然后找到第一个这样的行:

    first_invalid = np.where(zero_rows)[0][0]
    

    最后,对数组进行切片:

    points[:first_invalid]
    

    【讨论】:

      【解决方案2】:
      x = [[15.0, 14.0, 13.0],
      [11.0, 7.0, 8.0],
      [4.0, 1.0, 3.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0]]
      

      简单的迭代解决方案:

      y = [i for i in x if i != [0.0, 0.0, 0.0]]
      

      更好的解决方案(Python 3.x):

      y = list(filter(lambda a: a != [0.0, 0.0, 0.0], x))
      

      输出:

      [[15.0, 14.0, 13.0], [11.0, 7.0, 8.0], [4.0, 1.0, 3.0]]
      

      【讨论】:

      • 这与 John Zwinck 关于性能的回答相比会很有趣。
      • 为什么你的 (Python 3.x) 解决方案更好(为什么它在 Python 2 中不能正常工作)?
      【解决方案3】:

      有一些相关的方法,分为两个阵营。您可以通过计算单个布尔数组和np.ndarray.all 来使用矢量化方法。或者,您可以通过 for 循环或带有生成器表达式的 next 计算仅包含 0 元素的第一行的索引。

      为了提高性能,我建议您使用 numba 和手动 for 循环。这是一个示例,但请参阅下面的基准测试以获得更有效的变体:

      from numba import jit
      
      @jit(nopython=True)
      def trim_enum_nb(A):
          for idx in range(A.shape[0]):
              if (A[idx]==0).all():
                  break
          return A[:idx]
      

      性能基准测试

      # python 3.6.5, numpy 1.14.3
      
      %timeit trim_enum_loop(A)     # 9.09 ms
      %timeit trim_enum_nb(A)       # 193 µs
      %timeit trim_enum_nb2(A)      # 2.2 µs
      %timeit trim_enum_gen(A)      # 8.89 ms
      %timeit trim_vect(A)          # 3.09 ms
      %timeit trim_searchsorted(A)  # 7.67 µs
      

      测试代码

      设置

      import numpy as np
      from numba import jit
      
      np.random.seed(0)
      
      n = 120000
      k = 1500
      
      A = np.random.randint(1, 10, (n, 3))
      A[k:, :] = 0
      

      函数

      def trim_enum_loop(A):
          for idx, row in enumerate(A):
              if (row==0).all():
                  break
          return A[:idx]
      
      @jit(nopython=True)
      def trim_enum_nb(A):
          for idx in range(A.shape[0]):
              if (A[idx]==0).all():
                  break
          return A[:idx]
      
      @jit(nopython=True)
      def trim_enum_nb2(A):
          for idx in range(A.shape[0]):
              res = False
              for col in range(A.shape[1]):
                  res |= A[idx, col]
                  if res:
                      break
                  return A[:idx]
      
      def trim_enum_gen(A):
          idx = next(idx for idx, row in enumerate(A) if (row==0).all())
          return A[:idx]
      
      def trim_vect(A):
          idx = np.where((A == 0).all(1))[0][0]
          return A[:idx]
      
      def trim_searchsorted(A):
          B = np.frombuffer(A, 'S12')
          idx = A.shape[0] - np.searchsorted(B[::-1], B[-1:], 'right')[0]
          return A[:idx]
      

      检查

      # check all results are the same
      assert (trim_vect(A) == trim_enum_loop(A)).all()
      assert (trim_vect(A) == trim_enum_nb(A)).all()
      assert (trim_vect(A) == trim_enum_nb2(A)).all()
      assert (trim_vect(A) == trim_enum_gen(A)).all()
      assert (trim_vect(A) == trim_searchsorted(A)).all()
      

      【讨论】:

      • 你能给我解释一下 trim_enum_gen 中的 if 吗? .all() 是什么?
      • 参见np.ndarray.allnextgenerator expressions。如果您对它们在此处的使用方式有具体的问题,我可以尝试进一步解释。
      • numba for numba : 用`替换if (A[idx]==0).all(): for j in range(3):\ if v[j]!=0:\ break\ if v[j]==0:\ break快四倍;)
      • @B.M.,好点,已更新。需要一段时间才能习惯 trying 编写嵌套循环。我现在的解决方案似乎比以前的 numba 快 100 倍!
      【解决方案4】:

      知道一切都结束了,只是想我会给出答案:)

      x = [[15.0, 14.0, 13.0],
      [11.0, 7.0, 8.0],
      [4.0, 1.0, 3.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0]]
      

      然后可以进行简单的列表理解

      [i for i in x if all(i)]
      

      和输出:

      [[15.0, 14.0, 13.0],[11.0, 7.0, 8.0],[4.0, 1.0, 3.0]]
      

      需要

      0.0000010866 # seconds or 1.0866 microseconds
      

      花点时间吃一克盐,这真的是不一致的,给我 2 秒的时间来获得更好的估计

      何时:

      x = [[15.0, 14.0, 13.0],
      [11.0, 7.0, 8.0],
      [4.0, 1.0, 3.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0],
      [0.0, 0.0, 0.0]]*(120000//7)
      

      我有时间

      0.01199 # seconds
      

      这个时间很大程度上取决于它们是否为 0,0 会更快,因为它被忽略了。

      【讨论】:

        【解决方案5】:

        对于对数复杂度,可以在逐行转换数据后使用numpy.searchsorted

        B=np.frombuffer(A,'S12')
        index=B.size-np.searchsorted(B[::-1],B[-1:],'right')[0]
        

        index 将是非空项目的数量,如果第一个都不为空。

        测试:

        >>>> %timeit B.size-searchsorted(B[::-1],B[-1:],'right')[0]
        2.2 µs 
        

        【讨论】:

        • 我认为你需要A.shape[0]-np.searchsorted(B[::-1],B[-1:],'right')[0]
        • 还是不错的解决方案 +1,我也将此解决方案添加到我的发帖时间中,希望你没问题。
        【解决方案6】:

        使用vstack的简单迭代解决方案

        import numpy as np
        b = np.empty((0,3), float)
        for elem in a:
            toRemove = np.array([0.0, 0.0, 0.0])
            if(not np.array_equal(elem,toRemove)):
                b=np.vstack((b, elem))
        

        【讨论】:

          猜你喜欢
          • 2014-04-01
          • 2015-04-02
          • 2023-03-10
          • 2019-11-24
          • 1970-01-01
          • 2021-06-18
          • 1970-01-01
          相关资源
          最近更新 更多