【问题标题】:Operation has type int32 that does not match type int64操作的 int32 类型与 int64 类型不匹配
【发布时间】:2018-04-23 06:24:37
【问题描述】:

我正在尝试使用tflearn 提供的 DNN 训练一些数据。我的data 变量的形状为(6605, 32),我的labels 数据的形状为(6605,),我在下面将其整形为(6605, 1)...

# Target label used for training
labels = np.array(data[label], dtype=np.float32)

# Reshape target label from (6605,) to (6605, 1)
labels = tf.reshape(labels, shape=[-1, 1])

# Data for training minus the target label.
data = np.array(data.drop(label, axis=1), dtype=np.float32)

# DNN
net = tflearn.input_data(shape=[None, 32])
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 32)
net = tflearn.fully_connected(net, 1, activation='softmax')
net = tflearn.regression(net)

# Define model.
model = tflearn.DNN(net)
model.fit(data, labels, n_epoch=10, batch_size=16, show_metric=True)

运行后出现两个错误,第一个是...

ValueError: Tensor conversion requested dtype int64 for Tensor with dtype int32: 'Tensor("strided_slice/stack_4:0", shape=(1,), dtype=int32)'

第二个错误是……

在处理上述异常的过程中,又发生了一个异常:

TypeError: 'StridedSlice' Op 的输入 'strides' 的 int32 类型与参数 'begin' 的 int64 类型不匹配。

我不知道如何解决这个问题。所以我采取的一种方法是将dtypelabelsdata 更改为int64...

# Target label used for training
labels = np.array(data[label], dtype=np.int64)

# Reshape target label from (6605,) to (6605, 1)
labels = tf.reshape(labels, shape=[-1, 1])

# Data for training minus the target label.
data = np.array(data.drop(label, axis=1), dtype=np.int64)

但如果我这样做,我仍然会遇到同样的错误。我该如何解决这个问题?

【问题讨论】:

  • 你升级到哪个版本的tensorflow?

标签: python-3.x machine-learning tensorflow neural-network tflearn


【解决方案1】:

请注意,错误指出有关整数、int64 和 int32 的 ValueError,因此所讨论的转换应该是 int 而不是 float。

我在代码库中遇到过类似情况,无法更新 tensorflow, 我正在生成一个导致 int 类型 ValueError 的索引。

如果无法升级 tensorflow 的版本,您可以强制转换为 int32,例如:

new_var = tf.cast(old_var, tf.int32)

如果您检查发生错误的确切行,您应该找到罪魁祸首,将其强制转换,它应该可以工作。此解决方案适用于 int 和 float 类型转换(使用 tf.float32)。

【讨论】:

    【解决方案2】:

    此错误似乎是由较旧版本的 tensorflow 生成的。我通过使用...更新了 tensorflow。

    pip3 install tensorflow --upgrade
    

    这消除了int64 错误。在旧版本中转换 long 数据类型(例如 int64)似乎存在问题。

    更新后,由于张量的形状,我得到了不同的错误...

    形状必须为 1 级,但对于输入形状为 [6605,1]、[1,16]、[1,16]、[1] 的“strided_slice”(操作:“StridedSlice”)为 2 级。

    但这是另一个问题。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-06-17
      • 1970-01-01
      • 1970-01-01
      • 2019-09-20
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多