【发布时间】:2018-02-27 01:49:28
【问题描述】:
我正在尝试写一个可以批量获取数据的函数,类似于tensorflow的next_batch。
next_batch 可以在这里看到: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/learn/python/learn/datasets/mnist.py
这是我写的代码。
class Sampler:
def __init__(self, data):
self.x, self.y = data
self.N, = self.y.shape
self.start = 0
self.shuffle = np.arange(self.N)
np.random.shuffle(self.shuffle)
self.x = self.x[self.shuffle]
self.y = self.y[self.shuffle]
def sample(self, s):
start = self.start
end = np.minimum(start+s, self.N)
data = (self.x[start:end], self.y[start:end])
self.start += s
if self.start >= self.N - 1:
self.start = 0
np.random.shuffle(self.shuffle)
self.x = self.x[self.shuffle]
self.y = self.y[self.shuffle]
return data
我觉得这是一种自然的方法,但是虽然我可以使用 next_batch 获得 99% 以上的分类准确率,但使用我的“样本”函数只能获得 50% 左右。
谁能帮我理解发生了什么?
【问题讨论】:
-
据我所知,您的代码与 mnist 示例中的 next_batch 函数几乎完全相同。唯一的区别是示例中的 DataSet 类将输入数据从 (x,y,z,1) 展平为 (x,y*z),然后还将所有数据从 [0,256] 归一化为 [0,1]。这些都不会立即影响准确性,但取决于您的训练方式,它们可能会产生影响。
-
非常感谢——这解决了我的问题。我会将此标记为正确答案,但这是一条评论,所以我认为我不能这样做。请随时将其写到答案中,我会检查标记它! :) 再次感谢。
-
np,谢谢你的 5 美元 :) jkjk
标签: python tensorflow tensorflow-datasets