【问题标题】:How to test the correctness of a Keras custom layer?如何测试 Keras 自定义层的正确性?
【发布时间】:2018-07-21 09:34:37
【问题描述】:

创建具有训练权重的 Keras 自定义层后,如何测试代码的正确性? Keras 的手册中似乎没有描述。

例如,要测试函数的预期行为,可以编写单元测试。我们如何为 Keras 自定义层做到这一点?

【问题讨论】:

    标签: deep-learning keras keras-layer


    【解决方案1】:

    您仍然可以通过获取给定输入的自定义层的输出并根据手动计算的输出验证它来执行单元测试之类的操作,

    假设您的自定义层 Custom 将 (None, 3, 200) 作为输入形状并返回 (None, 3)

    from keras.layers import Input
    from keras.models import Model
    
    inp = Input(shape=(3, 200))
    out = Custom()(inp)
    model = Model(inp, out)
    
    output = model.predict(your_input)
    

    您可以使用已知输入your_input 的预期输出来验证层输出output

    【讨论】:

    • 我的图层有在每个call 之后更新的变量。有没有办法在我的层中的每个call 之后进行测试?
    • 您甚至可以在 keras 提供的回调中使用此代码。因此,编写一个自定义回调并在其中使用上述逻辑进行测试。
    【解决方案2】:

    layer_test 在 keras 实用程序中。 https://github.com/keras-team/keras/blob/master/keras/utils/test_utils.py

    他们提供以下代码,测试形状、实际结果、序列化和训练:

    def layer_test(layer_cls, kwargs={}, input_shape=None, input_dtype=None,
                   input_data=None, expected_output=None,
                   expected_output_dtype=None, fixed_batch_size=False):
        """Test routine for a layer with a single input tensor
        and single output tensor.
        """
        # generate input data
        if input_data is None:
            assert input_shape
            if not input_dtype:
                input_dtype = K.floatx()
            input_data_shape = list(input_shape)
            for i, e in enumerate(input_data_shape):
                if e is None:
                    input_data_shape[i] = np.random.randint(1, 4)
            input_data = (10 * np.random.random(input_data_shape))
            input_data = input_data.astype(input_dtype)
        else:
            if input_shape is None:
                input_shape = input_data.shape
            if input_dtype is None:
                input_dtype = input_data.dtype
        if expected_output_dtype is None:
            expected_output_dtype = input_dtype
    
        # instantiation
        layer = layer_cls(**kwargs)
    
        # test get_weights , set_weights at layer level
        weights = layer.get_weights()
        layer.set_weights(weights)
    
        expected_output_shape = layer.compute_output_shape(input_shape)
    
        # test in functional API
        if fixed_batch_size:
            x = Input(batch_shape=input_shape, dtype=input_dtype)
        else:
            x = Input(shape=input_shape[1:], dtype=input_dtype)
        y = layer(x)
        assert K.dtype(y) == expected_output_dtype
    
        # check with the functional API
        model = Model(x, y)
    
        actual_output = model.predict(input_data)
        actual_output_shape = actual_output.shape
        for expected_dim, actual_dim in zip(expected_output_shape,
                                            actual_output_shape):
            if expected_dim is not None:
                assert expected_dim == actual_dim
    
        if expected_output is not None:
            assert_allclose(actual_output, expected_output, rtol=1e-3)
    
        # test serialization, weight setting at model level
        model_config = model.get_config()
        recovered_model = model.__class__.from_config(model_config)
        if model.weights:
            weights = model.get_weights()
            recovered_model.set_weights(weights)
            _output = recovered_model.predict(input_data)
            assert_allclose(_output, actual_output, rtol=1e-3)
    
        # test training mode (e.g. useful when the layer has a
        # different behavior at training and testing time).
        if has_arg(layer.call, 'training'):
            model.compile('rmsprop', 'mse')
            model.train_on_batch(input_data, actual_output)
    
        # test instantiation from layer config
        layer_config = layer.get_config()
        layer_config['batch_input_shape'] = input_shape
        layer = layer.__class__.from_config(layer_config)
    
        # for further checks in the caller function
        return actual_output
    

    【讨论】:

      猜你喜欢
      • 2020-01-04
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2022-11-29
      • 2011-01-01
      • 2019-12-21
      • 2022-06-13
      相关资源
      最近更新 更多