2017-11-24 2 views
1

이 코드를 사용하여 theta를 찾지 못했습니다.간단한 선형 회귀 분석에 어떤 문제가 있습니까?

문제를 시각화하는 데 도움이되는 플롯 코드가 추가되었습니다.

은 나를

감사

을 코드의 짧은 블록의 버그를 발견 도와주세요

Normal Equation used in linear function

import numpy as np 
import matplotlib.pyplot as plt 

N = 20 

def arr(n): 
    return np.arange(n) + 1 

def linear(features, y): 
    x = np.vstack(features).T 
    xT = np.transpose(x) 
    xTx = xT.dot(x) 
    return np.linalg.inv(xTx).dot(xT).dot(y) 

def plot(x, y, dots_y): 
    plt.plot(x, y) 
    plt.plot(x, dots_y, marker='o', linestyle=' ', color='r') 
    plt.show()  

y = arr(N) ** 2 + 3  
theta = linear((np.ones(N), arr(N), arr(N) ** 2), y) 

plot(arr(N), arr(N) ** theta[1] + theta[0], y) 

Output plot

답변

1

오류는해야하는 플로팅 라인에

plot(arr(N), arr(N)**2 * theta[2] + arr(N) * theta[1] + theta[0], y) 

2 차 다항식 모델에 따라.

또한; 난 당신이 최소 제곱 솔루션의 계산 해설 이유로이 방법을했다 생각하지만, 다음과 같이 실제로 선형 최소 제곱 적합 훨씬 더 짧고 효율적인 코드로, np.linalg.lstsq 얻을 것이다 :

N = 20 
x = np.arange(1, N+1) 
y = x**2 + 3 
basis = np.vstack((x**0, x**1, x**2)).T # basis for the space of quadratic polynomials 
theta = np.linalg.lstsq(basis, y)[0] # least squares approximation to y in this basis 
plt.plot(x, y, 'ro')     # original points 
plt.plot(x, basis.dot(theta))   # best fit 
plt.show() 

fit