【问题标题】:Training data algorithm not working. (Perceptron)训练数据算法不起作用。 (感知器)
【发布时间】:2017-04-19 05:24:32
【问题描述】:

我正在尝试编写一个算法来训练感知器,但似乎有超过 double 的最大值的值。从昨天开始我就一直想弄明白,但不能。

权重的值似乎超过了变量输出的值。

读入的文本文件格式为: Input variables and the output

/**
* Created by yafael on 12/3/16.
*/
import java.io.*;
import java.util.*;

public class Perceptron {

static double[] weights;
static ArrayList<Integer> inputValues;
static ArrayList<Integer> outputValues;
static int[] inpArray;
static int[] outArray;

public static int numberOfInputValues(String filePath)throws IOException
{
    Scanner valueScanner = new Scanner(new File(filePath));
    int num = valueScanner.nextInt();
    return num;
}

public static void inputs(String filePath)throws IOException
{
    inputValues = new ArrayList<Integer>();
    outputValues = new ArrayList<Integer>();
    Scanner valueScanner = new Scanner(new File(filePath));
    int num = valueScanner.nextInt();

    while (valueScanner.hasNext())
    {
        String temp = valueScanner.next();
        String[] values = temp.split(",");
        for(int i = 0; i < values.length; i++)
        {
            if(i+1 != values.length)
            {
                inputValues.add(Integer.parseInt(values[i]));
            }else
            {
                outputValues.add(Integer.parseInt(values[i]));
            }
        }

    }
    valueScanner.close();
}

public static void trainData(int[] inp, int[] out, int num,int epoch)
{
    weights = new double[num];
    Random r = new Random();
    int i,ep;
    int error = 0;
    /*
     * Initialize weights
     */
    for(i = 0; i < num; i++)
    {
        weights[i] = r.nextDouble();
    }

    for(ep = 1; ep<= epoch; ep++)
    {
        double totalError = 0;

        for(i = 0; i < inp.length/(num); i++)
        {
            double output = calculateOutput(inp, i, weights);
            System.out.println("Output " + (i + 1) + ": " + output);
            //System.out.println("Output: " + output);
            if(output > 0)
            {
                error = out[i] - 1;
            }else
            {
                error = out[i] - 0;
            }

            for(int temp = 0; temp < num; temp++)
            {
                double epCalc = (1000/(double)(1000+ep));
                weights[temp] += epCalc*error*inp[((i*weights.length)+temp)];
                //System.out.println("Epoch calculation: " + epCalc);
                //System.out.println("Output: " + output);
                //System.out.println("error: " + error);
                //System.out.println("input " + ((i*weights.length)+temp) + ": " + inp[(i*weights.length)+temp]);
            }
            totalError += (error*error);
        }
        //System.out.println("Total Error: " + totalError);

        if(totalError == 0)
        {
            System.out.println("In total error");
            for(int temp = 0; temp < num; temp++)
            {
                System.out.println("Weight " +(temp)+ ": " + weights[temp]);
            }

            double x = 0.0;
            for(i = 0; i < inp.length/(num); i++)
            {
                for(int j = 0; j < weights.length; j++)
                {
                    x = inp[((i*num) + j)] * weights[j];
                }
                System.out.println("Output " + (i+1) + ": " + x);
            }
            break;
        }

    }
    if(ep >= 10000)
    {
        System.out.println("Solution not found");
    }
}

public static double calculateOutput(int[] input, int start, double[] weights)
{
    start = start * weights.length;
    double sum = 0.0;
    for(int i = 0; i < weights.length; i++)
    {
        //System.out.println("input[" + (start + i) + "]: " + input[(start+i)]);
        //System.out.println("weights[i]" + weights[i]);
        sum += (double)input[(start + i)] * weights[i];
    }
    return sum - 1.0 ;
}
public static void main(String args[])throws IOException
{
    BufferedReader obj = new BufferedReader(new InputStreamReader(System.in));

    //Read the file path from the user
    String fileName;
    System.out.println("Please enter file path for Execution: ");
    fileName = obj.readLine();

    int numInputValues = numberOfInputValues(fileName);

    //Call the function to store values in the ArrayList<>
    inputs(fileName);
    inpArray = inputValues.stream().mapToInt(i->i).toArray();
    outArray = outputValues.stream().mapToInt(i->i).toArray();

    trainData(inpArray, outArray, numInputValues, 10000);
}
}

【问题讨论】:

  • 您是否遇到溢出错误?您面临的确切问题是什么?您的程序的哪个部分导致了您的问题?
  • @WasiAhmad 感谢您的评论。问题是,我的权重值应该保持在 0.0 - 1.0 的范围内。但它们呈指数级增长。经过几次迭代,方法calculateOutput 返回NaN
  • 你能分享你的完整代码以便我运行和检查吗?顺便说一句,在 else 块中扣除 0 from out[i] 有什么意义。此外,您可以写1000.0 / (1000.0 + ep) 而不是(1000/(double)(1000+ep))
  • 我这样做只是为了便于阅读和理解。一旦我让代码工作,我将清理它。再次感谢您的输入。另外,我已经更新了代码供您参考
  • num 中的trainData 函数是什么?是批量大小还是输入大小?如果是输入大小,那么为什么要从 0 to inp.length/num 运行 for 循环?此外,我可以看到为什么重量值会爆炸!因为您通过乘以 epCalc, output, error and inp[((i*weights.length)+temp)] 来更新权重值,这将是一个很大的值。在神经网络中,我们通常有一个学习参数,我们需要获取梯度并更新权重。您更新权重的方式不合适。我相信你对更新神经网络模型参数的理解是错误的!

标签: java algorithm neural-network


【解决方案1】:

我相信你的代码有问题,所以我给你一个简单的例子,但我相信你会从这段代码中得到帮助来解决你的问题。

import java.util.Random;

public class Perceptron {
    double[] weights;
    double threshold;
    public void Train(double[][] inputs, int[] outputs, double threshold, double lrate, int epoch) {
        this.threshold = threshold;
        int n = inputs[0].length;
        int p = outputs.length;
        weights = new double[n];
        Random r = new Random();

        //initialize weights
        for(int i=0;i<n;i++) {
            weights[i] = r.nextDouble();
        }

        for(int i=0;i<epoch;i++) {
            int totalError = 0;
            for(int j =0;j<p;j++) {
                int output = Output(inputs[j]);
                int error = outputs[j] - output;

                totalError +=error;

                for(int k=0;k<n;k++) {
                    double delta = lrate * inputs[j][k] * error;
                    weights[k] += delta;
                }
            }
            if(totalError == 0)
                break;
        }
    }

    public int Output(double[] input) {
        double sum = 0.0;
        for(int i=0;i<input.length;i++) {
            sum += weights[i]*input[i];
        }

        if(sum>threshold)
            return 1;
        else
            return 0;
    }

    public static void main(String[] args) {
        Perceptron p = new Perceptron();
        double inputs[][] = {{0,0},{0,1},{1,0},{1,1}};
        int outputs[] = {0,0,0,1};

        p.Train(inputs, outputs, 0.2, 0.1, 200);
        System.out.println(p.Output(new double[]{0,0})); // prints 0
        System.out.println(p.Output(new double[]{1,0})); // prints 0
        System.out.println(p.Output(new double[]{0,1})); // prints 0
        System.out.println(p.Output(new double[]{1,1})); // prints 1
    }
}

【讨论】:

    猜你喜欢
    • 2017-05-10
    • 2016-03-02
    • 2021-09-07
    • 2021-04-07
    • 2013-12-23
    • 2013-01-03
    • 2017-03-27
    • 2014-07-22
    • 2017-06-12
    相关资源
    最近更新 更多