【问题标题】:Compute trace of matrix product using numpy/pytorch broadcasting使用 numpy/pytorch 广播计算矩阵乘积的轨迹
【发布时间】:2018-04-24 01:33:53
【问题描述】:

令 A 为 (nxm)-矩阵,M 为 (mxm)-矩阵。为矩阵的迹写 tr(),我需要计算 tr(AM(A^T))。但是,最终的跟踪操作会丢弃大部分计算。我可以使用 numpy 或 pytorch 的广播规则只计算 AM(A^T) 的必要对角线吗?

更新: 这是我在 PyTorch 中计算对角线的解决方案:

torch.sum(torch.sum(A.t()[:,None,:]*M[:,:,None],0)*A.t(),0)

【问题讨论】:

标签: python pytorch array-broadcasting


【解决方案1】:

您必须至少计算两个矩阵乘积之一。随后,您可以在此处使用其中一个答案:What is the best way to compute the trace of a matrix product in numpy?

【讨论】:

    最近更新 更多