【发布时间】:2020-10-19 13:07:41
【问题描述】:
美好的一天。我正在尝试在 MNIST 数据集上的 tensorflow 上运行 CNN 示例,作为 tensorflow 新手介绍的一部分。我不得不使用“import tensorflow.compat.v1 as tf”来启用示例中使用的会话调用功能。但是,我在以下代码的末尾遇到了一些麻烦,无法通过它(在此处显示完整代码)。
import tensorflow.compat.v1 as tf
import tensorflow_datasets as tfds
import numpy as np
tf.compat.v1.disable_eager_execution()
import keras
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10])
# Weight tensor
W = tf.Variable(tf.zeros([784, 10],tf.float32))
# Bias tensor
b = tf.Variable(tf.zeros([10],tf.float32))
# run the op initialize_all_variables using an interactive session
sess.run(tf.global_variables_initializer())
# mathematical operation to add weights and biases to the inputs
tf.matmul(x,W) + b
y = tf.nn.softmax(tf.matmul(x,W) + b)
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
mnist_data, info = tfds.load("mnist", with_info=True, as_supervised=True)
train, test = mnist_data["train"], mnist_data["test"]
#Load 50 training examples for each training iteration
import keras
for i in range(1000):
batch = train.next_batch(50)
train_step.run(feed_dict={x: batch[0], y_: batch[1]})
然后我得到了 ffg。错误:
AttributeError Traceback(最近一次调用最后一次) 在 2 导入 keras 3 for i in range(1000): ----> 4 批次 = train.next_batch(50) 5 train_step.run(feed_dict={x: batch[0], y_: batch[1]})
AttributeError: 'PrefetchDataset' 对象没有属性 'next_batch'
我在 StackOverflow 上看到了解决此问题的各种响应,例如为 next_batch 编写自己的调用函数,例如:
def next_batch(num, data,labels):
'''
Return a total of `num` random samples and labels.
'''
idx = np.arange(0 , len(data),len(labels))
np.random.shuffle(idx)
idx = idx[:num]
data_shuffle = [data[ i] for i in idx]
labels_shuffle = [labels[ i] for i in idx]
return np.asarray(data_shuffle), np.asarray(labels_shuffle)
不幸的是,我收到以下错误:
AttributeError: 'dict' 对象没有属性 'next_batch'
我将非常感谢您对此提供一些指导。
提前谢谢你。
【问题讨论】:
-
到目前为止,“最明智的”选项恕我直言,忽略过时的教程和正确的 TF2 风格的代码;你会过得更好。
标签: python tensorflow