0
cuBLAS 함수 cublasSgemm을 사용하여 행렬 곱셈의 간단한 예제를 작성하려고합니다. 내 코드는 아래와 같습니다 :이 코드에서 cublasSgemm을 사용하여 행렬 곱셈에 실패한 이유는 무엇입니까?
int m =100, n = 100;
float * bold1 = new float [m*n];
float * bold2 = new float [m*n];
float * bold3 = new float [m*n];
for (int i = 0; i< m; i++)
for(int j = 0; j <n;j++)
{
bold1[i*n+j]=rand()%10;
bold2[i*n+j]=rand()%10;
}
cudaError_t cudaStat;
cublasStatus_t stat;
cublasHandle_t handle;
const float alpha = 1.0;
const float beta = 0;
float * dev_bold1, * dev_bold2, *dev_bold3;
cudaStat = cudaMalloc ((void**)&bold1, sizeof(float)*m*n);
if(cudaStat != CUBLAS_STATUS_SUCCESS)
{
cout<<"problem1";
return cudaStat;
}
cudaStat = cudaMalloc ((void**)&bold2,sizeof(float)*m*n);
if(cudaStat != CUBLAS_STATUS_SUCCESS)
{
cout<<"problem2";
return cudaStat;
}
cudaStat = cudaMalloc ((void**)&bold3,sizeof(float)*m*n);
if(cudaStat != CUBLAS_STATUS_SUCCESS)
{
cout<<"problem3";
return cudaStat;
}
cublasSetMatrix(m,n,sizeof(float),bold1,m,dev_bold1,m);
cublasSetMatrix(m,n,sizeof(float),bold2,m,dev_bold2,m);
stat = cublasCreate(&handle);
if(stat != CUBLAS_STATUS_SUCCESS)
{
cout<<"problem4";
return stat;
}
cout<<stat<<" "<<CUBLAS_STATUS_SUCCESS<<"\n";
stat = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, m, n ,&alpha, dev_bold1, n, dev_bold2, n, &beta,dev_bold3,m);
if (stat != CUBLAS_STATUS_SUCCESS)
{
cout<<"problem5";
return stat;
}
cudaStat = cudaMemcpy(bold3,dev_bold3,sizeof(float)*m*n,cudaMemcpyDeviceToHost);
if (cudaStat != cudaSuccess)
{
cout<<"problem6";
return cudaStat;
}
delete []bold1;
delete []bold2;
cudaFree(dev_bold1);
cudaFree(dev_bold2);
cudaFree(dev_bold3);
이 코드에서는 임의의 숫자로 채워진 행렬 bold1과 bold2에 곱하고 싶습니다. 코드의이 부분과 관련이 코드 반환 "문제 5"
stat = cublasSgemm(handle, CUBLAS_OP_N, CUBLAS_OP_N, m, m, n ,&alpha, dev_bold1, n, dev_bold2, n, &beta,dev_bold3,m);
if (stat != CUBLAS_STATUS_SUCCESS)
{
cout<<"problem5";
return stat;
}
또한 내가 인쇄 스탯하고 "13"을 보여줍니다!
아무도 내 코드의 문제점을 이해하는 데 도움이 될 수 있습니까? 감사합니다. 다른 2 cudaMalloc
문에 대한
float * dev_bold1, * dev_bold2, *dev_bold3;
cudaStat = cudaMalloc ((void**)&bold1, sizeof(float)*m*n);
^
|
this should be dev_bold1
마찬가지로 :