2017-01-21 1 views
1

학교 프로젝트 용 OCR 프로그램을 만들어야하므로 위키 백과의 도움으로 Backpropagation 알고리즘을 만들었습니다. 내 네트워크를 훈련시키기 위해 며칠 전에 추출한 MNIST 데이터베이스를 사용하여 실제 이미지 파일을 갖습니다. 그러나 이제 오류는 항상 약 237이며 잠시 훈련을 받으면 오류 및 가중치가 NaN이됩니다. 내 코드에 어떤 문제가 있습니까?신경망 : Backpropagation이 작동하지 않음 (Java)

A screenshot of my images folder

가 여기 내 네트워크를 훈련해야 내 주요 클래스입니다 :

package de.Marcel.NeuralNetwork; 

public class Neuron { 
    private double input, output; 

public Neuron() { 

} 

public void setInput(double input) { 
    this.input = input; 
} 

public void setOutput(double output) { 
    this.output = output; 
} 

public double getInput() { 
    return input; 
} 

public double getOutput() { 
    return output; 
} 

}

:

package de.Marcel.NeuralNetwork; 

import java.awt.Color; 
import java.awt.image.BufferedImage; 
import java.io.File; 
import java.io.IOException; 

import javax.imageio.ImageIO; 

public class OCR { 
    public static void main(String[] args) throws IOException { 
     // create network 
     NeuralNetwork net = new NeuralNetwork(784, 450, 5, 0.2); 

    // load Images 
    File file = new File("images"); 

    int images= 0; 
    double error = 0; 
    for (File f : file.listFiles()) { 
     BufferedImage image = ImageIO.read(f); 

     int t = -1; 
     double[] pixels = new double[784]; 
     for (int x = 0; x < image.getWidth(); x++) { 
      for (int y = 0; y < image.getHeight(); y++) { 
       t++; 
       Color c = new Color(image.getRGB(x, y)); 

       if (c.getRed() == 0 && c.getGreen() == 0 && c.getBlue() == 0) { 
        pixels[t] = 1; 
       } else if (c.getRed() == 255 && c.getGreen() == 255 && c.getBlue() == 255) { 
        pixels[t] = 0; 
       } 
      } 
     } 

     try { 
      if (f.getName().startsWith("1")) { 
       net.learn(pixels, new double[] { 1, 0, 0, 0, 0 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("2")) { 
       net.learn(pixels, new double[] { 0, 1, 0, 0, 0 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("3")) { 
       net.learn(pixels, new double[] { 0, 0, 1, 0, 0 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("4")) { 
       net.learn(pixels, new double[] { 0, 0, 0, 1, 0 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("5")) { 
       net.learn(pixels, new double[] { 0, 0, 0, 0, 1 }); 
       error += net.getError(); 

       images++; 
      } else if (f.getName().startsWith("6")) { 
       break; 
      } 
     } catch (Exception e) { 
      e.printStackTrace(); 
     } 
    } 

    error = error/iterations; 

    System.out.println("Trained images: " + images); 
    System.out.println("Error: " + error); 

    //save 
    System.out.println("Save"); 
    try { 
     net.saveNetwork("network.nnet"); 
    } catch (Exception e) { 
     e.printStackTrace(); 
    } 
} 
} 

...이 내 신경 세포 클래스입니다. .. 그리고 마지막으로 내 NeuralNetwork

package de.Marcel.NeuralNetwork; 

import java.io.File; 
import java.io.FileWriter; 
import java.util.Random; 

public class NeuralNetwork { 
    private Neuron[] inputNeurons, hiddenNeurons, outputNeurons; 
    private double[] weightMatrix1, weightMatrix2; 
    private double learningRate, error; 

public NeuralNetwork(int inputCount, int hiddenCount, int outputCount, double learningRate) { 
    this.learningRate = learningRate; 

    // create Neurons 
    // create Input 
    this.inputNeurons = new Neuron[inputCount]; 
    for (int i = 0; i < inputCount; i++) { 
     this.inputNeurons[i] = new Neuron(); 
    } 
    // createHidden 
    this.hiddenNeurons = new Neuron[hiddenCount]; 
    for (int i = 0; i < hiddenCount; i++) { 
     this.hiddenNeurons[i] = new Neuron(); 
    } 
    // createOutput 
    this.outputNeurons = new Neuron[outputCount]; 
    for (int i = 0; i < outputCount; i++) { 
     this.outputNeurons[i] = new Neuron(); 
    } 

    // create weights 
    Random random = new Random(); 
    // weightMatrix1 
    this.weightMatrix1 = new double[inputCount * hiddenCount]; 
    for (int i = 0; i < inputCount * hiddenCount; i++) { 
     this.weightMatrix1[i] = (random.nextDouble() * 2 - 1)/0.25; 
    } 
    // weightMatrix2 
    this.weightMatrix2 = new double[hiddenCount * outputCount]; 
    for (int i = 0; i < hiddenCount * outputCount; i++) { 
     this.weightMatrix2[i] = (random.nextDouble() * 2 - 1)/0.25; 
    } 
} 

public void calculate(double[] input) throws Exception { 
    // verfiy input length 
    if (input.length == inputNeurons.length) { 
     // forwardPropagation 
     // set input array as input and output of input neurons 
     for (int i = 0; i < input.length; i++) { 
      inputNeurons[i].setInput(input[i]); 
      inputNeurons[i].setOutput(input[i]); 
     } 

     // calculate output of hiddenNeurons 
     for (int h = 0; h < hiddenNeurons.length; h++) { 
      Neuron hNeuron = hiddenNeurons[h]; 
      double totalInput = 0; 

      // sum up totalInput of Neuron 
      for (int i = 0; i < inputNeurons.length; i++) { 
       Neuron iNeuron = inputNeurons[i]; 
       totalInput += iNeuron.getOutput() * weightMatrix1[h * inputNeurons.length + i]; 
      } 

      // set input 
      hNeuron.setInput(totalInput); 

      // calculate output by applying sigmoid 
      double calculatedOutput = sigmoid(totalInput); 

      // set output 
      hNeuron.setOutput(calculatedOutput); 
     } 

     // calculate output of outputNeurons 
     for (int o = 0; o < outputNeurons.length; o++) { 
      Neuron oNeuron = outputNeurons[o]; 
      double totalInput = 0; 

      // sum up totalInput of Neuron 
      for (int h = 0; h < hiddenNeurons.length; h++) { 
       Neuron hNeuron = hiddenNeurons[h]; 
       totalInput += hNeuron.getOutput() * weightMatrix2[o * hiddenNeurons.length + h]; 
      } 

      // set input 
      oNeuron.setInput(totalInput); 

      // calculate output by applying sigmoid 
      double calculatedOutput = sigmoid(totalInput); 

      // set output 
      oNeuron.setOutput(calculatedOutput); 
     } 
    } else { 
     throw new Exception("[NeuralNetwork] input array is either too small or to big"); 
    } 
} 

public void learn(double[] input, double[] output) throws Exception { 
    double partialOutput = 0; 

    // verfiy input length 
    if (input.length == inputNeurons.length) { 
     // forwardPropagation 
     // set input array as input and output of input neurons 
     for (int i = 0; i < input.length; i++) { 
      inputNeurons[i].setInput(input[i]); 
      inputNeurons[i].setOutput(input[i]); 
     } 

     // calculate output of hiddenNeurons 
     for (int h = 0; h < hiddenNeurons.length; h++) { 
      Neuron hNeuron = hiddenNeurons[h]; 
      double totalInput = 0; 

      // sum up totalInput of Neuron 
      for (int i = 0; i < inputNeurons.length; i++) { 
       Neuron iNeuron = inputNeurons[i]; 
       totalInput += iNeuron.getOutput() * weightMatrix1[h * inputNeurons.length + i]; 
      } 

      // set input 
      hNeuron.setInput(totalInput); 

      // calculate output by applying sigmoid 
      double calculatedOutput = sigmoid(totalInput); 

      // set output 
      hNeuron.setOutput(calculatedOutput); 
     } 

     // calculate output of outputNeurons 
     for (int o = 0; o < outputNeurons.length; o++) { 
      Neuron oNeuron = outputNeurons[o]; 
      double totalInput = 0; 

      // sum up totalInput of Neuron 
      for (int h = 0; h < hiddenNeurons.length; h++) { 
       Neuron hNeuron = hiddenNeurons[h]; 
       totalInput += hNeuron.getOutput() * weightMatrix2[o * hiddenNeurons.length + h]; 
      } 

      // set input 
      oNeuron.setInput(totalInput); 

      // calculate output by applying sigmoid 
      double calculatedOutput = sigmoid(totalInput); 

      // set output 
      oNeuron.setOutput(calculatedOutput); 
     } 

     // backPropagation 
     double totalError = 0; 
     // calculate weights in matrix2 
     for (int h = 0; h < hiddenNeurons.length; h++) { 
      Neuron hNeuron = hiddenNeurons[h]; 

      for (int o = 0; o < outputNeurons.length; o++) { 
       Neuron oNeuron = outputNeurons[o]; 

       // calculate weight 
       double delta = learningRate * derivativeSigmoid(oNeuron.getInput()) 
         * (output[o] - oNeuron.getOutput()) * hNeuron.getOutput(); 

       // set new weight 
       weightMatrix2[h + o * hiddenNeurons.length] = weightMatrix2[h + o * hiddenNeurons.length] + delta; 

       // update partial output 
       partialOutput += (derivativeSigmoid(oNeuron.getInput()) * (output[o] - oNeuron.getOutput()) 
         * weightMatrix2[h + o * hiddenNeurons.length]); 

       //calculate error 
       totalError += Math.pow((output[o] - oNeuron.getOutput()), 2); 
      } 
     } 

     //set error 
     this.error = 0.5 * totalError; 

     // calculate weights in matrix1 
     for (int i = 0; i < inputNeurons.length; i++) { 
      Neuron iNeuron = inputNeurons[i]; 

      for (int h = 0; h < hiddenNeurons.length; h++) { 
       Neuron hNeuron = hiddenNeurons[h]; 

       // calculate weight 
       double delta = learningRate * derivativeSigmoid(hNeuron.getInput()) * partialOutput 
         * (iNeuron.getOutput()); 

       // set new weight 
       weightMatrix1[i + h * inputNeurons.length] = weightMatrix1[i + h * inputNeurons.length] + delta; 
      } 
     } 
    } else { 
     throw new Exception("[NeuralNetwork] input array is either too small or to big"); 
    } 
} 

// save Network 
public void saveNetwork(String fileName) throws Exception { 
    File file = new File(fileName); 
    FileWriter writer = new FileWriter(file); 

    writer.write("weightmatrix1:"); 
    writer.write(System.lineSeparator()); 

    // write weightMatrix1 
    for (double d : weightMatrix1) { 
     writer.write(d + "-"); 
    } 

    writer.write(System.lineSeparator()); 
    writer.write("weightmatrix2:"); 
    writer.write(System.lineSeparator()); 

    // write weightMatrix2 
    for (double d : weightMatrix2) { 
     writer.write(d + "-"); 
    } 

    // save 
    writer.close(); 
} 

// sigmoid function 
private double sigmoid(double input) { 
    return Math.exp(input * (-1)); 
} 

private double derivativeSigmoid(double input) { 
    return sigmoid(input) * (1 - sigmoid(input)); 
} 

public double getError() { 
    return error; 
} 
} 
+0

NaN은 숫자가 아님을 의미합니다. 그것은 0으로 나눌 때 발생합니다. 예를 들어, double error = 10/0; 오류는 NaN과 같습니다. 또한 anotherVar = 1 + 오류를 두 번 수행하면; anotherVar도 NaN이됩니다. – Zack

답변

0

귀하의 시그 모이 드 기능이 올바르지 않습니다. 1/(1 + exp (-x))이어야합니다.

여전히 NaN 오류가 발생하면 함수를 과도하게 사용할 수 있습니다. 특히 큰 숫자 (즉, -10보다 작고 10보다 큰 숫자)의 경우에는 과도한 오류 일 수 있습니다.

sigmoid (x)의 미리 계산 된 값 배열을 사용하면 더 큰 데이터 집합에 대해이 문제를 방지 할 수 있으며 프로그램을보다 효율적으로 실행할 수 있습니다.

희망이 도움이됩니다.

관련 문제