1
이 코드를 사용하여 theta를 찾지 못했습니다.간단한 선형 회귀 분석에 어떤 문제가 있습니까?
문제를 시각화하는 데 도움이되는 플롯 코드가 추가되었습니다.
은 나를
감사
을 코드의 짧은 블록의 버그를 발견 도와주세요import numpy as np
import matplotlib.pyplot as plt
N = 20
def arr(n):
return np.arange(n) + 1
def linear(features, y):
x = np.vstack(features).T
xT = np.transpose(x)
xTx = xT.dot(x)
return np.linalg.inv(xTx).dot(xT).dot(y)
def plot(x, y, dots_y):
plt.plot(x, y)
plt.plot(x, dots_y, marker='o', linestyle=' ', color='r')
plt.show()
y = arr(N) ** 2 + 3
theta = linear((np.ones(N), arr(N), arr(N) ** 2), y)
plot(arr(N), arr(N) ** theta[1] + theta[0], y)