【问题标题】:Dynamic while loop Tensorflow动态while循环Tensorflow
【发布时间】:2021-05-11 09:44:35
【问题描述】:

我有这个numpy 函数,我把它简化成这样的:

def _func(new_inputs, X):
    """ Basically any operation. Below is just an example """
    outputs = new_inputs + X
    new_inputs.pop(0)
    return outputs, new_inputs

new_inputs = []
flag = True
while flag:
    outputs, new_inputs = _func(new_inputs, X)
    X = np.concatenate([X, outputs], axis=0)
    if not new_inputs:
        flag = False

我想将其转换为 AutoGraph 可以支持的等效 TF 函数。我正在使用 TF 2.4.1,但我需要在图形计算下运行它,因为我需要在 Beam 上下文下运行它。

这是我的尝试:

def function_to_solve():
    new_inputs = tf.constant([])
    flag = tf.constant(True)
    outputs = tf.constant([])
    X = tf.TensorArray(dtype=tf.float32, infer_shape=False, size=1,
                        dynamic_size=True)

    def _func(new_inputs, X):
        return outputs, new_inputs


    def loop_func(flag, new_inputs, outputs, X, i):
        outputs, new_inputs = _func(new_inputs, X)
        X.write(i, outputs)
        if i==2:
            flag = False
        return flag, new_inputs, outputs, X, i+1


    def condition(flag, *args):
        return flag


    _, _, _, X, _ = tf.while_loop(condition, loop_func, [flag, new_inputs, outputs, X, 0])
    X.concat()
    return X

然后在 Beam 上下文中调用它:

import pprint
import tempfile

import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam.impl as tft_beam
from tensorflow_transform.tf_metadata import dataset_metadata
from tensorflow_transform.tf_metadata import dataset_schema


raw_data = [
      {'x': [1, 2, 0, 0, 0, 0.3, 10, 1, 1.4]},
 ]

raw_data_metadata = dataset_metadata.DatasetMetadata(
    dataset_schema.from_feature_spec({
        'x': tf.io.FixedLenFeature(shape=(9,), dtype=tf.float32),
    }))

def preprocessing_fn(inputs):
    outputs = function_to_solve()
    return {
        'outputs ': outputs
    }

with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
  transformed_dataset, transform_fn = (
      (raw_data, raw_data_metadata) | tft_beam.AnalyzeAndTransformDataset(
          preprocessing_fn))

transformed_data, transformed_metadata = transformed_dataset

print('\nRaw data:\n{}\n'.format(pprint.pformat(raw_data)))
print('Transformed data:\n{}'.format(pprint.pformat(transformed_data)))

我收到以下错误:

AttributeError: 'TensorArray' object has no attribute 'get_shape'

有人可以帮忙吗?谢谢!

【问题讨论】:

    标签: python tensorflow apache-beam tfx


    【解决方案1】:

    每个tf.TensorArray 对象都可以使用ta.stack() 函数转换为tf.Tensor

    tensor = ta.stack()
    

    因此,无论错误发生在哪里,您都必须将TensorArray 对象转换为常规Tensorget_shape 函数当前在另一个 TF 函数中被调用。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2010-10-26
      • 2019-04-28
      • 2017-09-27
      • 2017-05-05
      • 2017-10-03
      • 1970-01-01
      • 2018-02-12
      相关资源
      最近更新 更多