1

크기가 매우 큽니다 (행렬에서 최대 10 억 개의 요소)라고 가정합니다. 행렬 벡터 제품에 대한 캐시를 인식하지 못하는 알고리즘을 어떻게 구현합니까? 위키 피 디아를 기반으로하면 재귀 적으로 나누고 정복해야하지만 오버 헤드가 많을 것이라고 생각합니다. 그렇게하는 것이 효율적일까요?행렬 및 벡터 곱셈 최적화 알고리즘

질문과 대답을 따르 OpenMP with matrices and vectors

+0

[scicomp.SE] (HTTP를하지만 순진 구현을하는 것은, 그것은 적절한 SGEMV 라이브러리 호출보다 간단한 더블 루프보다 느리게, 그리고 방법 느리다 : //scicomp.stackexchange.com)? –

답변

3

그래서,에 대한 조정 BLAS 라이브러리에 찾아 링크를 항상 도처에있다 "나는 빨리이 기본 선형 대수 작업을 어떻게해야합니까"라는 질문에 대한 대답하여 플랫폼. 예 : GotoBLAS (그의 작업은 OpenBLAS으로 계속 진행 중임) 또는 느린 자동 조정 ATLAS 또는 Intel의 MKL과 같은 상용 패키지. 선형 대수학은 매우 많은 다른 작업에서 매우 기본이기 때문에 다양한 플랫폼에서 이러한 패키지를 최적화하는 데 엄청난 노력이 필요하며 경쟁 할 몇 가지 오후 작업에서 뭔가를 생각해 볼 기회가 없습니다. 일반 조밀 한 행렬 - 벡터 곱셈을 찾고자하는 특정 서브 루틴 호출은 SGEMV/DGEMV/CGEMV/ZGEMV입니다.

캐시-잊기 알고리즘, 또는 오토 튜닝은, 당신이 당신의 시스템의 특정 캐시 아키텍처에 대한 조정을 방해 할 수없는 경우위한 - 일반적으로, 잘 될 수도 있지만, 사람들 BLAS을 위해 그렇게 기꺼이 때문에 루틴을 사용하고 튜닝 된 결과를 사용 가능하게하면 해당 루틴을 사용하는 것이 가장 좋습니다.

GEMV에 대한 메모리 액세스 패턴은 나누기와 정복 (매트릭스 전치의 표준 경우와 동일)이 필요하지 않을 정도로 충분히 간단합니다. 캐시 차단 크기를 찾아 사용하면됩니다. GEMV (y = Ax)에서, 당신은 여전히 ​​전체 매트릭스를 한 번 스캔해야하므로 재사용 (그리고 따라서 효과적인 캐시 사용)을 위해 할 일이 없지만 가능한 한 많이 x를 재 시도해 볼 수 있습니다 한 번 (행 수) 시간 대신에 - 그리고 당신은 여전히 ​​캐시 친화적 인 A에 대한 액세스를 원합니다. 따라서 캐시를 차단하는 명백한 방법은 블록을 따라 분해하는 것입니다.

A x -> [ A11 | A12 ] | x1 | = | A11 x1 + A12 x2 | 
     [ A21 | A22 ] | x2 | | A21 x1 + A22 x2 | 

그리고 확실히 재귀 적으로 수행 할 수 있습니다.

$ ./gemv 
Testing for N=4096 
Double Loop: time = 0.024995, error = 0.000000 
Divide and conquer: time = 0.299945, error = 0.000000 
SGEMV: time = 0.013998, error = 0.000000 

코드는 다음과 같습니다 :

#include <stdio.h> 
#include <stdlib.h> 
#include <sys/time.h> 
#include "mkl.h" 

float **alloc2d(int n, int m) { 
    float *data = malloc(n*m*sizeof(float)); 
    float **array = malloc(n*sizeof(float *)); 
    for (int i=0; i<n; i++) 
     array[i] = &(data[i*m]); 
    return array; 
} 

void tick(struct timeval *t) { 
    gettimeofday(t, NULL); 
} 

/* returns time in seconds from now to time described by t */ 
double tock(struct timeval *t) { 
    struct timeval now; 
    gettimeofday(&now, NULL); 
    return (double)(now.tv_sec - t->tv_sec) + ((double)(now.tv_usec - t->tv_usec)/1000000.); 
} 

float checkans(float *y, int n) { 
    float err = 0.; 
    for (int i=0; i<n; i++) 
     err += (y[i] - 1.*i)*(y[i] - 1.*i); 
    return err; 
} 

/* assume square matrix */ 
void divConquerGEMV(float **a, float *x, float *y, int n, 
        int startr, int endr, int startc, int endc) { 

    int nr = endr - startr + 1; 
    int nc = endc - startc + 1; 

    if (nr == 1 && nc == 1) { 
     y[startc] += a[startr][startc] * x[startr]; 
    } else { 
     int midr = (endr + startr+1)/2; 
     int midc = (endc + startc+1)/2; 
     divConquerGEMV(a, x, y, n, startr, midr-1, startc, midc-1); 
     divConquerGEMV(a, x, y, n, midr, endr, startc, midc-1); 
     divConquerGEMV(a, x, y, n, startr, midr-1, midc, endc); 
     divConquerGEMV(a, x, y, n, midr, endr, midc, endc); 
    } 
} 
int main(int argc, char **argv) { 
    const int n=4096; 
    float **a = alloc2d(n,n); 
    float *x = malloc(n*sizeof(float)); 
    float *y = malloc(n*sizeof(float)); 
    struct timeval clock; 
    double eltime; 

    printf("Testing for N=%d\n", n); 

    for (int i=0; i<n; i++) { 
     x[i] = 1.*i; 
     for (int j=0; j<n; j++) 
      a[i][j] = 0.; 
     a[i][i] = 1.; 
    } 

    /* naive double loop */ 
    tick(&clock); 
    for (int i=0; i<n; i++) { 
     y[i] = 0.; 
     for (int j=0; j<n; j++) { 
      y[i] += a[i][j]*x[j]; 
     } 
    } 
    eltime = tock(&clock); 
    printf("Double Loop: time = %lf, error = %f\n", eltime, checkans(y,n)); 

    for (int i=0; i<n; i++) y[i] = 0.; 

    /* naive divide and conquer */ 
    tick(&clock); 
    divConquerGEMV(a, x, y, n, 0, n-1, 0, n-1); 
    eltime = tock(&clock); 
    printf("Divide and conquer: time = %lf, error = %f\n", eltime, checkans(y,n)); 

    /* decent GEMV implementation */ 
    tick(&clock); 

    float alpha = 1.; 
    float beta = 0.; 
    int incrx=1; 
    int incry=1; 
    char trans='N'; 

    sgemv(&trans,&n,&n,&alpha,&(a[0][0]),&n,x,&incrx,&beta,y,&incry); 
    eltime = tock(&clock); 
    printf("SGEMV: time = %lf, error = %f\n", eltime, checkans(y,n)); 

    return 0; 
}