【问题标题】:Tensorflow: Replace one op with another(maybe even 2 ops)Tensorflow:用另一个替换一个操作(甚至可能是 2 个操作)
【发布时间】:2018-10-11 12:31:37
【问题描述】:

我的目标是构建一个脚本,使用 TF 的图形编辑器将一个操作更改为另一个操作。到目前为止,我尝试制作一个仅更改 Conv2D 的输入内核权重的脚本,但无济于事,因为界面非常混乱。

with tf.Session() as sess:

    model_filename = sys.argv[1]

    with gfile.FastGFile(model_filename, 'r') as f:

        graph_def = graph_pb2.GraphDef()
        text_format.Merge(f.read(), graph_def)
        importer.import_graph_def(graph_def)

        #my_sgv = ge.sgv("Conv2D", graph=tf.get_default_graph())
        #print my_sgv

        convs = find_conv2d_ops(tf.get_default_graph())
        print convs

        my_sgv = ge.sgv(convs)
        print my_sgv

        conv_tensor = tf.get_default_graph().get_tensor_by_name(convs[0].name + ':0')
        conv_weights_input = tf.get_default_graph().get_tensor_by_name(convs[0].inputs[1].name)

        weights_new = tf.Variable(tf.truncated_normal([1, 1, 1, 8], stddev=0.03),
                                  name='Wnew')

        ge.graph_replace(conv_tensor, {conv_weights_input: weights_new})

错误是“输入需要是张量:”。有人可以提供一些见解吗?

【问题讨论】:

    标签: tensorflow graph replace editor


    【解决方案1】:

    由于您处理的是tf.Variable,因此您不需要使用图形编辑器。 tf.assign 就足够了。

    你可以像下面这样使用它:

    assign_op = tf.assign(conv_weights_input, weights_new)
    with tf.Session() as sess:
        sess.run(assign_op)
    

    如果您希望分出操作而不是权重。考虑以下示例(修改自 this 示例):

    import tensorflow as tf
    import tensorflow.contrib.graph_editor as ge
    
    def build():
        a_pl = tf.placeholder(dtype=tf.float32, name="a")
        b_pl = tf.placeholder(dtype=tf.float32, name="b")
        c = tf.add(a_pl, b_pl, name="c")
    
    build() #or load graph from disc
    
    a = tf.constant(1.0, shape=[2, 3], name="a_const")
    b = tf.constant(2.0, shape=[2, 3], name="b_const")
    
    a_pl = tf.get_default_graph().get_tensor_by_name("a:0")
    b_pl = tf.get_default_graph().get_tensor_by_name("b:0")
    c = tf.get_default_graph().get_tensor_by_name("c:0")
    
    c_ = ge.graph_replace(c, {a_pl: a, b_pl: b})
    
    with tf.Session() as sess:
        #no need for placeholders
        print(sess.run(c_))
        #will give error since a_pl and b_pl have no value
        print(sess.run(c))
    

    您的代码的问题在于您处理的是 wights,而不是张量。上面例子的症结在于,第一个参数是目标张量(输出张量),它具有要替换的张量作为依赖项。第二个参数是您要替换的实际张量。

    同样值得注意的是conv_weights_input实际上是一个张量,其中weights_new是一个tf.Variable。我相信您想要的是将weights_new 替换为具有随机权重初始化的新conv 操作。

    【讨论】:

    • 谢谢!但是如果我想在 Conv2D 的输出中插入其他操作呢?或者,作为开始,将 Conv2D 替换为另一个操作?
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2022-01-23
    • 2017-08-26
    • 1970-01-01
    • 2020-03-10
    • 2021-08-16
    • 1970-01-01
    相关资源
    最近更新 更多