【问题标题】:Sampling from multivariate normal distribution in JAX gives type error从 JAX 中的多元正态分布采样会产生类型错误
【发布时间】:2021-02-02 17:34:15
【问题描述】:

我正在尝试使用 JAX 从多元正态分布生成样本:

import jax
import jax.numpy as jnp
import numpy as np

key = random.PRNGKey(0)
cov = np.array([[1.2, 0.4], [0.4, 1.0]])
mean = np.array([3,-1])
x1,x2 = jax.random.multivariate_normal(key, mean, cov, 5000).T

但是,当我运行代码时,出现以下错误:

TypeError                                 Traceback (most recent call last)
<ipython-input-25-1397bf923fa4> in <module>()
      2 cov = np.array([[1.2, 0.4], [0.4, 1.0]])
      3 mean = np.array([3,-1])
----> 4 x1,x2 = jax.random.multivariate_normal(key, mean, cov, 5000).T

1 frames
/usr/local/lib/python3.6/dist-packages/jax/core.py in canonicalize_shape(shape)
   1159          "got {}.")
   1160   if any(isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
-> 1161          and not isinstance(get_aval(x), ConcreteArray) for x in shape):
   1162     msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
   1163             "smaller subfunctions.")

TypeError: 'int' object is not iterable

我不确定问题出在哪里,因为相同的语法适用于 Numpy 中的等效函数

【问题讨论】:

    标签: python numpy normal-distribution jax


    【解决方案1】:

    jax.random 模块中,大多数形状必须明确地是元组。所以不要使用形状5000,而是使用(5000,)

    x1,x2 = jax.random.multivariate_normal(key, mean, cov, (5000,)).T
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 2019-01-19
      • 1970-01-01
      • 2020-12-18
      • 2020-06-03
      • 1970-01-01
      • 2015-01-28
      相关资源
      最近更新 更多