【发布时间】:2017-04-21 14:45:46
【问题描述】:
我一直在将 Matlab 转换为 Python 时遇到问题。我在去年编写的 Matlab 中有代码(正在工作),现在尝试将这些函数转换为 Python。其中 5 个有效,4 个无效。我真的被困住了,希望得到一些帮助。 这是关于估计朴素贝叶斯概率的。这是 Matlab 中的函数:
function [ p_x_y ] = estimate_p_x_y_NB(Xtrain,ytrain,a,b )
% Function calculates probability distribution p(x|y), assuming that x is binary
% and its elements are independent from each other
% Xtrain - training dataset NxD
% ytrain - training dataset class labels 1xN
% p_x_y - binomial distribution estimators - element at position(m,d)
% represents estimator p(x_d=1|y=m) MxD
% N - number of elements in training dataset
D = size(Xtrain,2);
M = length(unique(ytrain));
p_x_y = zeros(M,D);
for i=1:M
for j=1:D
numerator = sum((ytrain==i).*((Xtrain(:,j)==1))')+a-1;
denominator = sum(ytrain==i)+a+b-2;
p_x_y(i,j) = numerator/denominator;
end
end
end
这是我对 Python 的翻译:
def estimate_p_x_y_nb(Xtrain, ytrain, a, b):
"""
:param Xtrain: training data NxD
:param ytrain: class labels for training data 1xN
:param a: parameter a of Beta distribution
:param b: parameter b of Beta distribution
:return: Function calculated probality p(x|y) assuming that x takes binary values and elements
x are independent from each other. Function returns matrix p_x_y that has size MxD.
"""
D = Xtrain.shape[1]
M = len(np.unique(ytrain))
p_x_y = np.zeros((M, D))
for i in range (M):
for j in range(D):
up = np.sum((ytrain == i+1).dot((Xtrain[:, j]==1)).conjugate().T) + a - 1
down = np.sum((ytrain == i+1) + a + b -2)
p_x_y[i,j] = up/down
return p_x_y
追溯:
p_x_y[i,j] = up/down
ValueError: setting an array element with a sequence.
如果您发现该功能有任何问题,我会非常乐意指出。另外,我在up 变量中使用了.dot 而不仅仅是*,因为当它是* 时,我收到了一个关于尺寸不准确的错误,但有了这个,它似乎可以正常工作。谢谢。
【问题讨论】:
-
您是否尝试过将 MATLAB 代码的结果与 Python 在每一行中获得的结果进行比较,看看问题出在哪里?这是缩小问题范围的简单方法。或提供Minimal, Complete, and Verifiable Example。
-
用点产品
dot替换元素产品.*可能是错误的。 numpy 中的元素乘积是*。 -
@kazemakase 好吧,所以你说它应该保留 * ? (在python中?)。然而,当我这样离开时,我得到一个“ValueError:尺寸不匹配”。我想这是因为后来的转置,但它在 Matlab 中工作:(
-
是的。使用
*并找出错误的原因,而不是使用dot使症状消失 :) 我不知道是什么导致了错误——这取决于数组的形状。很可能是由于 Matlab 和 Python 中的广播规则不同。将表达式拆开,找出导致错误的确切原因。