【问题标题】:Python Tensorflow itertools groupby: using itertools.groupby() in tf.data.Dataset.filter()Python Tensorflow itertools groupby:在 tf.data.Dataset.filter() 中使用 itertools.groupby()
【发布时间】:2022-02-01 01:01:01
【问题描述】:

我正在尝试对tf.data.Dataset 应用过滤器,以删除其中一组 > 50% 的字符串的任何字符串。这是我的Dataset

import tensorflow as tf


strings = [
    ["ABCDEFGABCDEFG\tUseless\tLabel1"],
    ["AAAAAAAADEFGAB\tUseless\tLabel2"],
    ["HIJKLMNHIJKLMN\tUseless\tLabel3"],
    ["HIJKLMMMMMMMNH\tUseless\tLabel4"],
]
ds = tf.data.Dataset.from_tensor_slices(strings)

def _clean(x):
    x = tf.strings.split(x, "\t")
    return x[0], x[2]

def _filter(x):
    s = tf.strings.bytes_split(x)
    _, _, count = tf.unique_with_counts(s)
    percent = tf.reduce_max(count) / tf.shape(s)[0]
    return tf.less_equal(percent, 0.5)

ds = ds.map(_clean)
ds = ds.filter(lambda x, y: _filter(x))

for x, y in ds:
    tf.print(x, y)

这会产生以下错误:

TypeError: Failed to convert elements of tf.RaggedTensor(values=Tensor("StringsByteSplit/StringSplit:1", shape=(None,), dtype=string), row_splits=Tensor("StringsByteSplit/RaggedFromValueRowIds/RowPartitionFromValueRowIds/concat:0", shape=(None,), dtype=int64)) to Tensor. Consider casting elements to a supported type.

有什么方法可以在tf.data.Dataset 图表中解决这个问题?

【问题讨论】:

    标签: python tensorflow filter tensorflow-datasets tf.data.dataset


    【解决方案1】:

    你可以使用tf.strings解决这个问题:

    import tensorflow as tf
    
    def filter_data(x):
      s = tf.strings.strip(tf.strings.regex_replace(x, '', ' '))
      s = tf.strings.split(s, sep=" ")
      _, _, count = tf.unique_with_counts(s)
      return tf.less_equal(tf.reduce_max(count) / tf.shape(s)[0], 0.25)
    
    ds = tf.data.Dataset.from_tensor_slices([["AAAABBBCC", "Label1"], ["AAAAAABC", "Label2"], ["ABBAABCCCCAB", "Label3"], ["ABDC", "Label4"]])
    ds = ds.map(lambda x: (x[0], x[1]))
    
    ds = ds.filter(lambda x, y: filter_data(x))
    for x, y in ds:
      tf.print(x, y)
    
    "ABDC" "Label4"
    

    但是,我会重新考虑 25% 的阈值,因为您的示例数据集中的所有样本都高于此阈值,因此不会添加到数据集中。我在您的数据集中添加了第四个示例,以表明该方法适用于 tf.less_equal

    AAAABBBCC为例,A出现频率最高(4次),除以字符串的总长度(9),得到4/9=0.44,这意味着它被排除在数据集中。也许这种行为是需要的。无论如何,我只是想通知你。

    【讨论】:

    • 非常感谢您的帮助!!一个问题:原始字符串实际上是制表符分隔的字符串,所以在管道的早期,我使用tf.strings.split(x, "\t") 将它们分成单独的部分。这引发了过滤的问题。我已经编辑了这个问题,所以它可以作为一个更好的例子。
    • 我认为您可以使用tf.strings 工具解决此问题;)
    • 另外,s=tf.strings.bytes_split() 是执行过滤功能前两行的好方法 :)
    • 把你的 clean 函数改成这个,一切都会正常工作:def _clean(x): x = tf.squeeze(tf.strings.split(x, "\t"), axis=0) return x[0], x[2]
    • 不规则张量总是多维的
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2015-12-17
    相关资源
    最近更新 更多