【问题标题】:TensorFlow in Python for MNIST datasets('dict' object has no attribute 'train')用于 MNIST 数据集的 Python 中的 TensorFlow(“dict”对象没有属性“train”)
【发布时间】:2021-09-26 11:57:14
【问题描述】:

我尝试在 Python 中将 TensorFlow 用于 MNIST 数据集,如下所示。在这里,我使用了神经网络的逻辑回归模型进行训练。但是我的代码有错误报告。

#import tensorflow
import tensorflow as tf
import tensorflow_datasets
mnist = tensorflow_datasets.load('mnist')

batch_size=100
n_batch=mnist.train.num_examples//batch_size

#PLACEHOLDER
#simple neural network
x=tf.placeholder(tf.float32,[None,784])
y=tf.placeholder(tf.float32,[None,784])
prediction=tf.nn.softmax(tf.matmul(x,W)+b)
#Define loss function
loss=tf.reduce_mean(tf.square(y-prediction))

#Gradient Descent
train_step=tf.train.GradientDescentOptimizer(0,2).minimize(loss)

init=tf.global_variables_initializer()

#Check accuracy of model
#These lines are used for converting one hot coding back to the original label form.
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))

#accuracy
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#training
#print predicted and true lables of first 10 test samples
with tf.compat.v1.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs,batch_ys=mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})

            acc=sess.run(accuracy,feed_dict=({x:mnist.test.images,y:mnist.test.labels}))
    print("Iter"+str(epoch)+",Testing accuarcy"+str(acc))

但是报错:

AttributeError: 'dict' object has no attribute 'train'

【问题讨论】:

  • print(type(mnist)) 的输出是什么?根据张量流数据集docs 应该是tf.data.Dataset

标签: python tensorflow


【解决方案1】:

mnist 对象是带有键 train 和 test 的字典。所以,这就是为什么你不能像 mnist.train 那样访问它

像 mnist['train'] 一样访问它。

您可以使用 len(mnist['train']) 获取 num_examples

【讨论】:

  • 对不起,我添加了我的完整代码。
  • 使用这个代替 tensorflow.load_datasets() from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
  • 替换mnist = tensorflow_datasets.load('mnist')?
  • 我使用了input_data mnist = input_data.read_data_sets("MNIST_data", one_hot=True) ,但出现错误'SyntaxError: invalid syntax`
猜你喜欢
  • 2021-05-14
  • 2018-10-13
  • 1970-01-01
  • 2018-01-29
  • 1970-01-01
  • 2017-07-10
  • 1970-01-01
  • 2020-09-11
相关资源
最近更新 更多