2013-03-09 2 views
11

SSE 명령어를 사용하여 두 개의 16 바이트 숫자를 비교하기 위해 함수 int compare_16bytes(__m128i lhs, __m128i rhs)을 작성했습니다.이 함수는 비교 수행 후 얼마나 많은 바이트가 동일한지를 반환합니다.두 배열 사이의 동일한 바이트 수를 빠르게 계산하십시오.

이제 임의의 길이의 두 바이트 배열을 비교하기 위해 위의 함수를 사용하고 싶습니다. 길이가 16 바이트의 배수가 아니기 때문에이 문제를 해결해야합니다. 아래 함수의 구현을 어떻게 완료 할 수 있습니까? 아래의 기능을 어떻게 향상시킬 수 있습니까?

int fast_compare(const char* s, const char* t, int length) 
{ 
    int result = 0; 

    const char* sPtr = s; 
    const char* tPtr = t; 

    while(...) 
    { 
     const __m128i* lhs = (const __m128i*)sPtr; 
     const __m128i* rhs = (const __m128i*)tPtr; 

     // compare the next 16 bytes of s and t 
     result += compare_16bytes(*lhs,*rhs); 

     sPtr += 16; 
     tPtr += 16; 
    } 

    return result; 
} 
+2

는 패딩이 달라야 루프 우측에 좌측 및 것들 (길이/16 시간), 및 패드 0을 사용. –

+1

'while (길이> = 16) {/ * 함수 사용 */길이 - = 16; } 길이 (최대 길이)/* 길이 (최대 15) 바이트를 비교하는 버전 사용 * /; ' – pmg

+1

FYI 이것은 종종 [* 해밍 거리 *]라고합니다 (http://en.wikipedia.org/wiki/Hamming_distance) - 검색 용어로 유용 할 수 있습니다. –

답변

6

@Mysticial 위의 의견에 말했듯이, 수행 비교하고 수직으로 요약 한 후 바로 메인 루프의 끝에서 수평으로 요약 :

#include <stdio.h> 
#include <stdlib.h> 
#include <time.h> 
#include <emmintrin.h> 

// reference implementation 
int fast_compare_ref(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int i; 

    for (i = 0; i < length; ++i) 
    { 
     if (s[i] == t[i]) 
      result++; 
    } 
    return result; 
} 

// optimised implementation 
int fast_compare(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int i; 

    __m128i vsum = _mm_set1_epi32(0); 
    for (i = 0; i < length - 15; i += 16) 
    { 
     __m128i vs, vt, v, vh, vl, vtemp; 

     vs = _mm_loadu_si128((__m128i *)&s[i]); // load 16 chars from input 
     vt = _mm_loadu_si128((__m128i *)&t[i]); 
     v = _mm_cmpeq_epi8(vs, vt);    // compare 
     vh = _mm_unpackhi_epi8(v, v);   // unpack compare result into 2 x 8 x 16 bit vectors 
     vl = _mm_unpacklo_epi8(v, v); 
     vtemp = _mm_madd_epi16(vh, vh);   // accumulate 16 bit vectors into 4 x 32 bit partial sums 
     vsum = _mm_add_epi32(vsum, vtemp); 
     vtemp = _mm_madd_epi16(vl, vl); 
     vsum = _mm_add_epi32(vsum, vtemp); 
    } 

    // get sum of 4 x 32 bit partial sums 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8)); 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4)); 
    result = _mm_cvtsi128_si32(vsum); 

    // handle any residual bytes (< 16) 
    if (i < length) 
    { 
     result += fast_compare_ref(&s[i], &t[i], length - i); 
    } 

    return result; 
} 

// test harness 
int main(void) 
{ 
    const int n = 1000000; 
    char *s = malloc(n); 
    char *t = malloc(n); 
    int i, result_ref, result; 

    srand(time(NULL)); 

    for (i = 0; i < n; ++i) 
    { 
     s[i] = rand(); 
     t[i] = rand(); 
    } 

    result_ref = fast_compare_ref(s, t, n); 
    result = fast_compare(s, t, n); 

    printf("result_ref = %d, result = %d\n", result_ref, result);; 

    return 0; 
} 

컴파일하고 위의 테스트 장치 실행

을 우리가 풀고 _mm_madd_epi16를 사용하여 16 비트를 축적 위의 SSE 코드의 하나의 가능성이 아닌 명백한 속임수가 있다는
$ gcc -Wall -O3 -msse3 fast_compare.c -o fast_compare 
$ ./fast_compare 
result_ref = 3955, result = 3955 
$ ./fast_compare 
result_ref = 3947, result = 3947 
$ ./fast_compare 
result_ref = 3945, result = 3945 

0/-1 값은 32 비트 부분 합계입니다. -1*-1 = 1 (및 0*0 = 0 물론)을 활용합니다. 여기서는 실제로 곱셈을하지 않고 하나의 명령어로 풀고 요약하는 것입니다.


UPDATE : 아래의 코멘트에 언급 한 바와 같이,이 솔루션이 최적이 아닌 - 난 그냥 상당히 최적의 16 비트 솔루션을 가져다가 8 개 비트 데이터를 작동하도록 풀고 16 비트에 8 비트를 추가했다. 그러나 8 비트 데이터의 경우보다 효율적인 방법이 있습니다. psadbw/_mm_sad_epu8을 사용하십시오. 나는이 대답을 후손을 위해 남겨두고, 16 비트 데이터로 이런 종류의 일을하고 싶어하는 사람들을 위해, 그러나 실제로 입력 데이터를 푸는 것을 요구하지 않는 다른 해답 중 하나는 받아 들여진 대답이어야한다.

+0

훌륭합니다! 제대로 작동합니다! 더욱이, 두 벡터's'와't'가 _aligned_입니까? 정렬이란 무엇입니까? – enzom83

+1

위의 예에서'_mm_loadu_si128'을 사용했기 때문에 정렬에 대해서는 중요하지 않습니다. 's'와't '가 16 바이트 정렬이라는 것을 보장 할 수 있다면, 특히 구형 CPU에서 성능 향상을 위해'_mm_loadu_si128' 대신'_mm_load_si128'을 사용하십시오. –

+0

_mm_setzero_si128()은 vsum을 제로화하는 데있어 _mm_set1_epi32 (0)보다 빠를 수 있습니다. – leecbaker

1

SSE의 정수 비교는 모두 0이거나 모두 1 인 바이트를 생성합니다. 집계하려면 먼저 비교 결과를 7만큼 오른쪽으로 이동시킨 다음 결과 벡터에 추가해야합니다. 결국, 결과 벡터의 요소를 합하여 결과 벡터를 줄여야합니다. 이 감소는 스칼라 코드 또는 일련의 추가/교대로 수행되어야합니다. 보통이 부분은 문제가 될만한 가치가 없습니다.

3

16 x uint8 요소의 부분 합계를 사용하면 성능이 향상 될 수 있습니다.
나는 내부 루프와 외부 루프로 루프를 나눴다.
내부 루프 sum uint8 요소 (각 uint8 요소는 최대 255 "1"을 합계 할 수 있음).
작은 트릭 : _mm_cmpeq_epi8은 같은 원소를 0xFF로 설정하고 (char) 0xFF = -1이므로 합계에서 결과를 뺄 수 있습니다 (1을 더하려면 -1 빼기). 큰 입력에 대한

int fast_compare2(const char *s, const char *t, int length) 
{ 
    int result = 0; 
    int inner_length = length; 
    int i; 
    int j = 0; 

    //Points beginning of 4080 elements block. 
    const char *s0 = s; 
    const char *t0 = t; 


    __m128i vsum = _mm_setzero_si128(); 

    //Outer loop sum result of 4080 sums. 
    for (i = 0; i < length; i += 4080) 
    { 
     __m128i vsum_uint8 = _mm_setzero_si128(); //16 uint8 sum elements (each uint8 element can sum up to 255). 
     __m128i vh, vl, vhl, vhl_lo, vhl_hi; 

     //Points beginning of 4080 elements block. 
     s0 = s + i; 
     t0 = t + i; 

     if (i + 4080 <= length) 
     { 
      inner_length = 4080; 
     } 
     else 
     { 
      inner_length = length - i; 
     } 

     //Inner loop - sum up to 4080 (compared) results. 
     //Each uint8 element can sum up to 255. 16 uint8 elements can sum up to 255*16 = 4080 (compared) results. 
     ////////////////////////////////////////////////////////////////////////// 
     for (j = 0; j < inner_length-15; j += 16) 
     { 
       __m128i vs, vt, v; 

       vs = _mm_loadu_si128((__m128i *)&s0[j]); // load 16 chars from input 
       vt = _mm_loadu_si128((__m128i *)&t0[j]); 
       v = _mm_cmpeq_epi8(vs, vt);    // compare - set to 0xFF where equal, and 0 otherwise. 

       //Consider this: (char)0xFF = (-1) 
       vsum_uint8 = _mm_sub_epi8(vsum_uint8, v); //Subtract the comparison result - subtract (-1) where equal. 
     } 
     ////////////////////////////////////////////////////////////////////////// 

     vh = _mm_unpackhi_epi8(vsum_uint8, _mm_setzero_si128());  // unpack result into 2 x 8 x 16 bit vectors 
     vl = _mm_unpacklo_epi8(vsum_uint8, _mm_setzero_si128()); 
     vhl = _mm_add_epi16(vh, vl); //Sum high and low as uint16 elements. 

     vhl_hi = _mm_unpackhi_epi16(vhl, _mm_setzero_si128()); //unpack sum of vh an vl into 2 x 4 x 32 bit vectors 
     vhl_lo = _mm_unpacklo_epi16(vhl, _mm_setzero_si128()); //unpack sum of vh an vl into 2 x 4 x 32 bit vectors 

     vsum = _mm_add_epi32(vsum, vhl_hi); 
     vsum = _mm_add_epi32(vsum, vhl_lo); 
    } 

    // get sum of 4 x 32 bit partial sums 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 8)); 
    vsum = _mm_add_epi32(vsum, _mm_srli_si128(vsum, 4)); 
    result = _mm_cvtsi128_si32(vsum); 

    // handle any residual bytes (< 16) 
    if (j < inner_length) 
    { 
     result += fast_compare_ref(&s0[j], &t0[j], inner_length - j); 
    } 

    return result; 
} 
+0

흠, 나는 Paul 's에 대해 논하기 전에 새로운 대답을 보았어야했다. 나는 같은 것을 제안했다 (psubb는 안쪽 루프 안에있다). 이것은'psadbw'를 사용하여'vsum_uint8'의 가로 합계를해야한다는 것을 제외하고는 제가 의미했던 바입니다 (Paul의 대답에 대한 저의 의견을보십시오). –

+0

나는 수평 합계를 사용하려고 생각했지만 SSE2 호환성을 유지하기로 결정했다. – Rotem

+0

당신은'phaddd'에 대해 이야기하고 있습니까? 그건 내가 말한 것이 아니다. 'phaddd'의 [장점은 코드 크기입니다] (http : // stackoverflow.com/questions/6996764/가장 빠른 방법 - 수평 - 부동 - 벡터 - 합계 - x86/35270026 # 35270026) 현재 CPU에서. 또한 SSE2 지침 만 사용하는이 질문에 대한 내 대답을 참조하십시오. –

2

가장 빠른 방법은 내부 루프는 벡터의 바이트 요소 앞에 수평으로 합계 탈옥, pcmpeqb/psubb입니다 로템의 대답입니다 : 여기

는 fast_compare 내 최적화 된 버전입니다 어큐뮬레이터가 오버플로됩니다. 모두 0 벡터에 대해 psadbw으로 부호없는 바이트의 hsum을 수행합니다.대신 모든 제로의 0x7f의 벡터에 대한 루프, psadbw에서 레지스터 많은 압력이없는 경우 줄이기/중첩 루프없이

은 최선의 선택은 아마도

pcmpeqb -> vector of 0 or 0xFF elements 
psadbw -> two 64bit sums of (0*no_matches + 0xFF*matches) 
paddq  -> accumulate the psadbw result in a vector accumulator 

#outside the loop: 
horizontal sum 
divide the result by 255 

입니다.

  • psadbw(0x00, set1(0x7f)) =>sum += 0x7f
  • psadbw(0xff, set1(0x7f)) =>

sum += 0x80 그래서 대신 단지 n * 0x7f을 뺄 필요는 컴파일러 (실제 div없이 효율적으로 수행되어야한다) (255)에 의해 분할 n은 요소의 수입니다.

또한 Nehalem 및 Atom에서 느린 경우 128 비트 * 수가 32 비트 정수를 오버플로 할 것으로 예상하지 않으면 paddd (_mm_add_epi32)을 사용할 수 있습니다.

이것은 바울 R의 pcmpeqb/배 punpck/2 배 pmaddwd/2 배 paddw 아주 잘 비교합니다. 나머지 바이트가 16 미만인 경우 거짓 동등한 패딩을 카운트하지 않도록

관련 문제