【问题标题】:Tensorflow getting data into it (SVHN)Tensorflow 将数据输入其中(SVHN)
【发布时间】:2016-07-22 22:19:20
【问题描述】:

我成功安装了 tensorflow,并按照关于 MNIST 数据的简单教程进行操作。 现在我想建立模型来训练 SVHN 数据。不幸的是,我无法找到如何将数据导入模型的任何地方。 基本上是每个模型的第一步。 数据保存在字典中,参数为“X”键,标签为“y”键。 形状如下:

打印 traindata['X'].shape

(32, 32, 3, 73257)

打印 traindata['y'].shape

(73257, 1)

谁能给我一个提示或链接如何成功地进入 tensorflow?

谢谢

【问题讨论】:

    标签: machine-learning tensorflow


    【解决方案1】:

    我正在使用这个数据集的课程。需要转换为灰度并保持维度,将图像数据缩放为[0,1),将10个标签更改为0,并移动轴使图像的索引在前。最终做了下面的事情。之后,诸如 model.fit(xTrain,yTrain, ... ) 之类的东西起作用了。

        xTrain,xTest=np.mean(xTrain0,axis=2,keepdims=True),np.mean(xTest0,axis=2,keepdims=True);
    xTrain/=255; xTest/=255
    print(f'Min: {xTrain.min()}, Max: {xTrain.max()}')
    yTrain[yTrain>9]=0; yTest[yTest>9]=0
    print(f'Min: {yTrain.min()}, Max: {yTrain.max()}')
    xTrain,xTest=np.moveaxis(xTrain,-1,0),np.moveaxis(xTest,-1,0)
    print(xTrain.shape,xTest.shape)
    

    【讨论】:

      【解决方案2】:

      SVNH的digitStruct.mat是matlab的文件格式,需要转换一下。

      这是将digitStruct.mat转换为json的代码,或者您可以使用scipy.io.loadmat

      # coding: utf-8
      # SVHN extracts data from the digitStruct.mat full numbers files.  The data can be downloaded
      # the Street View House Number (SVHN)  web site: http://ufldl.stanford.edu/housenumbers.
      #
      # This is an A2iA tweak (YG -9 Jan 2014) of the script found here :
      # http://blog.grimwisdom.com/python/street-view-house-numbers-svhn-and-octave
      #
      # The digitStruct.mat files in the full numbers tars (train.tar.gz, test.tar.gz, and extra.tar.gz)
      # are only compatible with matlab.  This Python program can be run at the command line and will generate
      # a json version of the dataset.
      #
      # Command line usage:
      #       SVHN_dataextract.py [-f input] [-o output_without_extension]
      #    >  python SVHN_dataextract.py -f digitStruct.mat -o digitStruct
      #
      # Issues:
      #    The alibility to split in several files has been removed from the original
      #    script.
      #
      
      import tqdm
      import h5py
      import optparse
      from json import JSONEncoder
      
      parser = optparse.OptionParser()
      parser.add_option("-f", dest="fin", help="Matlab full number SVHN input file", default="digitStruct.mat")
      parser.add_option("-o", dest="filePrefix", help="name for the json output file", default="digitStruct")
      options, args = parser.parse_args()
      
      fin = options.fin
      
      
      # The DigitStructFile is just a wrapper around the h5py data.  It basically references
      #    inf:              The input h5 matlab file
      #    digitStructName   The h5 ref to all the file names
      #    digitStructBbox   The h5 ref to all struc data
      class DigitStructFile:
          def __init__(self, inf):
              self.inf = h5py.File(inf, 'r')
              self.digitStructName = self.inf['digitStruct']['name']
              self.digitStructBbox = self.inf['digitStruct']['bbox']
      
          # getName returns the 'name' string for for the n(th) digitStruct.
          def getName(self, n):
              return ''.join([chr(c[0]) for c in self.inf[self.digitStructName[n][0]].value])
      
          # bboxHelper handles the coding difference when there is exactly one bbox or an array of bbox.
          def bboxHelper(self, attr):
              if len(attr) > 1:
                  attr = [self.inf[attr.value[j].item()].value[0][0] for j in range(len(attr))]
              else:
                  attr = [attr.value[0][0]]
              return attr
      
          # getBbox returns a dict of data for the n(th) bbox.
          def getBbox(self, n):
              bbox = {}
              bb = self.digitStructBbox[n].item()
              bbox['height'] = self.bboxHelper(self.inf[bb]["height"])
              bbox['label'] = self.bboxHelper(self.inf[bb]["label"])
              bbox['left'] = self.bboxHelper(self.inf[bb]["left"])
              bbox['top'] = self.bboxHelper(self.inf[bb]["top"])
              bbox['width'] = self.bboxHelper(self.inf[bb]["width"])
              return bbox
      
          def getDigitStructure(self, n):
              s = self.getBbox(n)
              s['name'] = self.getName(n)
              return s
      
          # getAllDigitStructure returns all the digitStruct from the input file.
          def getAllDigitStructure(self):
              print('Starting get all digit structure')
              return [self.getDigitStructure(i) for i in tqdm.tqdm(range(len(self.digitStructName)))]
      
          # Return a restructured version of the dataset (one structure by boxed digit).
          #
          #   Return a list of such dicts :
          #      'filename' : filename of the samples
          #      'boxes' : list of such dicts (one by digit) :
          #          'label' : 1 to 9 corresponding digits. 10 for digit '0' in image.
          #          'left', 'top' : position of bounding box
          #          'width', 'height' : dimension of bounding box
          #
          # Note: We may turn this to a generator, if memory issues arise.
          def getAllDigitStructure_ByDigit(self):
              pictDat = self.getAllDigitStructure()
              result = []
              structCnt = 1
              print('Starting pack josn dict')
              for i in tqdm.tqdm(range(len(pictDat))):
                  item = {'filename': pictDat[i]["name"] }
                  figures = []
                  for j in range(len(pictDat[i]['height'])):
                      figure = dict()
                      figure['height'] = pictDat[i]['height'][j]
                      figure['label']  = pictDat[i]['label'][j]
                      figure['left']   = pictDat[i]['left'][j]
                      figure['top']    = pictDat[i]['top'][j]
                      figure['width']  = pictDat[i]['width'][j]
                      figures.append(figure)
                  structCnt += 1
                  item['boxes'] = figures
                  result.append(item)
              return result
      
      
      dsf = DigitStructFile(fin)
      dataset = dsf.getAllDigitStructure_ByDigit()
      fout = open(options.filePrefix + ".json", 'w')
      fout.write(JSONEncoder(indent=True).encode(dataset))
      fout.close()
      

      之后,您应该编写代码将数据加载到 numpy 中。

      在我看来,您的任务不是将数据加载到 TensorFlow 中,而是将所有图像加载到 numpy 中。因此,您还应该使用PIL 库将图像读取为 numpy 格式。

      【讨论】:

        【解决方案3】:

        TensorFlow 使用这个概念:首先,你定义一个图;接下来,你训练一个图;最后,使用图表。

        在您定义图表的那一刻,您创建了占位符。这就像您的图表的输入节点。但是,此时,这些变量是“空的”,因为它们与您的输入数据无关。

        在训练时和测试时,您都可以通过引用此预定义的输入节点将数据“输入”到图表中。

        这个概念对你来说可能是新的,我建议你学习一些关于它的教程。 TensorFlow 本身有一个不错的页面,名为“Tensorflow Mechanics 101”和“基本用法”。 如果你比较直观,我可以推荐 YouTube 频道“Dan Does Data”,他以幽默的方式探索 TensorFlow 概念。

        如果您更喜欢示例代码,可以考虑this 示例,我在其中为 MNIST 制作了一个小型 CNN。看看占位符“x”和“y_”,它们是您感兴趣的变量。

        【讨论】:

          【解决方案4】:

          我认为一个好主意是将您的数据重塑为:

          打印 traindata['X'].shape
          (73257, 32*32*3) = (73257, 3072)

          我也是 tensorflow 和 python 的新手,但我认为你可以通过使用 numpy 和 np.reshape 来做到这一点

          查看以下文档: http://docs.scipy.org/doc/numpy/reference/generated/numpy.reshape.html

          告诉我它是否适合你:)

          【讨论】:

            猜你喜欢
            • 2017-10-27
            • 1970-01-01
            • 1970-01-01
            • 2019-06-28
            • 2017-11-18
            • 1970-01-01
            • 2018-03-06
            • 2018-02-27
            • 2018-08-17
            相关资源
            最近更新 更多