【问题标题】:Using mixture of Multivariate Normal distributions with Tensorflow-probability.layers使用多元正态分布与 Tensorflow-probability.layers 的混合
【发布时间】:2020-02-12 02:38:23
【问题描述】:

我正在尝试使用张量流概率层来创建多元正态分布的混合。当我为此使用 IndependentNormal 层时,它工作正常,但是当我使用 MultivariateNormalTriL 层时,我遇到了 event_shape 的问题。我将这些层与 MixtureSameFamily 层结合起来。以下代码应该可以很好地说明我的问题,并且应该可以在 google colab 中使用:

import tensorflow as tf
import tensorflow_probability as tfp
import tensorflow.keras as keras
tfpl = tfp.layers

print(tf.__version__)
# >> '1.15.0-rc3'
# but I get the same result with extra warnings in 1.14.0

print(tfp.__version__)
# >> '0.7.0'

print(tfpl.MultivariateNormalTriL(100)(
    keras.layers.Input(shape=tfpl.MultivariateNormalTriL.params_size(100))
))

# >> tfp.distributions.MultivariateNormalTriL("multivariate_normal_tri_l_4/MultivariateNormalTriL/MultivariateNormalTriL/", 
#    batch_shape=[?], event_shape=[100], dtype=float32)


print(tfpl.IndependentNormal((100,))(
    keras.layers.Input(shape=(tfpl.IndependentNormal.params_size(100),))
))

# >> tfp.distributions.Independent("Independentindependent_normal_2/IndependentNormal/Normal/", 
#    batch_shape=[?], event_shape=[100], dtype=float32)


print(tfpl.MixtureSameFamily(16, tfpl.MultivariateNormalTriL(100))(
    keras.layers.Input(shape=(16*tfpl.MultivariateNormalTriL.params_size(100),))
))

# >> tfp.distributions.MixtureSameFamily("mixture_same_family_2/MixtureSameFamily/MixtureSameFamily/", 
#    batch_shape=[?], event_shape=[?], dtype=float32)


print(tfpl.MixtureSameFamily(16, tfpl.IndependentNormal((100,)))(
    keras.layers.Input(shape=(16*tfpl.IndependentNormal.params_size(100,),))
))

# >> tfp.distributions.MixtureSameFamily("mixture_same_family_3/MixtureSameFamily/MixtureSameFamily/", 
#    batch_shape=[?], event_shape=[100], dtype=float32)

尽管 MultivariateNormalTriL 和 IndependentNormal 具有相同的 batch_shape 和 event_shape,但将它们与 MixtureSameFamily 组合会导致不同的事件形状。

所以我的问题是:为什么它们会导致不同的事件形状,以及如何为具有不同(不一定是对角)协方差矩阵和 event_shape=[100] 的多元正态分布混合获得一个层?

编辑:张量流概率版本 0.8 也是如此

【问题讨论】:

    标签: python tensorflow keras mixture-model tensorflow-probability


    【解决方案1】:

    我误解了 MixtureSameFamily 层是如何工作的,所以在阅读了所有涉及层的代码后,我想出了以下解决方案:

    import tensorflow as tf
    import tensorflow_probability as tfp
    import tensorflow.compat.v1 as tf1
    import numpy as np
    
    tfl = tfp.layers
    tfd = tfp.distributions
    
    
    class MixtureMultivariateNormalTriL(tfl.DistributionLambda):
        """ Creates a mixture of multivariate normal distributions through tfd.Mixture """
    
        def __init__(self, num_components, event_size, validate_args=False, scale='default', **kwargs):
            """
            Initialize the MixtureMultivariateNormalTriL layer
            :param num_components: Number of component distributions in the mixture (int)
            :param event_size: Scalar `int` representing the size of single draw from this
            distribution.
            :param validate_args: Python `bool`, default `False`. When `True` distribution
            parameters are checked for validity despite possibly degrading runtime
            performance. When `False` invalid inputs may silently render incorrect
            outputs.
            Default value: False
            :param scale: type of tfp.bijectors.ScaleTriL used for the multivariate normal distribution.
            If 'default', we use tfp.bijectors.ScaleTriL(
                    diag_shift=np.array(1e-5, params.dtype.as_numpy_dtype()),
                    validate_args=validate_args)
                (using the same convention as in tfpl.MultivariateNormalTriL)
            If `exponential`, we use scale_tril = tfp.bijectors.ScaleTriL(
                    diag_bijector=tfp.bijectors.Exp(),
                    diag_shift=None,
                    validate_args=validate_args
                )
            Alternatively a tfp.bijectors.ScaleTriL object can be passed.
            Default value: "default"
            """
            kwargs.pop('make_distribution_fn', None)
    
            super().__init__(
                lambda t: MixtureMultivariateNormalTriL.new(t, num_components, event_size, validate_args, scale),
                **kwargs
            )
            self._event_size = event_size
            self._num_components = num_components
            self._validate_args = False
            self._scale = scale
    
        @staticmethod
        def new(params, num_components, event_size, validate_args=False, scale='default', name=None):
            #  we expect params to be of shape (batch_size, num_components, component_params_shape)
            with tf1.name_scope(name, 'MixtureMultivariateNormalTriL',
                                [params, num_components, event_size]):
                params = tf.convert_to_tensor(value=params, name='params', dtype_hint=tf.float32)
    
                num_components = tf.convert_to_tensor(
                    value=num_components, name='num_components', dtype_hint=tf.int32)
    
                mixture_dist = tfd.Categorical(logits=params[..., :num_components])
    
                component_params = tf.reshape(
                    params[..., num_components:],
                    tf.concat([tf.shape(input=params)[:-1], [num_components, -1]],
                              axis=0))  # the parameters for the various components
    
                params_per_component = tf.unstack(component_params, axis=1)
    
                if scale == "default":
                    scale_tril = tfp.bijectors.ScaleTriL(
                        diag_shift=np.array(1e-5, params.dtype.as_numpy_dtype()),
                        validate_args=validate_args)  # use same conventions as MultivariateNormalTriL
                elif scale == "exponential":
                    scale_tril = tfp.bijectors.ScaleTriL(
                        diag_bijector=tfp.bijectors.Exp(validate_args=validate_args),
                        diag_shift=None,
                        validate_args=validate_args
                    )
                else:
                    assert isinstance(scale, tfp.bijectors.ScaleTriL)
                    scale_tril = scale
    
                # for some reason, tfp doesn't manage to infer the event_shape of out distributions
                # putting applying the following bijector helps remedy this
                reshape = tfp.bijectors.Reshape(event_shape_out=(event_size,))
    
                distributions = [
                    reshape(
                        tfd.MultivariateNormalTriL(
                            loc=par[..., :event_size],
                            scale_tril=scale_tril(par[..., event_size:]),
                            validate_args=validate_args
                        )
                    )
                    for par in params_per_component
                ]
    
                return tfd.Mixture(
                    mixture_dist,
                    distributions,
                    validate_args=validate_args
                )
    
        @staticmethod
        def params_size(num_components, event_size, name=None):
            with tf1.name_scope(name, "MixtureMultivariateNormalTriL_params_size",
                                [num_components, event_size]):
                return num_components + num_components * tfl.MultivariateNormalTriL.params_size(event_size)
    
        def get_config(self):
            base_config = super().get_config()
            base_config["num_components"] = self._num_components
            base_config["event_size"] = self._event_size
            base_config["scale"] = self._scale
            base_config["validate_args"] = self._validate_args
            return base_config
    
    

    不过,我仍在努力对其进行全面测试。

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 2021-07-20
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2016-04-17
      • 2020-12-04
      • 1970-01-01
      相关资源
      最近更新 更多