2016-10-27 1 views
0

scipy.linprog를 사용하여 L1 회귀를 해결하려고하지만 오류가 발생합니다. | 도끼-B |scipy.linprog를 사용하여 L1 회귀 분석 get ValueError : 입력 배열을 shape (20,1)에서 shape (20)로 방송 할 수 없습니다.

import numpy as np 
from sklearn import datasets 
from scipy.optimize import linprog 


def generate_dataset(n, d): 
    A, b, coef = datasets.make_regression(n_samples=n, 
              n_features=d, 
              n_informative=d, 
              noise=10, 
              coef=True, 
              random_state=0) 
    return A, b, coef 



def solver(A, b): 
    n = len(A) 
    m = len(A[0]) 
    c = np.vstack((np.zeros((m, 1)), np.ones((n, 1)))) 
    A_ = np.vstack((np.hstack((A, -np.eye(n))), np.hstack((-A, -np.eye(n))))) 
    b_ = np.vstack((b, -b)) 
    res = linprog(c, A_ub=A_, b_ub=b_) 
    return res 

A, b, coef = generate_dataset(10, 10) 
res = solver(A, b) 
print(res) 
print(coef) 

generate_dataset 기능은 무작위로 10 개 기능 10 개 샘플의 데이터 세트를 생성 한 후 나는 분을 해결하려고합니다. 이는 선형 프로그래밍을 사용하여 절대 편향 회귀를 최소화하는 간단한 문제입니다. 그러나 오류가 발생합니다. 오류는 ValueError: could not broadcast input array from shape (20,1) into shape (20)입니다. 일부 매트릭스의 치수에 문제가있을 것으로 생각하지만, 알아낼 수는 없습니다.

+0

오류가 발생을? 'linprog' 호출에서? 이 함수에 대한 3 가지 입력의 '모양'은 무엇입니까? – hpaulj

+0

'c '가 (m + n, 1) 배열로 만들어진 이유는 무엇입니까? 왜'(m + n,)'이 아니겠습니까? – hpaulj

+0

다른 옵티 마이저로 전환하는 것을 고려하십시오. linprog는 많은 기쁨을 가져다주지 못합니다 (수정 된 예제는 배가 된 차원에서 실패 할 것입니다). GLPK, CBC 및 cvxpy 및 펄프와 같은 우수한 모델링 도구와 같은 LP 솔버가 훨씬 우수합니다. Cobyla 또는 SLSQP와 함께 scipy.optimize.minimize를 사용하면 데이터가 그렇게 크지 않을 수도 있습니다. – sascha

답변

2

다음 줄에 np.squeeze(c)c 교체 :

res = linprog(c, A_ub=A_, b_ub=b_) 

결과 :

status: 0 
    slack: array([ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 
     0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 
     0.00000000e+00, 1.74947071e-15, 8.81121786e-15, 
     3.08534221e+01, 0.00000000e+00, 0.00000000e+00, 
     0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 
     0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 
     0.00000000e+00, 0.00000000e+00]) 
success: True 
    fun: 15.426711070042149 
     x: array([ 4.17175117e+01, 3.70399683e+01, 6.20756253e+01, 
     3.77095189e+01, 7.52937664e+01, 6.83052169e+01, 
     2.99644354e+01, 0.00000000e+00, 4.16154976e+00, 
     1.57578313e+01, 0.00000000e+00, 1.45010519e-32, 
     0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 
     1.04491091e-15, 1.77635684e-15, 5.78946776e-16, 
     0.00000000e+00, 1.54267111e+01]) 
message: 'Optimization terminated successfully.' 
    nit: 19 
[ 42.38550486 42.87687009 66.01735375 29.8282326 60.63932141 
    61.8015429 30.15748167 1.91931983 13.54740642 29.00776072]