【发布时间】:2018-07-05 03:19:35
【问题描述】:
我正在研究字符级循环神经网络。为了训练网络,我从互联网上复制了一个文本语料库。这是其中有错误的代码块:
X = np.zeros((int(len(data)/SEQ_LENGTH), SEQ_LENGTH, VOCAB_SIZE))
y = np.zeros((int(len(data)/SEQ_LENGTH), SEQ_LENGTH, VOCAB_SIZE))
for i in range(0, int(len(data)/SEQ_LENGTH)):
X_sequence = data[i*SEQ_LENGTH:(i+1)*SEQ_LENGTH]
X_sequence_ix = [char_to_ix[value] for value in X_sequence]
input_sequence = np.zeros((SEQ_LENGTH, VOCAB_SIZE))
for j in range(SEQ_LENGTH):
input_sequence[j][X_sequence_ix[j]] = 1.
X[i] = input_sequence
y_sequence = data[i*SEQ_LENGTH+1:(i+1)*SEQ_LENGTH+1]
y_sequence_ix = [char_to_ix[value] for value in y_sequence]
target_sequence = np.zeros((SEQ_LENGTH, VOCAB_SIZE))
for j in range(SEQ_LENGTH):
target_sequence[j][y_sequence_ix[j]] = 1
y[i] = target_sequence
基本上,我所做的只是将字符转换为它们的 ASCII 等价物。 y_sequence 是字符序列,y_sequence_ix 是其对应的 ASCII 序列。 VOCAB_SIZE 变量包含文本语料库中唯一字符的数量。错误发生在这一行:
target_sequence[j][y_sequence_ix[j]] = 1
完整的源代码和文本语料库:https://github.com/tanmay-edgelord/charRNN
请询问您回答问题所需的任何信息。
编辑
调用函数traceback.print_stack()时的TRACEBACK
File "/usr/lib/python3.5/runpy.py", line 184, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.5/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.5/dist-packages/ipykernel_launcher.py", line 16, in <module>
app.launch_new_instance()
File "/usr/local/lib/python3.5/dist-packages/traitlets/config/application.py", line 658, in launch_instance
app.start()
File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelapp.py", line 478, in start
self.io_loop.start()
File "/usr/local/lib/python3.5/dist-packages/zmq/eventloop/ioloop.py", line 177, in start
super(ZMQIOLoop, self).start()
File "/usr/local/lib/python3.5/dist-packages/tornado/ioloop.py", line 888, in start
handler_func(fd_obj, events)
File "/usr/local/lib/python3.5/dist-packages/tornado/stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
self._handle_recv()
File "/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
self._run_callback(callback, msg)
File "/usr/local/lib/python3.5/dist-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
callback(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/tornado/stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 233, in dispatch_shell
handler(stream, idents, msg)
File "/usr/local/lib/python3.5/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "/usr/local/lib/python3.5/dist-packages/ipykernel/ipkernel.py", line 208, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.5/dist-packages/ipykernel/zmqshell.py", line 537, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2728, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2856, in run_ast_nodes
if self.run_code(code, result):
File "/usr/local/lib/python3.5/dist-packages/IPython/core/interactiveshell.py", line 2910, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-15-b47f6c9a5577>", line 2, in <module>
traceback.print_stack()
【问题讨论】:
-
仔细检查
target_sequence的形状和大小。如果我没记错的话,它应该有两个维度。因此,调用target_sequence[j]应该返回另一个列表(您再次使用[y_sequence[j]]对其进行索引,但如果[y_sequence[j]]的结果大于VOCAB_SIZE,您将超出范围。回溯一下,您的y_sequence似乎是基于SEQ_LENGTH而不是VOCAB_SIZE,所以这很有可能是问题的根源。 -
target_sequence的尺寸为 100x80。此外,y_sequence 的长度不是问题,因为我有兴趣访问与 y_sequence 变量中第 j 个字符的 ASCII 码对应的索引并设置target_sequence[j][y_sequence_ix[j]= 1 -
请发布完整的 Traceback。
-
当您caught the error 并检查数据和变量时,您发现了什么?
-
变量的长度符合预期,必须匹配的变量是匹配的。这就是我感到困惑的原因,另外我从这个链接复制了这段代码:chunml.github.io/ChunML.github.io/project/…。这显然对他有用
标签: python python-3.x list deep-learning recurrent-neural-network