【问题标题】:Memory leak with Keras Lambda layerKeras Lambda 层的内存泄漏
【发布时间】:2021-01-16 05:57:44
【问题描述】:

我需要拆分张量的通道以对每个拆分应用不同的归一化。为此,我使用了 Keras 的 Lambda 层:

# split the channels in two (first part for IN, second for BN)
x_in = Lambda(lambda x: x[:, :, :, :split_index])(x)
x_bn = Lambda(lambda x: x[:, :, :, split_index:])(x)

# apply IN and BN on their respective group of channels
x_in = InstanceNormalization(axis=3)(x_in)
x_bn = BatchNormalization(axis=3)(x_bn)

# concatenate outputs of IN and BN
x = Concatenate(axis=3)([x_in, x_bn])

一切都按预期工作(请参阅下面的model.summary()),但每次迭代时 RAM 都会不断增加,这表明存在内存泄漏。

Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_1 (InputLayer)            (None, 832, 832, 1)  0
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 832, 832, 32) 320         input_1[0][0]
__________________________________________________________________________________________________
lambda_1 (Lambda)               (None, 832, 832, 16) 0           conv1[0][0]
__________________________________________________________________________________________________
lambda_2 (Lambda)               (None, 832, 832, 16) 0           conv1[0][0]
__________________________________________________________________________________________________
instance_normalization_1 (Insta (None, 832, 832, 16) 32          lambda_1[0][0]
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 832, 832, 16) 64          lambda_2[0][0]
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 832, 832, 32) 0           instance_normalization_1[0][0]
                                                                 batch_normalization_1[0][0]
__________________________________________________________________________________________________

我确信泄漏来自 Lambda 层,因为我尝试了另一种策略,我不拆分而是在所有通道上独立应用两个规范化,然后将功能添加在一起。这段代码我没有遇到任何内存泄漏:

# apply IN and BN on the input tensor independently
x_in = InstanceNormalization(axis=3)(x)
x_bn = BatchNormalization(axis=3)(x)

# addition of the feature maps outputed by IN and BN
x = Add()([x_in, x_bn])

有解决此内存泄漏的想法吗?我正在使用带有 Tensorflow 1.15.3 的 Keras 2.2.4,我目前无法升级到 TF 2 或 tf.keras。

【问题讨论】:

    标签: python tensorflow keras memory-leaks


    【解决方案1】:

    您可能需要考虑使用自定义层而不是 lambda 层。
    可能是keras lambda层出现了一些故障。

    【讨论】:

      【解决方案2】:

      Thibault Bacqueyrisses 答案是对的,内存泄漏用自定义层消失了!

      这是我的实现:

      class Crop(keras.layers.Layer):
          def __init__(self, dim, start, end, **kwargs):
              """
              Slice the tensor on the last dimension, keeping what is between start
              and end.
              Args
                  dim   (int)   : dimension of the tensor (including the batch dim)
                  start (int)   : index of where to start the cropping
                  end   (int)   : index of where to stop the cropping
              """
              super(Crop, self).__init__(**kwargs)
              self.dimension = dim
              self.start = start
              self.end = end
      
          def call(self, inputs):
              if self.dimension == 0:
                  return inputs[self.start:self.end]
              if self.dimension == 1:
                  return inputs[:, self.start:self.end]
              if self.dimension == 2:
                  return inputs[:, :, self.start:self.end]
              if self.dimension == 3:
                  return inputs[:, :, :, self.start:self.end]
              if self.dimension == 4:
                  return inputs[:, :, :, :, self.start:self.end]
      
          def compute_output_shape(self, input_shape):
              return (input_shape[:-1] + (self.end - self.start,))
      
          def get_config(self):
              config = {
                  'dim': self.dimension,
                  'start': self.start,
                  'end': self.end,
              }
              base_config = super(Crop, self).get_config()
              return dict(list(base_config.items()) + list(config.items()))
      

      【讨论】:

        猜你喜欢
        • 2019-01-09
        • 2018-10-24
        • 1970-01-01
        • 2019-04-11
        • 1970-01-01
        • 2023-01-30
        • 2020-01-27
        • 2019-02-07
        • 1970-01-01
        相关资源
        最近更新 更多