【问题标题】:TF 2.0 error with @tf.function decorator?@tf.function 装饰器的 TF 2.0 错误?
【发布时间】:2019-09-17 00:10:15
【问题描述】:

我有一个非常简单的 Python 控制流语句程序

@tf.function
def mandelbrot(T, max_iter):
    for i in range(10):
        if (tf.abs(T)) >= 4:
                return 5
    return max_iter

T=tf.complex(10.,2.)
mandelbrot(T, 100)

但它不起作用,并引发大量跟踪错误。这么简单的代码有什么问题?

----------------------------------- ---------------------------- AssertionError Traceback(最近调用 最后)在 2 T=tf.complex(10.,2.) 3 ----> 4 曼德布罗(T, 100)

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\eager\def_function.py 在调用(自我,*args,**kwds) 424 # 这是call的第一次调用,所以我们要初始化。 第425章 --> 426 self._initialize(args, kwds, add_initializers_to=initializer_map) 427 如果 self._created_variables: 428尝试:

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\eager\def_function.py 在 _initialize(self, args, kwds, add_initializers_to) 第368章 369 self._stateful_fn._get_concrete_function_internal_garbage_collected(

pylint: disable=protected-access

--> 370 *args, **kwds)) 371 第372章

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\eager\function.py 在 _get_concrete_function_internal_garbage_collected(self, *args, **kwargs) 1311 if self._input_signature: 1312 args, kwargs = None, None -> 1313 graph_function,_,_ = self._maybe_define_function(args, kwargs) 1314 return graph_function 1315

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\eager\function.py 在 _maybe_define_function(self, args, kwargs) 1578 或 call_context_key 不在 self._function_cache.missed 中):1579
self._function_cache.missed.add(call_context_key) -> 1580 graph_function = self._create_graph_function(args, kwargs) 1581 self._function_cache.primary[cache_key] = graph_function 1582 返回graph_function, args, kwargs

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\eager\function.py 在 _create_graph_function(self, args, kwargs, override_flat_arg_shapes)1510 arg_names=arg_names,
第1511章 -> 1512 capture_by_value=self._capture_by_value), 1513 self._function_attributes) 1514

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\framework\func_graph.py 在 func_graph_from_py_func(名称,python_func,args,kwargs,签名, func_graph,签名,autograph_options,add_control_dependencies, arg_names、op_return_value、集合、capture_by_value、 override_flat_arg_shapes) 第692章 693 --> 694 func_outputs = python_func(*func_args, **func_kwargs) 695 696 # 不变量:func_outputs 只包含张量、IndexedSlices,

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\eager\def_function.py 在 Wrapped_fn(*args, **kwds) 315 # wrapped 允许 AutoGraph 交换转换后的函数。我们给予 316 # 函数对自身进行弱引用以避免引用循环。 --> 317 return weak_wrapped_fn().wrapped(*args, **kwds) 第318章 第319章

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\framework\func_graph.py 在包装器中(*args,**kwargs) 第684章 第685章 --> 686), args, kwargs) 687 688 # 包裹装饰器允许像 tf_inspect.getargspec 这样的检查

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\autograph\impl\api.py 在 convert_call(f, owner, options, args, kwargs) 第390章 391 --> 392 结果 = 转换的_f(*有效参数,**kwargs) 393 394 # 转换后的函数的闭包被简单地插入函数的

~\AppData\Local\Temp\tmp95dcry6m.py in tf__mandelbrot(T, max_iter) 20 retval__1, do_return_1 = ag__.if_stmt(cond, if_true, if_false) 21 返回 retval__1, do_return_1 ---> 22 retval_, do_return = ag__.for_stmt(ag__.converted_call(范围, 无, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(tf.function, defun, ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=(), internal_convert_user_code=True), (10,), {}), extra_test, loop_body, (retval_, do_return)) 23 cond_1 = ag__.not_(do_return) 24

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\autograph\operators\control_flow.py 在 for_stmt(iter_, extra_test, body, init_state) 79 返回_dataset_for_stmt(iter_,extra_test,body,init_state) 80 其他: ---> 81 返回 _py_for_stmt(iter_, extra_test, body, init_state) 82 83

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\autograph\operators\control_flow.py 在 _py_for_stmt(iter_, extra_test, body, init_state) 88 如果 extra_test 不是 None 也不是 extra_test(*state): 89 休息 ---> 90 状态 = 身体(目标,*状态) 91返回状态 92

~\AppData\Local\Temp\tmp95dcry6m.py in loop_body(loop_vars, retval__1, do_return_1) 18 定义 if_false(): 19 返回 retval__1, do_return_1 ---> 20 retval__1, do_return_1 = ag__.if_stmt(cond, if_true, if_false) 21 返回 retval__1, do_return_1 22 retval_,do_return = ag__.for_stmt(ag__.converted_call(范围,无, ag__.ConversionOptions(recursive=True, verbose=0, strip_decorators=(tf.function, defun, ag__.convert, ag__.do_not_convert, ag__.converted_call), force_conversion=False, optional_features=(), internal_convert_user_code=True), (10,), {}), extra_test, loop_body, (retval_, do_return))

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\autograph\operators\control_flow.py 在 if_stmt(cond, body, orelse) 第243章 244 如果 tensor_util.is_tensor(cond): --> 245 返回 tf_if_stmt(cond, body, orelse) 246 其他: 第247章

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\autograph\operators\control_flow.py 在 tf_if_stmt(cond, body, orelse) 第254章 255 --> 256 返回 control_flow_ops.cond(cond, protected_body, protected_orelse) 257 258

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\util\deprecation.py 在 new_func(*args, **kwargs) 505 'in a future version' if date is None else ('after %s' % date), 506条指令) --> 507 返回函数(*args,**kwargs) 508 第509章 = _add_deprecated_arg_notice_to_docstring(

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\ops\control_flow_ops.py 在 cond(pred, true_fn, false_fn, strict, name, fn1, fn2) 1916 if (util.EnableControlFlowV2(ops.get_default_graph()) 和 1917
不是 context.executing_eagerly()): -> 1918 return cond_v2.cond_v2(pred, true_fn, false_fn, name) 1919 1920 # 我们需要做 true_fn/false_fn 关键字参数 对于

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\ops\cond_v2.py 在 cond_v2(pred,true_fn,false_fn,名称) 第84章 第85章 ---> 86 名称=范围) 87 88

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\ops\cond_v2.py 在 _build_cond(pred, true_graph, false_graph, true_inputs, false_inputs,名称) 185 个中间输出。 第186章 --> 187 _check_same_outputs(true_graph, false_graph) 188 189 # 将输入添加到 true_graph 和 false_graph 以使它们匹配。请注意

~.conda\envs\alphagpu\lib\site-packages\tensorflow\python\ops\cond_v2.py 在 _check_same_outputs(true_graph, false_graph) 584 错误(str(e)) 585 --> 586 断言 len(true_graph.outputs) == len(false_graph.outputs) 587 为 true_out,zip 中的 false_out(true_graph.outputs,false_graph.outputs): 588 如果 true_out.dtype != false_out.dtype:

断言错误:

【问题讨论】:

    标签: tensorflow


    【解决方案1】:

    看起来 2.0 还不能处理早期的条件返回。我想这将在某个时候得到修复(请随意检查是否有您自己的错误报告/文件),但与此同时,以下内容对我有用。它不允许提前退出,但至少应该给出正确的结果。

    @tf.function
    def mandelbrot(T, max_iter):
        out = max_iter
        for i in range(10):
            if (tf.abs(T)) >= 4:
                out = 5
    
        return out
    
    
    T = tf.complex(10.,2.)
    m = mandelbrot(T, 100)
    

    对于多个T 值,我认为您必须求助于调用tf.where

    def mandelbrot(T, max_iter):
        ones = tf.ones(tf.shape(T), dtype=tf.int64)
        out = ones * max_iter
        fives = ones * 5
    
        for i in range(10):
            out = tf.where(tf.greater_equal(tf.abs(T), 4), fives, out)
        return out
    

    您可以使用tf.while_looptf.TensorArray 做一些更复杂的事情,但我怀疑这会涉及到开销,这会使小问题的处理成本更高(而且代码复杂性也很重要)。

    请注意,这不是 mandelbrot 集的计算方式 - 我假设这是因为您已将其简化为最小示例。 T 从未在此处更新,因此您可以删除 i 上的循环。

    【讨论】:

    • 你说得对,我整理了代码以突出问题的核心。它在 2.0 的文档中说它支持 break、while 和所有主要的 Python 流控制语句。无论如何,我现在就按照你的修复。谢谢。
    • 如何让它适用于批量输入?所以 T 不是单个复数,而是它们的完整列表。特别是,当我将它与 4 进行比较时,我失败了。T>4 产生错误,如何将张量与标量进行比较,if 条件的结果取决于每个单独的比较!
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 2021-12-01
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-08-04
    • 2013-06-08
    • 1970-01-01
    相关资源
    最近更新 更多