【问题标题】:Tensorflow-probability transform event shape of JointDistributionJointDistribution的Tensorflow-概率变换事件形状
【发布时间】:2021-02-05 15:50:42
【问题描述】:

我想为事件形状为 n 的 n 个分类变量 C_1、..、C_n 创建一个分布。使用 JointDistributionSequentialAutoBatched 事件维度是一个列表 [[],..,[]]。例如对于 n=2

import tensorflow_probability.python.distributions as tfd

probs = [
    [0.8, 0.2], # C_1 in {0,1}
    [0.3, 0.3, 0.4] # C_2 in {0,1,2}
    ]

D = tfd.JointDistributionSequentialAutoBatched([tfd.Categorical(probs=p) for p in probs])

>>> D
<tfp.distributions.JointDistributionSequentialAutoBatched 'JointDistributionSequentialAutoBatched' batch_shape=[] event_shape=[[], []] dtype=[int32, int32]>

如何重塑它以获得事件形状 [2]?

【问题讨论】:

    标签: tensorflow-probability


    【解决方案1】:

    这里可以使用几种不同的方法:

    1. 创建一批分类分布,然后使用tfd.Independent 将批次维度重新解释为事件:
    vector_dist = tfd.Independent(
      tfd.Categorical(
        probs=[
          [0.8, 0.2, 0.0],  # C_1 in {0,1}
          [0.3, 0.3, 0.4]  # C_2 in {0,1,2}
        ]),
      reinterpreted_batch_ndims=1)
    

    在这里,我添加了一个额外的零来填充 probs,以便两个分布都可以由单个 Categorical 对象表示。

    1. 使用 Blockwise 分布,它将其分量分布填充到单个向量中(与 JointDistribution 类相反,后者将它们作为单独的值返回):
    vector_dist = tfd.Blockwise([tfd.Categorical(probs=p) for p in probs])
    
    1. 最接近您问题的直接答案是将Split 双射器(其逆为Concat)应用于联合分布:
    tfb = tfp.bijectors
    D = tfd.JointDistributionSequentialAutoBatched(
      [tfd.Categorical(probs=[p] for p in probs])
    vector_dist = tfb.Invert(tfb.Split(2))(D)
    

    请注意,我不得不尴尬地写 probs=[p] 而不仅仅是 probs=p。这是因为Concat 双射器和tf.concat 一样,不能改变其参数的张量等级——它可以将小向量连接成一个大向量,但不能将标量连接成一个向量——所以我们必须确保其输入本身是向量。如果 TFP 具有类似于 tf.stack / tf.unstackStack 双射器,则可以避免这种情况(目前没有,但没有理由不存在)。

    【讨论】:

    • 感谢您指出 tfd.Blockwise,在我的代码中,我实际上使用的是带有可训练概率的 JointDistributionCoroutine,因此我无法使用解决方案 1。我最终在其文档字符串中使用了 tfd.Blockwise:joint =tfd.JointDistributionCoroutine(model), d = tfd.Blockwise(joint) `
    猜你喜欢
    • 1970-01-01
    • 1970-01-01
    • 1970-01-01
    • 2019-10-07
    • 1970-01-01
    • 1970-01-01
    • 2023-03-30
    • 1970-01-01
    • 2018-02-14
    相关资源
    最近更新 更多