【发布时间】:2018-04-24 01:33:53
【问题描述】:
令 A 为 (nxm)-矩阵,M 为 (mxm)-矩阵。为矩阵的迹写 tr(),我需要计算 tr(AM(A^T))。但是,最终的跟踪操作会丢弃大部分计算。我可以使用 numpy 或 pytorch 的广播规则只计算 AM(A^T) 的必要对角线吗?
更新: 这是我在 PyTorch 中计算对角线的解决方案:
torch.sum(torch.sum(A.t()[:,None,:]*M[:,:,None],0)*A.t(),0)
【问题讨论】:
标签: python pytorch array-broadcasting