【发布时间】:2016-03-14 22:43:28
【问题描述】:
所以我的问题是我正在运行 TensorFlow 教程中的初学者级代码,并根据我的需要对其进行了修改,但是当我制作它时 print sess.run(accuracy, feed_dict={x: x_test, y_: y_test}) 它过去总是打印出 1.0,现在它总是在猜测 0 和打印出约 93% 的准确度。当我使用tf.argmin(y,1), tf.argmin(y_,1) 时,它会猜测所有的 1 并产生约 7% 的准确率。将两者相加,等于 100%。我不明白tf.argmin 是如何猜测 1 和 tf.argmax 是如何猜测 0 的。显然代码有问题。请看一下,让我知道我能做些什么来解决这个问题。我认为代码在训练期间出错了,但我可能是错的。
import tensorflow as tf
import numpy as np
from numpy import genfromtxt
data = genfromtxt('cs-training.csv',delimiter=',') # Training data
test_data = genfromtxt('cs-test.csv',delimiter=',') # Test data
x_train = []
for i in data:
x_train.append(i[1:])
x_train = np.array(x_train)
y_train = []
for i in data:
if i[0] == 0:
y_train.append([1., i[0]])
else:
y_train.append([0., i[0]])
y_train = np.array(y_train)
where_are_NaNs = isnan(x_train)
x_train[where_are_NaNs] = 0
x_test = []
for i in test_data:
x_test.append(i[1:])
x_test = np.array(x_test)
y_test = []
for i in test_data:
if i[0] == 0:
y_test.append([1., i[0]])
else:
y_test.append([0., i[0]])
y_test = np.array(y_test)
where_are_NaNs = isnan(x_test)
x_test[where_are_NaNs] = 0
x = tf.placeholder("float", [None, 10])
W = tf.Variable(tf.zeros([10,2]))
b = tf.Variable(tf.zeros([2]))
y = tf.nn.softmax(tf.matmul(x,W) + b)
y_ = tf.placeholder("float", [None,2])
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print "...Training..."
g = 0
for i in range(len(x_train)):
sess.run(train_step, feed_dict={x: [x_train[g]], y_: [y_train[g]]})
g += 1
此时,如果我将其设为print [x_train[g]] 和print [y_train[g]],结果如下所示。
[array([ 7.66126609e-01, 4.50000000e+01, 2.00000000e+00,
8.02982129e-01, 9.12000000e+03, 1.30000000e+01,
0.00000000e+00, 6.00000000e+00, 0.00000000e+00,
2.00000000e+00])]
[array([ 0., 1.])]
好的,那我们继续吧。
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print sess.run(accuracy, feed_dict={x: x_test, y_: y_test})
0.929209
这个百分比不会改变。无论我为 2 个类(1 或 0)创建的 onehot 是什么,它都猜测全为零。
来看看数据-
print x_train[:10]
[[ 7.66126609e-01 4.50000000e+01 2.00000000e+00 8.02982129e-01
9.12000000e+03 1.30000000e+01 0.00000000e+00 6.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 9.57151019e-01 4.00000000e+01 0.00000000e+00 1.21876201e-01
2.60000000e+03 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 6.58180140e-01 3.80000000e+01 1.00000000e+00 8.51133750e-02
3.04200000e+03 2.00000000e+00 1.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.33809776e-01 3.00000000e+01 0.00000000e+00 3.60496820e-02
3.30000000e+03 5.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.07239400e-01 4.90000000e+01 1.00000000e+00 2.49256950e-02
6.35880000e+04 7.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.13178682e-01 7.40000000e+01 0.00000000e+00 3.75606969e-01
3.50000000e+03 3.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 3.05682465e-01 5.70000000e+01 0.00000000e+00 5.71000000e+03
0.00000000e+00 8.00000000e+00 0.00000000e+00 3.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 7.54463648e-01 3.90000000e+01 0.00000000e+00 2.09940017e-01
3.50000000e+03 8.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.16950644e-01 2.70000000e+01 0.00000000e+00 4.60000000e+01
0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.89169052e-01 5.70000000e+01 0.00000000e+00 6.06290901e-01
2.36840000e+04 9.00000000e+00 0.00000000e+00 4.00000000e+00
0.00000000e+00 2.00000000e+00]]
print y_train[:10]
[[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]]
print x_test[:20]
[[ 4.83539240e-02 4.40000000e+01 0.00000000e+00 3.02297622e-01
7.48500000e+03 1.10000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 9.10224439e-01 4.20000000e+01 5.00000000e+00 1.72900000e+03
0.00000000e+00 5.00000000e+00 2.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.92682927e-01 5.80000000e+01 0.00000000e+00 3.66480079e-01
3.03600000e+03 7.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 3.11547538e-01 3.30000000e+01 1.00000000e+00 3.55431993e-01
4.67500000e+03 1.10000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 1.00000000e+00]
[ 0.00000000e+00 7.20000000e+01 0.00000000e+00 2.16630600e-03
6.00000000e+03 9.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 2.79217052e-01 4.50000000e+01 1.00000000e+00 4.89921122e-01
6.84500000e+03 8.00000000e+00 0.00000000e+00 2.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 0.00000000e+00 7.80000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 1.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.10363487e-01 2.80000000e+01 0.00000000e+00 4.99451497e-01
6.38000000e+03 8.00000000e+00 0.00000000e+00 2.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 6.36595797e-01 4.40000000e+01 0.00000000e+00 7.85457163e-01
4.16600000e+03 6.00000000e+00 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.41549211e-01 2.60000000e+01 0.00000000e+00 2.68407434e-01
4.25000000e+03 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.14101100e-03 7.80000000e+01 0.00000000e+00 2.26362500e-03
5.74200000e+03 7.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 9.99999900e-01 6.00000000e+01 0.00000000e+00 1.20000000e+02
0.00000000e+00 2.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 6.28525944e-01 4.70000000e+01 0.00000000e+00 1.13100000e+03
0.00000000e+00 5.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 2.00000000e+00]
[ 4.02283095e-01 6.00000000e+01 0.00000000e+00 3.79442065e-01
8.63800000e+03 1.00000000e+01 0.00000000e+00 1.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 5.70997900e-03 8.10000000e+01 0.00000000e+00 2.17382000e-04
2.30000000e+04 4.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.71171849e-01 5.10000000e+01 0.00000000e+00 1.53700000e+03
0.00000000e+00 1.40000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 1.42395210e-02 8.20000000e+01 0.00000000e+00 7.40466500e-03
2.70000000e+03 1.00000000e+01 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]
[ 4.67455800e-02 3.70000000e+01 0.00000000e+00 1.48010090e-02
9.12000000e+03 8.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 4.00000000e+00]
[ 9.99999900e-01 4.70000000e+01 0.00000000e+00 3.54604127e-01
1.10000000e+04 1.10000000e+01 0.00000000e+00 2.00000000e+00
0.00000000e+00 3.00000000e+00]
[ 8.96417860e-02 2.70000000e+01 0.00000000e+00 8.14664000e-03
5.40000000e+03 6.00000000e+00 0.00000000e+00 0.00000000e+00
0.00000000e+00 0.00000000e+00]]
print y_test[:20]
[[ 1. 0.]
[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 0. 1.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]
[ 1. 0.]]
【问题讨论】:
-
您是否总是在相同的数据集上进行测试?如果将它们洗牌,结果仍然是 1.0 吗?
-
是的,它仍然是 1.0
标签: python csv numpy tensorflow