2012-12-31 2 views
0

글쎄, "알고리즘 소개"에서 숫자가 4.2-6 인 질문입니다. 다음과 같이 설명됩니다.strassen 행렬 곱셈

a kn*n matrix by an n*kn matrix을 얼마나 빨리 곱하면 으로 Strassen's algorithm을 사용할 수 있습니까?

두 행렬을 모두 kn*kn matrix으로 늘리면이 질문에 Strassen의 알고리즘을 적용 할 수 있습니다. 그러나 나는 Math.pow(kn, lg7) running time을 얻을 것이다.

아무에게도 더 좋은 해결책이 있습니까? 모두에게 새해 복 많이 받으세요.

+0

Strassen 알고리즘의 런타임은 온라인에서 쉽게 찾을 수 있습니다. 당신의 문제가 무엇인지 잘 모르겠습니다. – Cratylus

+0

http://stackoverflow.com/questions/1920031/strassens-algorithm-for-matrix-multiplication – bonCodigo

+0

@Cratylus :이 질문은 Strassen 알고리즘의 변형입니다. –

답변

1

Strasens의 알고리즘의 다른 벡터 기반 구현은 여기에서, 그것은 순뿐만 아니라 strssens 모두 번 실행에서을 비교 한을 보여준다

enter code here: 
#include <cstdio> 
#include <iostream> 
#include <cstdlib> 
#include <ctime> 
#include <cassert> 
#include <vector> 
#include <ctime> 
using namespace std; 
void fun(vector<vector<int> >& u , vector<vector<int> >&m , int P , int n) 
{ 


    for(int i = 0 ; i < n ; i++) 
    { 
     vector<int>t ; 
     for(int j = 0 ; j < n ; j++) 
     { 
         switch(P) 
      { 
         case 1: 
         { 
       t.push_back(u[i][j]); 
          break; 
      } 
         case 2: 
         { 
       t.push_back(u[i][j+n]); 
          break; 
       } 
         case 3: 
         { 
       t.push_back(u[i+n][j]); 
          break; 
      } 
         case 4: 
         { 
      t.push_back(u[i+n][j+n]); 
          break; 
      } 
        } 
        } 

        m[i] = t; 
    } 
} 
void normalmul(int n , vector< vector<int> >& u , vector< vector<int> >& v ,  vector< vector<int> >& z) 

{ 
for(int i = 0 ; i < n ; i++) 
{ 
    for(int j = 0 ; j < n ; j++) 
    { 
     z[i][j] = 0; 
     for(int k = 0 ; k < n ; k++) 
     { 
      z[i][j] += (u[i][k] * v[k][j]); 
     } 
    } 
} 
} 

void strassen(int n , vector< vector<int> >& u , vector< vector<int> >& v , vector< vector<int> >& z) 

{ 
if(n == 32) 
{ 
    normalmul(n,u,v,z); 
    return; 
} 
else 
{ 
    int Shiftt = n>>1; 
    vector<vector<int> >AA(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >BB(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >CC(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >DD(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >EE(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >FF(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >GG(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >HH(Shiftt , vector<int>(Shiftt)); 

    vector<vector<int> >A1(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >A2(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >A3(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >A4(Shiftt , vector<int>(Shiftt)); 
    fun(u,AA,1,n>>1); 
    fun(u,BB,2,n>>1); 
    fun(u,CC,3,n>>1); 
    fun(u,DD,4,n>>1); 
    fun(v,EE,1,n>>1); 
    fun(v,FF,2,n>>1); 
    fun(v,GG,3,n>>1); 
    fun(v,HH,4,n>>1); 
    vector<vector<int> >M1(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >M2(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >M3(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >M4(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >M5(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >M6(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >M7(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >T1(Shiftt , vector<int>(Shiftt)); 
    vector<vector<int> >T2(Shiftt , vector<int>(Shiftt)); 
    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      T1[i][j] = AA[i][j] + DD[i][j]; 
      T2[i][j] = EE[i][j] + HH[i][j]; 
     } 
    } 
    strassen(Shiftt,T1,T2,M1); 


    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      T1[i][j] = CC[i][j] - AA[i][j]; 
      T2[i][j] = EE[i][j] + FF[i][j]; 
     } 
    } 
    strassen(Shiftt,T1,T2,M6); 

    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      T1[i][j] = BB[i][j] - DD[i][j]; 
      T2[i][j] = GG[i][j] + HH[i][j]; 
     } 
    } 
    strassen(Shiftt,T1,T2,M7); 


    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      T1[i][j] = CC[i][j] + DD[i][j]; 
      T2[i][j] = EE[i][j] ; 
     } 
    } 
    strassen(Shiftt,T1,T2,M2); 


    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      T1[i][j] = AA[i][j] ; 
      T2[i][j] = FF[i][j] - HH[i][j]; 
     } 
    } 
    strassen(Shiftt,T1,T2,M3); 


    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      T1[i][j] = DD[i][j]; 
      T2[i][j] = GG[i][j] - EE[i][j]; 
     } 
    } 
    strassen(Shiftt,T1,T2,M4); 


    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      T1[i][j] = AA[i][j] + BB[i][j]; 
      T2[i][j] = HH[i][j]; 
     } 
    } 
    strassen(Shiftt,T1,T2,M5); 

    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      A1[i][j] = M1[i][j] + M4[i][j] - M5[i][j] + M7[i][j] ; 
      A2[i][j] = M3[i][j] + M5[i][j] ; 
      A3[i][j] = M2[i][j] + M4[i][j] ; 
      A4[i][j] = M1[i][j] - M2[i][j] + M3[i][j] + M6[i][j] ; 
     } 
    } 
    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      z[i][j] = A1[i][j]; 
     } 
    } 
    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      z[i][j+Shiftt] = A2[i][j]; 
     } 
    } 
    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      z[i+Shiftt][j] = A3[i][j]; 
     } 
    } 
    for(int i = 0 ; i < Shiftt ; i++) 
    { 
     for(int j = 0 ; j < Shiftt ; j++) 
     { 
      z[i+Shiftt][j+Shiftt] = A4[i][j]; 
     } 
    } 
} 
} 


int main() 
{ 
int t,n; 
freopen("input_file.txt","r",stdin); 
cin >> t; 
while(t--) 
{ 
    int vl ; 
    scanf("%d",&n); 
    cout << "value of n " << n << endl ;; 
    vector< vector<int> >u(n,vector<int>(n)); 
    vector< vector<int> >v(n,vector<int>(n)); 
    vector< vector<int> >z(n,vector<int>(n)); 
    vector< vector<int> >zz(n,vector<int>(n)); 
    vector<int> temp; 
    for(int i = 0 ; i < n ; i++) 
    { 
      vector<int> temp; 
     for(int j = 0 ; j < n ; j++) 
     { 
      scanf("%d",&vl); 
      temp.push_back(vl); 
     } 
     u[i] = temp; 
    } 
    for(int i = 0 ; i < n ; i++) 
    { 
     vector<int> temp; 
     for(int j = 0 ; j < n ; j++) 
     { 
      scanf("%d",&vl); 
      temp.push_back(vl); 
     } 
     v[i] = temp; 
    } 
    clock_t start , end ; 

    //USING NAIVE APPROACH 

    start = clock(); 
      cout<<"Traditional Algorithm Running Time : "; 
    normalmul(n,u,v,z); 

    end = clock() ; 

    cout<<(double)(end-start)/CLOCKS_PER_SEC<<" seconds"<<endl ; 


    /*cout << "ANSWER OF MULTIPLICATION BY NAIVE APPROACH" << endl ; 
    for(int i = 0 ; i < n ; i++) 
    { 
     for(int j = 0 ; j < n ; j++) 
     { 
      cout << z[i][j] << " "; 
     } 
     cout << endl ; 
    }*/ 


    //USING STRASSENS ALGORITHM 

    start = clock() ; 

    strassen(n,u,v,zz); 

    end = clock(); 
      cout<<"Strassen Algorithm Running Time : "; 
    cout<<(double)(end-start)/CLOCKS_PER_SEC<<" seconds"<<endl ; 

    /*cout << "ANSWER BY STRASSENS ALGORITHM " << endl ; 
    for(int i = 0 ; i < n ; i++) 
    { 
     for(int j = 0 ; j < n ; j++) 
     { 
      cout << zz[i][j] << " "; 
     } 
     cout << endl ; 
    }*/ 
} 
return 0; 
    */ IPG_2011006 Abhishek Yadav */ 
} 
+0

: 대단히 감사합니다. –

0

Strassen in C++에서 구현을 볼 수 있으며이 알고리즘은 위키피디아에서도 잘 설명되어 있습니다.

+0

대단히 고맙습니다. 저는 많은 도움을 주셨고, 정말로 도움을 주셔서 감사합니다. –

2

k * 1 벡터에 1 * k 벡터를 곱하는 대신에 생각하십시오. 이것은 k^2 곱셈을 필요로하고 결국 k * k 행렬을 얻습니다. 여기서 유일한 차이점은 벡터의 원소가 n * n 행렬이므로 Strassen의 알고리즘을 사용하여 n * n을 곱하면 O (k^2 n^(log 7)) 개의 스칼라 곱셈을 수행하게됩니다. 행렬.

+0

당신은 나에게 최고의 대답을 주셨습니다. 정말 고마워요. –

0

쎄타 ((K^2 * N)^(로그 (7))) kn * kn 행렬을 제공합니다. 참조 용 here 또는 check this pdf