2014-09-11 2 views
2

나는 아래의 네트워크를 훈련시키고 적절한 체중을 얻으려고했지만 계속 달리기를 계속한다. 누구든지 코드에서 무엇이 잘못 될 수 있는지 말해 줄 수 있습니까? 여기서 {8,1}은 입력이고, {-1}}은 signum 함수를 사용하여 예상되는 출력입니다.단일 레이어 퍼셉트론 교육?

import java.util.Arrays; 

public class ANN { 

    public static void main(String args[]) { 

     double threshold = 1.2; 
     double learningRate = 0.08; 

     // Init weights 

     double[] weights = { -1.4, 1.8 }; 

     int[][][] trainingData = { 
      {{8, 1}, {-1}}, 
      {{3, 2}, {-1}}, 
      {{6, 3}, {-1}}, 
      {{1, 4}, {-1}}, 
      {{9, 5}, {1}}, 
      {{5, 6}, {1}}, 
      {{2, 7}, {1}}, 
      {{4, 8}, {1}}, 
      {{7, 9}, {1}}, 
     }; 

     // Start training loop 
     while (true) { 
      int errorCount = 0; 
      // Loop over training data 
      for (int i = 0; i < trainingData.length; i++) { 
       System.out.println("Starting weights: " + Arrays.toString(weights)); 
       // Calculate weighted input 
       double weightedSum = 0; 
       for (int ii = 0; ii < trainingData[i][0].length; ii++) { 
        weightedSum += trainingData[i][0][ii] * weights[ii]; 
       } 

       // Calculate output 
       int output = 0; 
       if (threshold <= weightedSum) { 
        output = 1; 
       } 

       System.out.println("Target output: " + trainingData[i][1][0] 
         + ", " + "Actual Output: " + output); 

       // Calculate error 
       int error = trainingData[i][1][0] - output; 
       System.out.println("Error: " + error); 
       // Increase error count for incorrect output 
       if (error != 0) { 
        errorCount++; 
       } 

       // Update weights 
       for (int ii = 0; ii < trainingData[i][0].length; ii++) { 
        weights[ii] += learningRate * error 
          * trainingData[i][0][ii]; 
       } 

       System.out.println("New weights: " + Arrays.toString(weights)); 
       System.out.println(); 
      } 

      // If there are no errors, stop 
      if (errorCount == 0) { 
       System.out 
         .println("Final weights: " + Arrays.toString(weights)); 
       System.exit(0); 
      } 
     } 
    } 

} 

편집 : 출력을 계산하는 코드 스 니펫에서 문제가 발생한다고 생각합니다. 합계가 임계 값 출력보다 큰 경우 1을 출력하고 그렇지 않으면 0이되도록 반전되어야합니다.

// Calculate output 
       int output = 0; 
       if (weightedSum > threshold) { 
        output = 1; 
       } 

답변

1

난 당신의 코드를 실행하고 (ERRORCOUNT == 0)를 확인하기 직전에 라인 추가 :이 항상 그 신경망을 의미하는 6과 7 사이에 진동이 나타납니다

System.out.println(errorCount); 

을 수행 된 교육의 양에 관계없이 교육 데이터에 대한 무효 추정을 생성합니다. 훈련이 훈련 자료에 대해 100 % 정확하지 않은 경우, 이것은 영원히 지속될 것으로 예상됩니다.

희망이 도움이됩니다!

1

오류는 긍정적이거나 부정적 일 수 있습니다. 첫 번째 실행에서 오류는 -1입니다. 따라서 errorCount가 증가하고 루프를 종료하는 코드가 실행되지 않습니다.

전체 학습을위한 조건은 errorCount가 아니라 오류 자체를 기반으로해야합니다. 오류가 최소 수준 (입력에 따라 설정)에 도달하면 교육이 완료된 것으로 간주됩니다.