아래의 C 코드는 비슷한 질문에 previous answer에 사용 된 알고리즘의 SSE 내장 함수로의 변환입니다.
기본적인 개념은 표준 지수 함수의 계산을 2의 제곱의 계산으로 변환하는 것입니다 : expf (x) = exp2f (x/logf (2.0f)) = exp2f (x * 1.44269504)
. 우리는 t = x * 1.44269504
을 정수 i
과 분수 , 즉 t = i + f
및 0 <= f <= 1
으로 나눕니다. 이제 2 f을 다항식 근사법으로 계산 한 다음 단 정밀도 부동 소수점 결과의 지수 필드에 i
을 더하여 결과를 2 으로 조정할 수 있습니다.
SSE 구현에 존재하는 한 가지 문제점은 i = floorf (t)
을 계산하려고하지만, floor()
기능을 빠르게 계산할 수있는 방법이 없다는 것입니다. 그러나 양수 인 경우 floor(x) == trunc(x)
과 음수 인 경우 floor(x) == trunc(x) - 1
인 경우를 제외하고는 x
이 음수 일 때만 나타납니다. 그러나 코어 근사값은 값인 1.0f
을 처리 할 수 있으므로 음의 인수에 대한 근사를 사용하면 무해합니다. SSE는 단 정밀도 부동 소수점 피연산자를 자름이있는 정수로 변환하는 명령을 제공하므로이 솔루션이 효율적입니다.
Peter Cordes SSE4.1은 빠른 층 기능 _mm_floor_ps()
을 지원하므로 SSE4.1을 사용하는 변형도 아래에 나와 있습니다. SSE 4.1 코드 생성이 활성화되어 있지만 gcc는 모든 툴 체인이 자동으로 __SSE4_1__
매크로를 미리 정의하지는 않습니다.
컴파일러 탐색기 (Godbolt)는 gcc 7.2가 일반 SSE의 경우 sixteen instructions으로, SSE 4.1의 경우 twelve instructions으로 코드를 컴파일하는 것을 보여줍니다.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <emmintrin.h>
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif
/* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, e, p, r;
__m128i i, j;
__m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
__m128 c0 = _mm_set1_ps (0.3371894346f);
__m128 c1 = _mm_set1_ps (0.657636276f);
__m128 c2 = _mm_set1_ps (1.00172476f);
/* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */
t = _mm_mul_ps (x, l2e); /* t = log2(e) * x */
#ifdef __SSE4_1__
e = _mm_floor_ps (t); /* floor(t) */
i = _mm_cvtps_epi32 (e); /* (int)floor(t) */
#else /* __SSE4_1__*/
i = _mm_cvttps_epi32 (t); /* i = (int)t */
j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
i = _mm_sub_epi32 (i, j); /* (int)t - signbit(t) */
e = _mm_cvtepi32_ps (i); /* floor(t) ~= (int)t - signbit(t) */
#endif /* __SSE4_1__*/
f = _mm_sub_ps (t, e); /* f = t - floor(t) */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
j = _mm_slli_epi32 (i, 23); /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
int main (void)
{
union {
float f[4];
unsigned int i[4];
} arg, res;
double relerr, maxrelerr = 0.0;
int i, j;
__m128 x, y;
float start[2] = {-0.0f, 0.0f};
float finish[2] = {-87.33654f, 88.72283f};
for (i = 0; i < 2; i++) {
arg.f[0] = start[i];
arg.i[1] = arg.i[0] + 1;
arg.i[2] = arg.i[0] + 2;
arg.i[3] = arg.i[0] + 3;
do {
memcpy (&x, &arg, sizeof(x));
y = fast_exp_sse (x);
memcpy (&res, &y, sizeof(y));
for (j = 0; j < 4; j++) {
double ref = exp ((double)arg.f[j]);
relerr = fabs ((res.f[j] - ref)/ref);
if (relerr > maxrelerr) {
printf ("arg=% 15.8e res=%15.8e ref=%15.8e err=%15.8e\n",
arg.f[j], res.f[j], ref, relerr);
maxrelerr = relerr;
}
}
arg.i[0] += 4;
arg.i[1] += 4;
arg.i[2] += 4;
arg.i[3] += 4;
} while (fabsf (arg.f[3]) < fabsf (finish[i]));
}
printf ("maximum relative errror = %15.8e\n", maxrelerr);
return EXIT_SUCCESS;
}
fast_sse_exp()
위한 대체 디자인의 조정 인자 x/log(2)
의 정수 부분을 추출으로 반올림 모드 상수 "마법"변환을 추가하는 공지의 기법을 사용하여 1.5 * 2 23 행 올바른 비트 위치에서 강제로 반올림 한 다음 다시 동일한 숫자를 뺍니다. SSE 반올림 모드는 추가하는 동안 "가장 가까운 수로 반올림"(기본값으로 설정)입니다. wim은 적극적인 최적화가 사용될 때 일부 컴파일러가 변환 상수 cvt
을 덧셈과 뺄셈을 최적화 할 수 있으므로이 코드 시퀀스의 기능을 방해하므로 생성 된 기계 코드를 검사하는 것이 좋습니다. 2 f의 계산에 대한 근사 간격은 -0.5 <= f <= 0.5
이후 0에 중심을 맞추기 때문에 다른 코어 근사값이 필요합니다.
:
/* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, p, r;
__m128i i, j;
const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
const __m128 cvt = _mm_set1_ps (12582912.0f); /* 1.5 * (1 << 23) */
const __m128 c0 = _mm_set1_ps (0.238428936f);
const __m128 c1 = _mm_set1_ps (0.703448006f);
const __m128 c2 = _mm_set1_ps (1.000443142f);
/* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
t = _mm_mul_ps (x, l2e); /* t = log2(e) * x */
r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
f = _mm_sub_ps (t, r); /* f = t - rint (t) */
i = _mm_cvtps_epi32 (t); /* i = (int)t */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
j = _mm_slli_epi32 (i, 23); /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
질문의 코드에 대한 알고리즘은 영리 IEEE-754 이진 부동 소수점 형식의 반 대수 특성을 활용 니콜 N. Schraudolph의 작품에서 가져온 것으로 보인다 NN Schraudolph. "지수 함수의 빠르고 컴팩트 한 근사법입니다." 신경 계산, 11 (4), May 1999, pp.853-862.
인수 클램핑 코드를 제거한 후에는 SSE 명령어가 세 개로 줄어 듭니다. "마법"보정 상수 486411
은 전체 입력 도메인에서 최대 상대 오차를 최소화하는 데는 적합하지 않습니다. 간단한 이진 검색을 기반으로하면 298765
값이 우수하여 FastExpSse()
의 최대 상대 오차를 3.56e-2로, 상대 오차 최대 값을 1.33e-3 (fast_exp_sse()
)으로 줄입니다.
/* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
__m128 FastExpSse (__m128 x)
{
__m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23)/log(2) */
__m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
__m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
return _mm_castsi128_ps (t);
}
Schraudolph 알고리즘은 기본적으로 [0,1]에 대한 f
~ 1.0 + f
= f를 선형 근사 2를 사용하고, 그 정확도는 차항을 추가함으로써 개선 될 수있다. Schraudolph의 접근 방식의 영리한 부분은 분수에서 정수 부분 i = floor(x * 1.44269504)
을 명시 적으로 분리하지 않고 2 i * 2 f을 계산하는 것입니다. 나는 차 근사에 그 트릭을 확장 할 수있는 방법을 볼 수 없지만, 하나는 확실히 위에서 사용 된 차 근사치 Schraudolph에서 floor()
계산을 결합 할 수 있습니다 :
/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 f, p, r;
__m128i t, j;
const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23)/log(2) */
const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
const __m128 c0 = _mm_set1_ps (0.3371894346f);
const __m128 c1 = _mm_set1_ps (0.657636276f);
const __m128 c2 = _mm_set1_ps (1.00172476f);
t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
j = _mm_and_si128 (t, m); /* j = (int)(floor (x/log(2))) << 23 */
t = _mm_sub_epi32 (t, j);
f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
현재 구현의 정확성 (최대 상대 오차)가 무엇입니까? 그리고 개선 된 버전을 목표로하는 정확성은 무엇입니까? – njuffa
정말 나쁜 것입니다 (나는 누군가가 그것을 발견 할 수 있다면 오류가 있다고 의심합니다). 아마 이해할 수있는 모든 것이 그것을 이길 것입니다. 1 % 미만의 항목은 훌륭합니다! – Royi
참고 : [sse_mathfun] (https://github.com/RJVB/sse_mathfun) 및 [이 답변] (https://stackoverflow.com/a/8907932/253056) ('log2'와 관련이 있지만 대부분의 제안 사항 또한'exp' 함수를 포함합니다). –