【问题标题】:How to save the value of a tensor in Tensorflow如何在 Tensorflow 中保存张量的值
【发布时间】:2021-07-05 22:05:07
【问题描述】:

我正在尝试保存一个张量数组,该数组必须使用装饰器@tf.function 计算到一个函数中,这使得函数内的所有张量都成为张量图,因此是不可迭代的对象。例如,在下面的最小代码中,我想知道是否可以使用函数foo() 中的代码将张量保存到文件中。

@tf.function
def foo(x):
    # code for saving x


a=tf.constant([1,2,3])
foo(a)

【问题讨论】:

  • 据我所知,没有办法只保存一个张量。一种方法是执行a.numpy().save('file.npy'),然后在加载后转换回张量。

标签: python tensorflow tensorflow2.0 tensorflow-datasets


【解决方案1】:

好吧,我想您是在图形模式下运行该函数,否则(急切模式执行)您可以在通过张量的.numpy() 函数访问它们后使用 NumPy 或普通的 pythonic 文件处理方式来保存值。
在图形模式下,您可以使用tf.io.write_file() 操作。详细说明前面提到的解决方案,write_file fn 采用单个字符串。下面的示例可能会有所帮助:

a = tf.constant([1,2,3,4,5] , dtype = tf.int32)
b = tf.constant([53.44569] , dtype= tf.float32)
c = tf.constant(0)
# if u wish to write all these tensors on each line ,
# then create a single string out of these.
one_string = tf.strings.format("{}\n{}\n{}\n", (a,b,c))
# {} is a placeholder for each element ina string and thus you would need n PH for n tensors.
# send this string to write_file fn
tf.io.write_file(filename, one_string)

write_filefn 只接受字符串,所以你需要先将所有内容都转换为字符串。 此外,如果您在同一运行中多次调用 write_file fn n 次,则每次调用都将覆盖先前的输出,因此文件将包含最后一次调用的内容。

【讨论】:

    【解决方案2】:

    看看tf.io.write_file。它允许您将张量写入文件。

    读取保存的张量文件对应的函数是tf.io.read_file

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2020-06-22
      • 2016-09-01
      • 2016-09-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多