【问题标题】:When I try to reshape my training data I get this error.... ValueError: cannot reshape array of size 568 into shape (28,28,3)当我尝试重塑我的训练数据时,我得到了这个错误.... ValueError: cannot reshape array of size 568 into shape (28,28,3)
【发布时间】:2020-02-07 02:06:09
【问题描述】:

这是我在图片中看到的地方:

train = []
imgsize = 28
    for image_name in image_name_list:
        im = cv2.imread(path_string + image_name +'.jpg')
        new = cv2.resize(im,(imgsize, imgsize))
        train.append(new)

在我使用的教程中,我不确定我们为什么要通过调整大小的图像列表将 X、Y 循环到不同的变量中。我假设它是划分训练数据和测试数据:

X = []
Y = []
for features, labels in enumerate(train):
    X.append(features)
    Y.append(labels)
X = np.array(X).reshape(-1, imgsize, imgsize, 3)

我知道最后一个数字表示它是灰度还是 RGB,但我需要颜色,因为我的图像需要颜色

ValueError: 无法将大小为 568 的数组重新整形为 (28,28,3)

【问题讨论】:

  • 调整大小后将灰度转换为3通道BGR图像。所以在new = cv2.resize(im,(imgsize, imgsize)) 之后做一个new = cv2.merge([im, im, im])new = cv2.cvtColor(im, cv2.COLOR_GRAY2BGR)
  • 都不工作,出现同样的错误:(

标签: python opencv tensorflow


【解决方案1】:

似乎不可能将大小为 568 的数组重新整形为 28x28x3...? 28x28x3=2352

【讨论】:

  • 我将如何使它成为可能,我不明白这个概念,它的数组大小有影响吗?
【解决方案2】:

冻结模型时需要设置固定的输入张量大小。

import tensorflow as tf
import os
from tensorflow.python.tools.freeze_graph import freeze_graph
import models
import utils
import image_utils as im
import numpy as np

FLAGS = tf.flags.FLAGS

tf.flags.DEFINE_string('checkpoint_dir', './checkpoints/photo2cartoon', 'checkpoints directory path')
tf.flags.DEFINE_integer('crop_size', '256', 'crop_size, default: 256')

def export_graph(model_name):
  graph = tf.Graph()

  with graph.as_default(): 

    a_real = tf.placeholder(tf.float32,shape=([1,FLAGS.crop_size, FLAGS.crop_size, 3]),name='input_image') # <<<< YOU NEED TO DEFINE THIS 
    #a_real=tf.reshape(a_real,tf.stack([1,FLAGS.crop_size, FLAGS.crop_size, 3]))

    a2b = models.generator(a_real, 'a2b',reuse=False, train=False) 

    saver = tf.train.Saver()

  with tf.Session(graph=graph) as sess:
    sess.run(tf.global_variables_initializer())
    # ------------------------------
    # Save graph nodes to text file
    # ------------------------------
    graph_def=graph.as_graph_def()
    # Remove Const nodes.
    for i in reversed(range(len(graph_def.node))):
      if graph_def.node[i].op == 'Const':
        del graph_def.node[i]
      for attr in ['T', 'data_format', 'Tshape', 'N', 'Tidx', 'Tdim',
                     'use_cudnn_on_gpu', 'Index', 'Tperm', 'is_training',
                     'Tpaddings']:
        if attr in graph_def.node[i].attr:
          del graph_def.node[i].attr[attr]   
    # Save as text.
    tf.train.write_graph(graph_def, "", "text_graph.pbtxt", as_text=True)    

     # ------------------------------
     # Load variables data
     # ------------------------------
    latest_ckpt = utils.load_checkpoint(FLAGS.checkpoint_dir, sess, saver)

    if latest_ckpt is None:
      raise Exception('No checkpoint!')
    else:
      print('Copy variables from % s' % latest_ckpt)    

    # -----------------------------------------
    # Write data for tensorboard for show graph
    # -----------------------------------------
    a_real_ipt = np.zeros(shape=[1, FLAGS.crop_size, FLAGS.crop_size, 3])
    writer = tf.summary.FileWriter('logs', sess.graph)
    writer.close()
    # -----------------------------------------
    # Write graph output
    # ----------------------------------------- 

    # get graph definition
    gd = sess.graph.as_graph_def()

    # fix batch norm nodes
    for node in gd.node:
      if node.op == 'RefSwitch':
        node.op = 'Switch'
        for index in xrange(len(node.input)):
          if 'moving_' in node.input[index]:
            node.input[index] = node.input[index] + '/read'
      elif node.op == 'AssignSub':
        node.op = 'Sub'
        if 'use_locking' in node.attr: del node.attr['use_locking']


    output_graph_def = tf.graph_util.convert_variables_to_constants(sess, gd, ["a2b_generator/Tanh"])
    tf.train.write_graph(output_graph_def, 'pretrained', model_name, as_text=False)

def main(unused_argv):
  print('photo2cartoon.pb')
  export_graph('photo2cartoon.pb')

if __name__ == '__main__':
  tf.app.run()

【讨论】:

  • 我该怎么做?
  • 在我的项目中添加了冻结功能示例。您需要定义像上面的“a_real”这样的占位符。
猜你喜欢
  • 2021-08-23
  • 2017-09-15
  • 2018-01-16
  • 2018-12-31
  • 1970-01-01
  • 1970-01-01
  • 1970-01-01
  • 2021-10-30
  • 2020-05-13
相关资源
最近更新 更多