【发布时间】:2020-10-01 14:18:12
【问题描述】:
我在tf.data.Dataset 中使用基于typing.NamedTuple 的元素类型。下面是一个例子。
# You can run all the code in this question by pasting all
# the code blocks consecutively into a Python file
import tensorflow as tf
from typing import *
from random import *
from pprint import *
class Coord(NamedTuple):
x: float
y: float
@classmethod
def random(cls): return cls(gauss(10., 1.), gauss(10., 1.))
class Box(NamedTuple):
min: Coord
max: Coord
@classmethod
def random(cls): return cls(Coord.random(), Coord.random())
class Boxes(NamedTuple):
boxes: List[Box]
@classmethod
def random(cls): return cls([Box.random() for _ in range(randint(3, 5))])
def test_dataset():
for _ in range(randint(3, 5)): yield Boxes.random()
tf_dataset = tf.data.Dataset.from_generator(test_dataset, output_types=(tf.float32,))
您可能知道,tf.data.Dataset.from_generator() 将数据集元素(最初具有 Boxes 类型)转换为具有 (None, 2, 2) 形状的 tf.Tensor 的单元素元组。例如,数据集的一个元素可能是以下项目:
(<tf.Tensor: shape=(4, 2, 2), dtype=float32, numpy=
array([[[11.642379, 9.937152],
[ 8.998009, 8.387287]],
[[10.649337, 10.028358],
[ 8.507834, 9.84779 ]],
[[11.10263 , 11.3706 ],
[ 9.20623 , 10.44905 ]],
[[ 9.591406, 9.560486],
[ 9.461394, 9.256082]]], dtype=float32)>,)
我有非@tf.function-annotated 的常规 Python 函数,可以将数据转换为其原始类型,例如以下函数:
def flip_boxes(boxes: Boxes):
def flip_coord(c: Coord): return Coord(-c.x, c.y)
def flip_box(b: Box): return Box(flip_coord(b.min), flip_coord(b.max))
return Boxes(boxes=list(map(flip_box, boxes.boxes)))
我想通过 tf.data.Dataset.map(map_func) 函数将此 Python 函数(以及其他类似函数)应用于此 tf.data.Dataset。 Dataset.map 期望 map_func 是一个函数,以 tf.Tensor 格式获取数据集元素类型的成员。原始元素类型是Boxes,它有一个成员,最初是boxes: List[Box]。创建数据集时,该列表将转换为上面的(4, 2, 2)-shape 张量。当tf.data.Dataset.map() 调用map_func 时,它不会被转换回来,而是直接将张量作为第一个参数传递给map_func。 (如果Boxes 有更多成员,这些成员将作为单独的参数传递给map_func,并且它们不会作为单个元组传递。)
问题:我要实现什么适配器函数才能使常规 Python 函数(如 flip_boxes)与 tf.data.Dataset.map() 一起使用?
我尝试迭代并使用tf.split 从输入tf.Tensor 中恢复List[Boxes],但我遇到了下面作为cmets 列出的错误消息。
# Question: How do I implement this function?
def to_tf_mappable_function(fn: Callable) -> Callable:
def function(tensor: tf.Tensor):
boxes: List[Box] = [Box(Coord(10.0, 0.0), Coord(10.0, 0.0)), Box(Coord(10.0, 0.0), Coord(10.0, 0.0))]
# TODO calculate `boxes` from `tensor`, not use this dummy constant above
# Trivial Python code does not work, it results in this error on the commented-out line:
# OperatorNotAllowedInGraphError: iterating over `tf.Tensor` is not allowed:
# AutoGraph is disabled in this function. Try decorating it directly with @tf.function.
# boxes = [Box(Coord(row[0][0], row[0][1]), Coord(row[1][0], row[1][1])) for row in tensor]
# Decorating any of flip_boxes, to_tf_mappable_function and to_tf_mappable_function.function
# does not eliminate the error.
# I thought tf.split might help, but it results in this error on the commented-out line:
# ValueError: Rank-0 tensors are not supported as the num_or_size_splits argument to split.
# Argument provided: Tensor("cond/Identity:0", shape=(), dtype=int32)
# boxes = tf.split(tensor, len(tensor))
return fn(Boxes(boxes))
return function
tf_dataset = tf_dataset.map(to_tf_mappable_function(flip_boxes))
# The line above should be morally equivalent to `dataset = map(flip_boxes, dataset)`,
# given a `dataset: Iterable[Boxes]` and the builtin `map` function in Python.
也许我没有问对正确的问题,但请让我放松一下。
* 高级任务是以有效的方式将flip_boxes 和类似功能应用于tf.data.Dataset
* 我被卡住的地方是从tf.Tensor 中恢复List[Box],它的形状与框坐标列表完全相同,所以也许我的问题应该针对这个问题。
【问题讨论】:
标签: python tensorflow tensorflow2.0