【问题标题】:Modifying shape of tensor in tensorflow checkpoint在张量流检查点中修改张量的形状
【发布时间】:2018-06-16 17:00:17
【问题描述】:

我有一个 tensorflow 检查点,在使用常规例程 tf.train.Saver()saver.restore(session, 'my_checkpoint.ckpt') 重新定义与其对应的图形后,我可以加载它。

但是,现在,我想修改网络的第一层以接受形状为 [200, 200, 1] 而不是 [200, 200, 10] 的输入。

为此,我想通过对第三维求和,将第一层对应的张量的形状从[3, 3, 10, 32](3x3 内核,10 个输入通道,32 个输出通道)修改为[3, 3, 1, 32]

我该怎么做?

【问题讨论】:

    标签: tensorflow tf-slim


    【解决方案1】:

    你可以使用tensorflow::BundleReader读取源码ckpt,并使用tensorflow::BundleWriter重写。

    tensorflow::BundleReader reader(Env::Default(), model_path_prefix);
    std::vector<std::string> tensor_names;
    reader.Seek("");
    reader.Next();
    for (; reader.Valid(); reader.Next()) {
        tensor_names.emplace_back(reader.key());
    }
    tensorflow::BundleWriter writer(Env::Default(), new_model_path_prefix);   
    for (auto &tensor_name : tensor_names) {
            DataType dtype;
            TensorShape shape;        
            
            reader.LookupDtypeAndShape(tensor_name, &dtype, &shape);
            Tensor val(dtype, shape);
            Status bool_ret  = reader.Lookup(tensor_name, &val);
            std::cout << tensor_name << " " << val.DebugString() << std::endl;
            // modify dtype and shape. padding Tensor
            TensorSlice slice(new_shape.dims());
            writer.AddSlice(tensor_name, new_shape, slice, new_val);
        }
    }
    writer.Finish();
    

    【讨论】:

      【解决方案2】:

      我找到了一种方法,但方法并不那么简单。 给定一个检查点,我们可以将其转换为序列化的 numpy 数组(或我们可能发现适合保存 numpy 数组字典的任何其他格式),如下所示:

      checkpoint = {}
      with tf.Session() as sess:
          sess.run(tf.global_variables_initializer())
          saver.restore(sess, 'my_checkpoint.ckpt')
          for x in tf.global_variables():
              checkpoint[x.name] = x.eval()
          np.save('checkpoint.npy', checkpoint)
      

      可能会有一些异常需要处理,但让我们保持代码简单。

      然后,我们可以对 numpy 数组执行任何我们喜欢的操作:

      checkpoint = np.load('checkpoint.npy')
      checkpoint = ...
      np.save('checkpoint.npy', checkpoint)
      

      最后,我们可以在构建图表后手动加载权重,如下所示:

      with tf.Session() as sess:
          sess.run(tf.global_variables_initializer())
          checkpoint = np.load('checkpoint.npy').item()
          for key, data in checkpoint.iteritems():
              var_scope = ... # to be extracted from key
              var_name = ...  # 
              with tf.variable_scope(var_scope, reuse=True):
                  var = tf.get_variable(var_name)
                  sess.run(var.assign(data))
      

      如果有更直接的方法,我会全力以赴!

      【讨论】:

        猜你喜欢
        • 2021-10-22
        • 2018-12-13
        • 1970-01-01
        • 2016-12-12
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        • 1970-01-01
        相关资源
        最近更新 更多