【问题标题】:Java Backpropagation Algorithm is very slowJava反向传播算法非常慢
【发布时间】:2016-02-02 18:41:16
【问题描述】:

我有一个大问题。我尝试创建一个神经网络,并希望使用反向传播算法对其进行训练。我在这里http://mattmazur.com/2015/03/17/a-step-by-step-backpropagation-example/ 找到了本教程,并尝试用 Java 重新创建它。当我使用他使用的训练数据时,我得到了和他一样的结果。 如果没有反向传播,我的 TotalError 几乎和他的一样。当我像他一样使用反向传播 10 000 次时,我得到的错误几乎相同。但他使用 2 个输入神经元、2 个隐藏神经元和 2 个输出,但我想将这个神经网络用于 OCR,所以我肯定需要更多的神经元。但是,如果我使用例如 49 个输入神经元、49 个隐藏神经元和 2 个输出神经元,改变权重需要很长时间才能得到一个小错误。 (我相信这需要永远......)。我的学习率为 0.5。在我的网络的构造器中,我生成神经元并给它们提供与教程中相同的训练数据,为了用更多的神经元测试它,我给它们随机的权重、输入和目标。所以我不能将它用于许多神经元,它需要很长时间还是我的代码有问题?我应该增加学习率、偏差还是起始权重? 希望你能帮助我。

package de.Marcel.NeuralNetwork;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Random;

public class Network {
    private ArrayList<Neuron> inputUnit, hiddenUnit, outputUnit;

    private double[] inHiWeigth, hiOutWeigth;
    private double hiddenBias, outputBias;

    private double learningRate;

    public Network(double learningRate) {
        this.inputUnit = new ArrayList<Neuron>();
        this.hiddenUnit = new ArrayList<Neuron>();
        this.outputUnit = new ArrayList<Neuron>();

        this.learningRate = learningRate;

        generateNeurons(2,2,2);

        calculateTotalNetInputForHiddenUnit();
        calculateTotalNetInputForOutputUnit();
    }

    public double calcuteLateTotalError () {
        double e = 0;
        for(Neuron n : outputUnit) {
            e += 0.5 * Math.pow(Math.max(n.getTarget(), n.getOutput()) - Math.min(n.getTarget(), n.getOutput()), 2.0);
        }

        return e;
    }

    private void generateNeurons(int input, int hidden, int output) {
        // generate inputNeurons
        for (int i = 0; i < input; i++) {
            Neuron neuron = new Neuron();

            // for testing give each neuron an input
            if(i == 0) {
                neuron.setInput(0.05d);
            } else if(i == 1) {
                neuron.setOutput(0.10d);
            }

            inputUnit.add(neuron);
        }

        // generate hiddenNeurons
        for (int i = 0; i < hidden; i++) {
            Neuron neuron = new Neuron();

            hiddenUnit.add(neuron);
        }

        // generate outputNeurons
        for (int i = 0; i < output; i++) {
            Neuron neuron = new Neuron();

            if(i == 0) {
                neuron.setTarget(0.01d);
            } else if(i == 1) {
                neuron.setTarget(0.99d);
            }

            outputUnit.add(neuron);
        }

        // generate Bias
        hiddenBias = 0.35;
        outputBias = 0.6;

        // generate connections
        double startWeigth = 0.15;
        // generate inHiWeigths
        inHiWeigth = new double[inputUnit.size() * hiddenUnit.size()];
        for (int i = 0; i < inputUnit.size() * hiddenUnit.size(); i += hiddenUnit.size()) {
            for (int x = 0; x < hiddenUnit.size(); x++) {
                int z = i + x;
                inHiWeigth[z] = round(startWeigth, 2, BigDecimal.ROUND_HALF_UP);

                startWeigth += 0.05;
            }
        }

        // generate hiOutWeigths
        hiOutWeigth = new double[hiddenUnit.size() * outputUnit.size()];
        startWeigth += 0.05;
        for (int i = 0; i < hiddenUnit.size() * outputUnit.size(); i += outputUnit.size()) {
            for (int x = 0; x < outputUnit.size(); x++) {
                int z = i + x;
                hiOutWeigth[z] = round(startWeigth, 2, BigDecimal.ROUND_HALF_UP);

                startWeigth += 0.05;
            }
        }
    }

    private double round(double unrounded, int precision, int roundingMode)
    {
        BigDecimal bd = new BigDecimal(unrounded);
        BigDecimal rounded = bd.setScale(precision, roundingMode);
        return rounded.doubleValue();
    }

    private void calculateTotalNetInputForHiddenUnit() {
        // calculate totalnetinput for each hidden neuron
        for (int s = 0; s < hiddenUnit.size(); s++) {
            double net = 0;
            int x = (inHiWeigth.length / inputUnit.size());

            // calculate toAdd
            for (int i = 0; i < x; i++) {
                int v = i + s * x;
                double weigth = inHiWeigth[v];
                double toAdd = weigth * inputUnit.get(i).getInput();
                net += toAdd;
            }

            // add bias
            net += hiddenBias * 1;
            net = net *-1;
            double output =  (1.0 / (1.0 + (double)Math.exp(net)));
            hiddenUnit.get(s).setOutput(output);
        }
    }

    private void calculateTotalNetInputForOutputUnit() {
        // calculate totalnetinput for each hidden neuron
        for (int s = 0; s < outputUnit.size(); s++) {
            double net = 0;
            int x = (hiOutWeigth.length / hiddenUnit.size());

            // calculate toAdd
            for (int i = 0; i < x; i++) {
                int v = i + s * x;
                double weigth = hiOutWeigth[v];
                double outputOfH = hiddenUnit.get(s).getOutput();
                double toAdd = weigth * outputOfH;
                net += toAdd;
            }

            // add bias
            net += outputBias * 1;
            net = net *-1;
            double output = (double) (1.0 / (1.0 + Math.exp(net)));
            outputUnit.get(s).setOutput(output);
        }
    }

    private void backPropagate() {
        // calculate ouputNeuron weigthChanges
        double[] oldWeigthsHiOut = hiOutWeigth;
        double[] newWeights = new double[hiOutWeigth.length];
        for (int i = 0; i < hiddenUnit.size(); i += 1) {
            double together = 0;
            double[] newOuts = new double[hiddenUnit.size()];
            for (int x = 0; x < outputUnit.size(); x++) {
                int z = x * hiddenUnit.size() + i;
                double weigth = oldWeigthsHiOut[z];
                double target = outputUnit.get(x).getTarget();
                double output = outputUnit.get(x).getOutput();

                double totalErrorChangeRespectOutput = -(target - output);
                double partialDerivativeLogisticFunction = output * (1 - output);
                double totalNetInputChangeWithRespect = hiddenUnit.get(x).getOutput();
                double puttedAllTogether = totalErrorChangeRespectOutput * partialDerivativeLogisticFunction
                        * totalNetInputChangeWithRespect;
                double weigthChange = weigth - learningRate * puttedAllTogether;

                // set new weigth
                newWeights[z] = weigthChange;
                together += (totalErrorChangeRespectOutput * partialDerivativeLogisticFunction * weigth);
                double out = hiddenUnit.get(x).getOutput();
                newOuts[x] = out * (1.0 - out);
            }
            for (int t = 0; t < newOuts.length; t++) {
                inHiWeigth[t + i] = (double) (inHiWeigth[t + i] - learningRate * (newOuts[t] * together * inputUnit.get(t).getInput()));
            }
            hiOutWeigth = newWeights;
        }
    }
}

还有我的神经元课:

package de.Marcel.NeuralNetwork;

public class Neuron {
    private double input, output;
    private double target;

    public Neuron () {

    }

    public void setTarget(double target) {
        this.target = target;
    }

    public void setInput (double input) {
        this.input = input;
    }

    public void setOutput(double output) {
        this.output = output;
    }

    public double getInput() {
        return input;
    }

    public double getOutput() {
        return output;
    }

    public double getTarget() {
        return target;
    }
}

【问题讨论】:

  • 许多神经元需要很长时间
  • 您的代码似乎没问题。最多您可以用预加载的变量替换对 .size() 方法的每次调用,以避免成千上万的调用。然后认为深度神经网络可能需要几天的时间来训练。 Google 的 HPC 集群整天都在这样做
  • @FMiscia 但是我没有那么多神经元,如果我增加神经元并训练例如 900000 次,错误不会改变。 :)s
  • 误差随着时间的推移收敛到一个单一的值。如果错误没有改变,要么你每次都达到收敛,要么你的代码有错误。尝试用 10、100、1000 等次对其进行训练,看看每一步之间的误差变化有多大。
  • 有两层,你的复杂度就像 n^n,所以试着一步一步增加你的训练数量。

标签: java algorithm backpropagation


【解决方案1】:

想一想:你有 10,000 次传播通过 49->49->2 个神经元。在输入层和隐藏层之间,您有 49 * 49 个链接要传播,因此您的部分代码正在执行大约 2400 万次(10,000 * 49 * 49)。这需要时间。你可以尝试 100 次传播,看看需要多长时间,只是为了给你一个想法。

可以做一些事情来提高性能,例如使用普通数组而不是 ArrayList,但对于Code Review 站点来说,这是一个更好的主题。另外,不要指望这会带来巨大的改进。

【讨论】:

    【解决方案2】:

    您的反向传播代码的复杂度为 O(h*o + h^2) * 10000,其中 h 是隐藏神经元的数量,o 是输出神经元的数量。原因如下。

    您有一个循环执行所有隐藏的神经元...

    for (int i = 0; i < hiddenUnit.size(); i += 1) {
    

    ...包含对所有输出神经元执行的另一个循环...

    for (int x = 0; x < outputUnit.size(); x++) {
    

    ...还有一个额外的内部循环,对所有隐藏的神经元再次执行...

    double[] newOuts = new double[hiddenUnit.size()];
    for (int t = 0; t < newOuts.length; t++) {
    

    ...你执行所有这些一万次。在此之上添加 O(i + h + o) [初始对象创建] + O(i*h + o*h) [初始权重] + O(h*i) [计算净输入] + O(h* o) [计算净产出]。

    难怪它需要永远;您的代码中到处都是嵌套循环。如果您希望它运行得更快,请将这些因素排除在外 - 例如,结合对象创建和初始化 - 或减少神经元的数量。但是显着减少反向传播调用的数量是让这个运行更快的最好方法。

    【讨论】:

      猜你喜欢
      • 2012-03-14
      • 2017-06-09
      • 2014-01-10
      • 2013-11-28
      • 1970-01-01
      • 1970-01-01
      • 2016-10-16
      • 2017-05-12
      相关资源
      最近更新 更多