【问题标题】:Keras variable() memory leakKeras 变量()内存泄漏
【发布时间】:2019-04-11 18:50:18
【问题描述】:

我是 Keras 和一般 tensorflow 的新手,但遇到了问题。我正在使用一些损失函数(主要是 binary_crossentropy 和 mean_squared_error)来计算预测后的损失。由于 Keras 只接受它自己的变量类型,我正在创建一个并将其作为参数提供。这个场景是在循环中执行的(带睡眠),如下所示:

获取适当的数据->预测->计算丢失的数据->返回。

由于我有多个遵循此模式的模型,因此我创建了 tensorflow 图和会话以防止冲突(在导出模型的权重时,我遇到了单个图形和会话的问题,因此我必须为每个模型创建不同的模型)。

但是,现在内存无法控制地上升,在几次迭代中从几个 MiB 到 700MiB。我知道 Keras 的 clear_session() 和 gc.collect(),我在每次迭代结束时都使用它们,但问题仍然存在。这里我提供了一个来自项目的代码sn-p,它不是实际的代码。我创建了单独的脚本来隔离问题:

import tensorflow as tf

from keras import backend as K
from keras.losses import binary_crossentropy, mean_squared_error

from time import time, sleep
import gc
from numpy.random import rand

from os import getpid
from psutil import Process

from csv import DictWriter
from keras import backend as K

this_process = Process(getpid())

graph = tf.Graph()
sess = tf.Session(graph=graph)

cnt = 0
max_c = 500

with open('/home/quark/Desktop/python-test/leak-7.csv', 'a') as file:
    writer = DictWriter(file, fieldnames=['time', 'mem'])
    writer.writeheader()

    while cnt < max_c:  
        with graph.as_default(), sess.as_default():         
            y_true = K.variable(rand(36, 6))
            y_pred = K.variable(rand(36, 6))

            rec_loss = K.eval(binary_crossentropy(y_true, y_pred))
            val_loss = K.eval(mean_squared_error(y_true, y_pred))

            writer.writerow({
                'time': int(time()),
                'mem': this_process.memory_info().rss
            })

        K.clear_session()
        gc.collect()

        cnt += 1
        print(max_c - cnt)
        sleep(0.1)

此外,我还添加了内存使用图: Keras memory leak

感谢任何帮助。

【问题讨论】:

  • 您能添加所需的导入吗?我相信您正在混合使用 tf 和 keras 命令。
  • 是的,我们可以运行一个完整的例子会很好。
  • 我已经更新了代码。

标签: python tensorflow memory-leaks keras


【解决方案1】:

最后,我从where 语句中删除了K.variable() 代码。这样,变量是默认图表的一部分,稍后会被K.clear_session() 清除。

【讨论】:

    【解决方案2】:

    我刚刚删除了 with 语句(可能是一些 tf 代码),我没有看到任何泄漏。我相信 keras 会话和 tf 默认会话之间存在差异。所以你没有用K.clear_session() 清除正确的会话。可能使用tf.reset_default_graph() 也可以。

    while True: 
        y_true = K.variable(rand(36, 6))
        y_pred = K.variable(rand(36, 6))
    
        val_loss = K.eval(binary_crossentropy(y_true, y_pred))
        rec_loss = K.eval(mean_squared_error(y_true, y_pred))
    
        K.clear_session()
        gc.collect()
    
        sleep(0.1)
    

    【讨论】:

    • 我知道删除with 语句可以解决问题,但恐怕我不能省略它,因为我有不同的图表。无论如何,谢谢,我会尝试tf.reset_default_graph()
    • 那么你只需要重置相应的图表,而不仅仅是默认的
    猜你喜欢
    • 2019-01-09
    • 2018-10-24
    • 1970-01-01
    • 1970-01-01
    • 2021-01-16
    • 1970-01-01
    • 2012-01-29
    • 1970-01-01
    • 2020-01-27
    相关资源
    最近更新 更多