【发布时间】:2015-10-07 00:50:49
【问题描述】:
我正在尝试使用 PyMC3 实现 lda。
但是,在定义模型的最后一部分时,根据主题对单词进行采样,我不断收到错误:TypeError: list indices must be integers, not TensorVariable
如何解决问题?
代码如下:
## Data Preparation
K = 2 # number of topics
N = 4 # number of words
D = 3 # number of documents
import numpy as np
data = np.array([[1, 1, 1, 1], [1, 1, 1, 1], [0, 0, 0, 0]])
Wd = [len(doc) for doc in data] # length of each document
## Model Specification
from pymc3 import Model, Normal, HalfNormal, Dirichlet, Categorical, constant
lda_model = Model()
with lda_model:
# Priors for unknown model parameters
alpha = HalfNormal('alpha', sd=1)
eta = HalfNormal('eta', sd=1)
a1 = eta*np.ones(shape=N)
a2 = alpha*np.ones(shape=K)
beta = [Dirichlet('beta_%i' % i, a1, shape=N) for i in range(K)]
theta = [Dirichlet('theta_%s' % i, a2, shape=K) for i in range(D)]
z = [Categorical('z_%i' % d, p = theta[d], shape=Wd[d]) for d in range(D)]
# That's when you get the error. It is caused by: beta[z[d][w]]
w = [Categorical('w_%i_%i' % (d, w), p = beta[z[d][w]], observed = data[i,j]) for d in range(D) for w in range(Wd[d])]
任何帮助将不胜感激!
【问题讨论】:
-
我会把它变成二维数组而不是创建列表。不幸的是,当我尝试这样做时,我遇到了github.com/pymc-devs/pymc3/issues/792。解决此问题后,我们应该再试一次。