【问题标题】:java.lang.IllegalArgumentException: Matrix inner dimensions must agreejava.lang.IllegalArgumentException:矩阵内部尺寸必须一致
【发布时间】:2018-03-14 11:51:15
【问题描述】:

这是我的代码:

package algorithms;
import Jama.Matrix;
import java.io.File;
import java.util.Arrays;
public class ThetaGetter {
    //First column is one, second is price and third is BHK
    private static double[][] variables = {
        {1,1130,2},
        {1,1100,2},
        {1,2055,3},
        {1,1047,2},
        {1,1927,3},
        {1,2667,3},
        {1,1146,2},
        {1,2020,3},
        {1,1190,2},
        {1,2165,3},
        {1,1250,2},
        {1,1185,2},
        {1,2825,4},
        {1,1200,2},
        {1,1580,3},
        {1,3200,3},
        {1,715,1},
        {1,1270,2},
        {1,2403,3},
        {1,1465,3},
        {1,1345,2}
    };

    private static double[][] prices = {
        {69.65},
        {60},
        {115},
        {55},
        {140},
        {225},
        {76.78},
        {120},
        {73.11},
        {140},
        {56},
        {79.39},
        {161},
        {73.69},
        {80},
        {145},
        {34.87},
        {77.72},
        {165},
        {98},
        {82}
    };
    private static Matrix X = new Matrix(variables);
    private static Matrix y = new Matrix(prices);
    public static void main(String[] args) {
        File file = new File("theta.dat");
        if(file.exists()){
            System.out.println("Theta has already been calculated!");
            return;
        }
        //inverse(Tra(X)*X)*tra(X)*y
        Matrix transposeX = X.transpose();
        Matrix inverse = X.times(transposeX).inverse();
        System.out.println(y.getArray().length);
        System.out.println(X.getArray().length);
        Matrix test = inverse.times(transposeX);
        Matrix theta = test.times(y);
        System.out.println(Arrays.deepToString(theta.getArray()));
    }
}

这个算法基本上是尝试获取房价,然后得到一些常数,然后用这些常数来猜测房价。但是我在'Matrix theta = test.times(y);'这一行遇到了一个异常错误消息几乎就是问题所在。尺寸有什么问题吗?两个都有21个项目,所以不知道是怎么回事。

【问题讨论】:

    标签: java matrix machine-learning jama


    【解决方案1】:

    您所犯的错误在以下代码行中:

    Matrix inverse = X.times(transposeX).inverse();
    

    你上面评论的公式是:

    //inverse(Tra(X)*X)*tra(X)*y
    

    但您在代码中实际计算的是: (X*Tra(X) 而不是 Tra(X)*X)

    //inverse(X*Tra(X))*tra(X)*y
    

    如果 X 的维度是 (m,n) 其中

    • m = 行数
    • n = 列数

    Y 的维度是 (m,1),使用上面使用的乘法,您将得到以下结果:

    逆(X * Tra(X)) *Tra(X)*Y = 逆 * Tra(X) * Y = 结果 * y

    inverse((m,n)(n,m))(n,m)*(m,1)= (m,m) * (n,m) => 结果在错误中,因为矩阵乘法的内部维度必须相等

    修复您的代码的方法是替换以下行:

    Matrix inverse = X.times(transposeX).inverse();
    

     Matrix inverse = transposeX.times(X).inverse();
    

    【讨论】:

      猜你喜欢
      • 1970-01-01
      • 2014-11-04
      • 2012-07-12
      • 2011-12-05
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      • 1970-01-01
      相关资源
      最近更新 更多