sklearn의 MiniBatchDictionaryLearning을 사용하여 사전 학습으로 오류 추적을 구현하고 싶습니다. 그래서 반복을 통해 오류가 감소하는 방식을 기록 할 수 있습니다. 두 가지 방법이 있는데 둘 다 실제로 작동하지 않습니다. 설정 :Python Minibatch 사전 학습
- 입력 데이터 X, NumPy와 배열 형상 (N_SAMPLES, n_features) = (298,143, 300). 이들은 모양의 이미지 (642, 480, 3)에서 생성 된 모양의 패치 (10, 10)입니다.
- 사전 학습 매개 변수 : 열 (또는 원자) 수 = 100, 알파 = 2, 변형 알고리즘 = OMP, 총 번호. 반복 횟수 = 500 (처음에는 작게 유지하고 테스트 사례와 마찬가지로)
계산 오류 : 사전을 학습 한 후 학습 된 사전을 기반으로 원본 이미지를 다시 인코딩합니다. 인코딩과 원본은 같은 모양 (642, 480, 3)의 numpy 배열이므로, 지금은 원소 적으로 유클리드 거리를 사용하고 있습니다 :
err = np.sqrt (np.sum (np.sum (reconstruction - original)) ** 2))
나는이 매개 변수를 사용하여 테스트 실행을했고, 전체 맞게 그 두 가지 방법에에 good.Now를, 그래서 낮은 오류와 꽤 좋은 재건을 생산할 수 있었다 :
방법 1 : 100 회 반복마다 학습 된 사전을 저장하고 오류를 기록하십시오. 500 회 반복의 경우, 이것은 각각 100 회 반복의 5 회 실행을 제공합니다. 각 실행 후 오류를 계산 한 다음 현재 실행 된 사전을 다음 실행을위한 초기화로 사용합니다.
# Fit an initial dictionary, V, as a first run
dico = MiniBatchDictionaryLearning(n_components = 100,
alpha = 2,
n_iter = 100,
transform_algorithm='omp')
dl = dico.fit(patches)
V = dl.components_
# Now do another 4 runs.
# Note the warm restart parameter, dict_init = V.
for i in range(n_runs):
print("Run %s..." % i, end = "")
dico = MiniBatchDictionaryLearning(n_components = 100,
alpha = 2,
n_iter = n_iterations,
transform_algorithm='omp',
dict_init = V)
dl = dico.fit(patches)
V = dl.components_
img_r = reconstruct_image(dico, V, patches)
err = np.sqrt(np.sum((img - img_r)**2))
print("Err = %s" % err)
문제점 : 오류가 감소하지 않고 꽤 높습니다. 사전도 잘 배웠습니다.
방법 2 : partial_fit()
방법을 사용하여 입력 데이터 X를 500 배치로 잘라서 부분 피팅합니다.
batch_size = 500
n_batches = X.shape[0] // batch_size
print(n_batches) # 596
for iternum in range(n_batches):
batch = patches[iternum*batch_size : (iternum+1)*batch_size]
V = dico.partial_fit(batch)
문제 :이 방법은 약 5000 배 더 오래 걸리는 것으로 보입니다.
피팅 프로세스에서 오류를 검색하는 방법이 있는지 알고 싶습니다.
안녕, 이것은 아주 좋은 것입니다, 감사합니다 - 그것은 수천 배 느리다 것을 제외하고!) : 저는 오버 헤드의 증가를 일으키는 원인이 무엇인지 잘 모릅니다. 기본적으로, MiniBatchDL 함수를 다시 만들려고하지만 이번에는 사전의 현재 상태를 저장하고 목록에 현재 상태를 추가하여 정기적으로 간격을두고 일시 중지하므로 나중에 오류를 계산할 수 있습니다. – AndreyIto
'n_updates % 100 == 0 :'을'n_updates % 10000 == 0 :'으로 변경하려고 시도했을 수 있습니다. 나는 dict를 업데이트하면서 오류를 계산하는 것의 상대적 오버 헤드가 무엇인지 전혀 모른다. – ogrisel
답장을 보내 주셔서 감사합니다. 내가 한 일은'n_updates % 100 == 0 : print (n_updates)'입니다. 즉, 사전을 아직 저장하지 않아서 계산 시간을 늘릴 수있는 가능성을 사전을 저장하지 못하게 할 수 있습니다 (처음에는 그다지 중요하지 않습니다!). 불행히도 천천히. 내가 알 수있는 한 실제로는 반복되는 작업이 있습니다. 'inner_stats'라는 행렬이 지나가고 있지만, 여전히 확실하지 않습니다. – AndreyIto