【发布时间】:2022-02-01 01:01:01
【问题描述】:
我正在尝试对tf.data.Dataset 应用过滤器,以删除其中一组 > 50% 的字符串的任何字符串。这是我的Dataset:
import tensorflow as tf
strings = [
["ABCDEFGABCDEFG\tUseless\tLabel1"],
["AAAAAAAADEFGAB\tUseless\tLabel2"],
["HIJKLMNHIJKLMN\tUseless\tLabel3"],
["HIJKLMMMMMMMNH\tUseless\tLabel4"],
]
ds = tf.data.Dataset.from_tensor_slices(strings)
def _clean(x):
x = tf.strings.split(x, "\t")
return x[0], x[2]
def _filter(x):
s = tf.strings.bytes_split(x)
_, _, count = tf.unique_with_counts(s)
percent = tf.reduce_max(count) / tf.shape(s)[0]
return tf.less_equal(percent, 0.5)
ds = ds.map(_clean)
ds = ds.filter(lambda x, y: _filter(x))
for x, y in ds:
tf.print(x, y)
这会产生以下错误:
TypeError: Failed to convert elements of tf.RaggedTensor(values=Tensor("StringsByteSplit/StringSplit:1", shape=(None,), dtype=string), row_splits=Tensor("StringsByteSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)) to Tensor. Consider casting elements to a supported type.
有什么方法可以在tf.data.Dataset 图表中解决这个问题?
【问题讨论】:
标签: python tensorflow filter tensorflow-datasets tf.data.dataset