【发布时间】:2017-06-19 03:04:01
【问题描述】:
您好,我正在尝试使用 GreedyEmbeddingHelper 创建 BasicDecoder,但它给出了错误:
TypeError: helper must be a Helper, received: <class 'helper.GreedyEmbeddingHelper'>
这是我的代码的简化版本:
elif self.mode == 'decode':
# Start_tokens: [batch_size,] `int32` vector
start_tokens = tf.ones([self.batch_size, self.dimension], tf.float32) * 0.1337
end_token = 0.1337
def project_inputs(inputs):
print inputs.shape
return input_layer(inputs)
if not self.use_beamsearch_decode:
# Helper to feed inputs for greedy decoding: uses the argmax of the output
decoding_helper = helper.GreedyEmbeddingHelper(start_tokens=start_tokens,
end_token=end_token,
embedding=project_inputs)
# Basic decoder performs greedy decoding at each time step
print("building greedy decoder..")
inference_decoder = seq2seq.BasicDecoder(cell=self.decoder_cell,
helper=decoding_helper,
initial_state=self.decoder_initial_state,
output_layer=output_layer)
else:
# Beamsearch is used to approximately find the most likely translation
print("building beamsearch decoder..")
inference_decoder = beam_search_decoder.BeamSearchDecoder(cell=self.decoder_cell,
embedding=project_inputs,
start_tokens=start_tokens,
end_token=end_token,
initial_state=self.decoder_initial_state,
beam_width=self.beam_width,
output_layer=output_layer,)
我不知道如何解决它,因为 Helper 是一个抽象类。所以这是不可能的。
【问题讨论】:
-
你是如何导入 GreedyEmbeddingHelper 的? helper.GreedyEmbeddingHelper??也许这条路径可以提供帮助:tensorflow.org/api_docs/python/tf/contrib/seq2seq/…
标签: python tensorflow deep-learning lstm