【问题标题】:Tensorflow: Number of nodes in the graph keeps increasing as training goes onTensorflow:随着训练的进行,图中的节点数量不断增加
【发布时间】:2019-04-30 00:39:39
【问题描述】:

我正在 tensorflow 中训练一个卷积模型。在对模型进行了大约 70 个 epoch 的训练后,花了将近 1.5 小时,我无法保存模型。它给了我ValueError: GraphDef cannot be larger than 2GB。我发现随着训练的进行,我图中的节点数量不断增加。

在 epochs 0,3,6,9,图中的节点数分别为 7214, 7238, 7262, 7286。当我使用with tf.Session() as sess: 时,不是将会话传递为sess = tf.Session(),而是在 epoch 0、3、6、9 处的节点数分别为 3982、4006、4030、4054。

this 答案中,据说随着节点被添加到图中,它可能会超过其最大大小。我需要帮助来了解我的图中的节点数量是如何不断增加的。

我使用以下代码训练我的模型:

def runModel(data):
    '''
    Defines cost, optimizer functions, and runs the graph
    '''
    X, y,keep_prob = modelInputs((755, 567, 1),4)
    logits = cnnModel(X,keep_prob)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=y), name="cost")
    optimizer = tf.train.AdamOptimizer(.0001).minimize(cost)
    correct_pred = tf.equal(tf.argmax(logits, 1), tf.argmax(y, 1), name="correct_pred")
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32), name='accuracy')

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    for e in range(12):
        batch_x, batch_y = data.next_batch(30)
        x = tf.reshape(batch_x, [30, 755, 567, 1]).eval(session=sess)
        batch_y = tf.one_hot(batch_y,4).eval(session=sess)
        sess.run(optimizer, feed_dict={X: x, y: batch_y,keep_prob:0.5})
        if e%3==0:
            n = len([n.name for n in tf.get_default_graph().as_graph_def().node])
            print("No.of nodes: ",n,"\n")
            current_cost = sess.run(cost, feed_dict={X: x, y: batch_y,keep_prob:1.0})
            acc = sess.run(accuracy, feed_dict={X: x, y: batch_y,keep_prob:1.0})
            print("At epoch {epoch:>3d}, cost is {a:>10.4f}, accuracy is {b:>8.5f}".format(epoch=e, a=current_cost, b=acc))

是什么导致节点数量增加?

【问题讨论】:

  • 也许你可以在每一步得到新节点的名称,看看它们是哪些节点?也许它只是每次都被复制的输入节点,我不知道......你使用的是什么版本的 tf?
  • @gdelab 我正在使用1.0.1,每个epoch的节点数似乎增加了8个!
  • 是的,但是你能在每一步得到八个新的节点名称吗?也许他们可以帮助了解创建新节点的位置...

标签: python tensorflow


【解决方案1】:

您正在训练循环中创建新节点。特别是,您正在调用tf.reshapetf.one_hot,它们中的每一个都会创建一个(或多个)节点。您可以:

  • 使用占位符作为输入在图表之外创建这些节点,然后仅在循环中评估它们。
  • 不要将 TensorFlow 用于这些操作,而是使用 NumPy 或等效操作。

我会推荐第二个,因为使用 TensorFlow 进行数据准备似乎没有任何好处。你可以有类似的东西:

import numpy as np
# ...
    x = np.reshape(batch_x, [30, 755, 567, 1])
    # ...
    # One way of doing one-hot encoding with NumPy
    classes_arr = np.arange(4).reshape([1] * batch_y.ndims + [-1])
    batch_y = (np.expand_dims(batch_y, -1) == classes_arr).astype(batch_y.dtype)
    # ...

PD:我还建议在 with context manager 中使用 tf.Session() 以确保最后调用其 close() 方法(除非您以后想继续使用相同的会话)。

【讨论】:

    【解决方案2】:

    另一个解决了我类似问题的选项是使用tf.reset_default_graph()

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2022-01-19
      • 1970-01-01
      • 1970-01-01
      • 2019-10-06
      • 2019-12-31
      • 2019-09-24
      • 1970-01-01
      相关资源
      最近更新 更多