【问题标题】:Is there a similar function in tensorflow like load_state_dict() in Pytorch?tensorflow 中是否有类似 Pytorch 中的 load_state_dict() 的函数?
【发布时间】:2019-11-14 01:33:11
【问题描述】:

就像已经描述的那样,我想知道在 tensorflow 中是否有类似 Pytorch 中的 load_state_dict() 函数。演示一个场景,请参考以下代码:

# Suppose we have two correctly initialized neural networks: net2 and net1
# Using Pytorch
net2.load_state_dict(net1.state_dict())

有人知道吗?

【问题讨论】:

    标签: tensorflow neural-network deep-learning pytorch


    【解决方案1】:

    下面的代码可能有助于在 tensorflow 中实现同样的效果:

    保存模型

    w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
    w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
    tf.add_to_collection('vars', w1)
    tf.add_to_collection('vars', w2)
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())
    saver.save(sess, 'my-model')
    # `save` method will call `export_meta_graph` implicitly.
    # you will get saved graph files:my-model.meta
    
    
    

    恢复模型

    sess = tf.Session()
    new_saver = tf.train.import_meta_graph('my-model.meta')
    new_saver.restore(sess, tf.train.latest_checkpoint('./'))
    all_vars = tf.get_collection('vars')
    for v in all_vars:
        v_ = sess.run(v)
        print(v_)
    

    【讨论】:

    • 我正是这样解决了我的问题(几个月前)!无论如何,谢谢你的回答!
    猜你喜欢
    • 1970-01-01
    • 2018-12-23
    • 1970-01-01
    • 2019-12-02
    • 1970-01-01
    • 1970-01-01
    • 2013-05-03
    • 2011-08-01
    • 2014-06-14
    相关资源
    最近更新 更多