【问题标题】:assign on a tf.Variable tensor slice在 tf.Variable 张量切片上赋值
【发布时间】:2016-06-02 05:53:29
【问题描述】:

我正在尝试执行以下操作

state[0,:] = state[0,:].assign( 0.9*prev_state + 0.1*( tf.matmul(inputs, weights) + biases ) )
for i in xrange(1,BATCH_SIZE):
    state[i,:] = state[i,:].assign( 0.9*state[i-1,:] + 0.1*( tf.matmul(inputs, weights) + biases ) )
prev_state = prev_state.assign( state[BATCH_SIZE-1,:] )

state = tf.Variable(tf.zeros([BATCH_SIZE, HIDDEN_1]), name='inner_state')
prev_state = tf.Variable(tf.zeros([HIDDEN_1]), name='previous_inner_state')

作为this question 的后续行动。我收到一个错误,指出 Tensor 没有 assign 方法。

Variable 张量的切片上调用assign 方法的正确方法是什么?


当前完整代码:

import tensorflow as tf
import math
import numpy as np

INPUTS = 10
HIDDEN_1 = 20
BATCH_SIZE = 3


def create_graph(inputs, state, prev_state):
    with tf.name_scope('h1'):
        weights = tf.Variable(
        tf.truncated_normal([INPUTS, HIDDEN_1],
                            stddev=1.0 / math.sqrt(float(INPUTS))),
        name='weights')
        biases = tf.Variable(tf.zeros([HIDDEN_1]), name='biases')

        updated_state = tf.scatter_update(state, [0], 0.9 * prev_state + 0.1 * (tf.matmul(inputs[0,:], weights) + biases))
        for i in xrange(1, BATCH_SIZE):
          updated_state = tf.scatter_update(
              updated_state, [i], 0.9 * updated_state[i-1, :] + 0.1 * (tf.matmul(inputs[i,:], weights) + biases))

        prev_state = prev_state.assign(updated_state[BATCH_SIZE-1, :])
        output = tf.nn.relu(updated_state)
    return output

def data_iter():
    while True:
        idxs = np.random.rand(BATCH_SIZE, INPUTS)
        yield idxs

with tf.Graph().as_default():
    inputs = tf.placeholder(tf.float32, shape=(BATCH_SIZE, INPUTS))
    state = tf.Variable(tf.zeros([BATCH_SIZE, HIDDEN_1]), name='inner_state')
    prev_state = tf.Variable(tf.zeros([HIDDEN_1]), name='previous_inner_state')

    output = create_graph(inputs, state, prev_state)

    sess = tf.Session()
    # Run the Op to initialize the variables.
    init = tf.initialize_all_variables()
    sess.run(init)
    iter_ = data_iter()
    for i in xrange(0, 2):
        print ("iteration: ",i)
        input_data = iter_.next()
        out = sess.run(output, feed_dict={ inputs: input_data})

【问题讨论】:

    标签: python-2.7 tensorflow


    【解决方案1】:

    Tensorflow Variable 对象对更新切片的支持有限,使用 tf.scatter_update()tf.scatter_add()tf.scatter_sub() 操作。这些操作中的每一个都允许您指定一个变量、一个切片索引向量(表示变量第 0 维中的索引,表示要突变的连续切片)和一个值张量(表示要应用于的新值)变量,在相应的切片索引处)。

    要更新变量的单行,您可以使用tf.scatter_update()。例如,要更新state 的第 0 行,您可以:

    updated_state = tf.scatter_update(
        state, [0], 0.9 * prev_state + 0.1 * (tf.matmul(inputs, weights) + biases))
    

    要链接多个更新,您可以使用从tf.scatter_update() 返回的可变updated_state 张量:

    for i in xrange(1, BATCH_SIZE):
      updated_state = tf.scatter_update(
          updated_state, [i], 0.9 * updated_state[i-1, :] + ...)
    
    prev_state = prev_state.assign(updated_state[BATCH_SIZE-1, :])
    

    最后,您可以评估生成的updated_state.op 以将所有更新应用到state

    sess.run(updated_state.op)  # or `sess.run(updated_state)` to fetch the result
    

    PS。您可能会发现使用tf.scan() 来计算中间状态会更有效,只需将prev_state 物化在一个变量中即可。

    【讨论】:

    • 谢谢,我有点困惑在这种情况下如何使用scan,我不应该先用 ops 构建一个图表,然后在输出节点上使用 scan 吗?
    • 我遇到的另一个问题是,由于我正在分解批处理元素维度上的图表,tf.matmul(inputs, weights) 不仅可以工作,而且updated_state = tf.scatter_update(state, [0], 0.9 * prev_state + 0.1 * (tf.matmul(inputs[0,:], weights) + biases)) 行给出:@987654343 @。我将使用完整代码更新 OP
    猜你喜欢
    • 2019-06-02
    • 1970-01-01
    • 1970-01-01
    • 2016-07-04
    • 2017-12-01
    • 2021-01-15
    • 2018-03-15
    • 2018-08-05
    • 1970-01-01
    相关资源
    最近更新 更多