【问题标题】:multivariateNormal distribution with n-batch > 1n-batch > 1 的多元正态分布
【发布时间】:2019-07-28 07:24:11
【问题描述】:

我试图将How to use a MultiVariateNormal distribution in the latest version of Tensorflow 中给出的示例推广到二维的正态分布,但不止一批。当我运行以下命令时:

from tensorflow_probability import distributions as tfd
import tensorflow as tf

tf.compat.v1.enable_eager_execution()

mu = [[1, 2],
        [-1,-2]]

cov = [[1, 3./5],
        [3./5, 2]]

cov = [cov, cov] # for demonstration purpose, use same cov for both batches

mvn = tfd.MultivariateNormalFullCovariance(
        loc=mu,
        covariance_matrix=cov)

# generate the pdf
X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))
idx = tf.concat([tf.reshape(X, [-1, 1]), tf.reshape(Y,[-1,1])], axis =1)
prob = tf.reshape(mvn.prob(idx), tf.shape(X))

我收到不兼容的形状错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Incompatible shapes: [3600,2] vs. [2,2] [Op:Sub] name: MultivariateNormalFullCovariance/log_prob/affine_linear_operator/inverse/sub/

我对文档 (https://www.tensorflow.org/api_docs/python/tf/contrib/distributions/MultivariateNormalFullCovariance) 的理解是,要计算 pdf,需要一个 [n_observation, n_dimensions] 张量(本例中就是这种情况:idx.shape = TensorShape([Dimension(3600), Dimension(2)]))。是不是我的数学算错了?

【问题讨论】:

    标签: python tensorflow tensorflow-probability


    【解决方案1】:

    您需要在倒数第二个位置的 idx 张量中添加一个批处理轴,因为 60x60 无法针对 mvn.batch_shape(2,) 广播。

    # TF/TFP Imports
    !pip install --quiet tfp-nightly tf-nightly
    import tensorflow.compat.v2 as tf
    tf.enable_v2_behavior()
    import tensorflow_probability as tfp
    tfd = tfp.distributions
    
    mu = [[1, 2],
          [-1, -2]]
    
    cov = [[1, 3./5],
           [3./5, 2]]
    
    cov = [cov, cov] # for demonstration purpose, use same cov for both batches
    
    mvn = tfd.MultivariateNormalFullCovariance(
        loc=mu, covariance_matrix=cov)
    print(mvn.batch_shape, mvn.event_shape)
    
    # generate the pdf
    X, Y = tf.meshgrid(tf.range(-3, 3, 0.1), tf.range(-3, 3, 0.1))
    print(X.shape)
    idx = tf.stack([X, Y], axis=-1)[..., tf.newaxis, :]
    print(idx.shape)
    
    probs = mvn.prob(idx)
    print(probs.shape)
    

    输出:

    (2,) (2,)   # mvn.batch_shape, mvn.event_shape
    (60, 60)    # X.shape
    (60, 60, 1, 2)   # idx.shape == X.shape + (1 "broadcast against batch", 2 "event")
    (60, 60, 2)  # probs.shape == X.shape + (2 "mvn batch shape")
    

    【讨论】:

    • 谢谢。现在对我来说很有意义!我最终想要的是混合这两种发行版。 tfd.MixtureSameFamily 成功了
    猜你喜欢
    • 1970-01-01
    • 2020-09-25
    • 2015-03-05
    • 2015-09-05
    • 2021-07-25
    • 2018-09-01
    • 1970-01-01
    • 1970-01-01
    • 2021-11-14
    相关资源
    最近更新 更多