【发布时间】:2018-11-23 12:23:25
【问题描述】:
我在尝试让 TensorFlow 的 map_fn 在我的 GPU 上运行时遇到了一个奇怪的问题。这是一个最小的损坏示例:
import numpy as np
import tensorflow as tf
with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.map_fn(test_func, test_range, dtype=tf.float32))
print(test)
这会导致错误:
InvalidArgumentError:无法分配设备进行操作 “map/TensorArray_1”:无法满足明确的设备规范 '' 因为该节点与一组需要 不兼容的设备 '/device:GPU:0' 托管调试信息:托管 组有以下类型和设备:TensorArrayScatterV3:CPU TensorArrayGatherV3:GPU CPU 范围:GPU CPU TensorArrayWriteV3:CPU TensorArraySizeV3:GPU CPU TensorArrayReadV3:CPU 输入:GPU CPU TensorArrayV3:CPU 常数:GPU CPU
托管成员和用户请求的设备:
map/TensorArrayStack/range/delta (Const)
map/TensorArrayStack/range/start (Const) map/TensorArray_1 (TensorArrayV3) map/while/TensorArrayWrite/TensorArrayWriteV3/Enter (回车) /device:GPU:0 map/TensorArrayStack/TensorArraySizeV3 (TensorArraySizeV3) 地图/TensorArrayStack/范围(范围)
地图/TensorArrayStack/TensorArrayGatherV3 (TensorArrayGatherV3)
map/TensorArray (TensorArrayV3) map/while/TensorArrayReadV3/Enter (回车) /device:GPU:0 Const (Const) /device:GPU:0
地图/TensorArrayUnstack/TensorArrayScatter/TensorArrayScatterV3 (TensorArrayScatterV3) /device:GPU:0 map/while/TensorArrayReadV3 (TensorArrayReadV3) /device:GPU:0
映射/同时/TensorArrayWrite/TensorArrayWriteV3 (TensorArrayWriteV3) /device:GPU:0[[节点:map/TensorArray_1 = TensorArrayV3clear_after_read=true, dtype=DT_FLOAT,dynamic_size=false,element_shape=, 相同元素形状=真, tensor_array_name=""]]
代码在我的 CPU 上运行时的行为符合预期,以及简单的操作,例如:
import numpy as np
import tensorflow as tf
with tf.Session() as sess:
with tf.device("/gpu:0"):
def test_func(i):
return i
test_range = tf.constant(np.arange(5))
test = sess.run(tf.add(test_range, test_range))
print(test)
在我的 GPU 上运行良好。 This post 似乎描述了一个类似的问题。有没有人有任何提示?该帖子的答案暗示map_fn 应该可以在 GPU 上正常工作。我在 Arch Linux 上的 Python 3.6.4 上运行 TensorFlow 1.8.0 版,在 GeForce GTX 1050 上运行 CUDA 9.0 版和 cuDNN 7.0 版。
谢谢!
【问题讨论】:
标签: python python-3.x tensorflow gpu