【问题标题】:Using Tensorflow, How can I load weights generated from LSTM into CudnnLSTM Model?使用 Tensorflow,如何将 LSTM 生成的权重加载到 CudnnLSTM 模型中?
【发布时间】:2020-05-27 15:52:45
【问题描述】:

我使用 tensorflow 对 LSTM 模型进行了训练,我可以将 LSTM 生成的权重加载到 CudnnLSTM 模型中吗?我的 LSTM 代码是

lstm_cell = tf.contrib.rnn.LSTMCell(hidden_size)
outputs, (c, h) = tf.nn.dynamic_rnn(lstm_cell,
                                    input_seq,
                                    dtype = tf.float32)

CudnnLSTM 代码是

cudnn_cell_fw = cudnn_rnn.CudnnLSTM(num_layers = 1,
                                    num_units = hidden_size,
                                    direction = cudnn_rnn.CUDNN_RNN_UNIDIRECTION,
                                    input_mode = cudnn_rnn.CUDNN_INPUT_LINEAR_MODE,
                                    dtype = tf.float32)
outputs, (h, c) = cudnn_cell_fw(inputs = input_seq)

【问题讨论】:

标签: tensorflow lstm


【解决方案1】:

我尝试将 LSTM 的权重和偏差收集为:

frozen_graph_path = './train_one_lstm_model.pb'
frozen_graphdef = get_graphdef(frozen_graph_path)
for node in frozen_graphdef.node:
    if (node.name == 'rnn/lstm_cell/kernel'):
        lstm_weight = tensor_util.MakeNdarray(node.attr['value'].tensor)
    if (node.name == 'rnn/lstm_cell/bias'):
        lstm_bias = tensor_util.MakeNdarray(node.attr['value'].tensor)

然后我将它们发送到 CudnnLSTM 节点

weight_shape_0 = lstm_weight.shape[0]
weight_shape_1 = lstm_weight.shape[1]
new_cudnn_weight = np.zeros(((weight_shape_0 + 2) * weight_shape_1), dtype = np.float32)
index = 0
for i in range(weight_shape_0):
    for j in range(weight_shape_1):
        new_cudnn_weight[index] = lstm_weight[i][j]
        index += 1

for j in range(weight_shape_1):
    new_cudnn_weight[index] = lstm_bias[j]
    index += 1

frozen_graph_path = './train_one_culstm_model.pb'
frozen_graphdef = get_graphdef(frozen_graph_path)

for node in frozen_graphdef.node:
    if (node.name == 'cudnn_lstm/opaque_kernel'):
        ori_cudnn_weight = tensor_util.MakeNdarray(node.attr['value'].tensor)
        node.attr['value'].tensor.CopyFrom(tensor_util.make_tensor_proto(new_cudnn_weight))
        new_cudnn_weight = tensor_util.MakeNdarray(node.attr['value'].tensor)

因此,CudnnLSTM 节点获得与 LSTM 相同的权重和偏差。但是,当我发送相同的输入时,输出是不同的。

【讨论】:

    猜你喜欢
    • 2018-04-10
    • 1970-01-01
    • 1970-01-01
    • 2020-12-30
    • 2020-07-08
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2016-08-05
    相关资源
    最近更新 更多