[BOJ] 행렬 제곱 C++ 풀이 [분할정복]

Updated:

문제

크기가 N*N인 행렬 A가 주어진다. 이때, A의 B제곱을 구하는 프로그램을 작성하시오. 수가 매우 커질 수 있으니, A^B의 각 원소를 1,000으로 나눈 나머지를 출력한다.

입력

첫째 줄에 행렬의 크기 N과 B가 주어진다. (2 ≤ N ≤ 5, 1 ≤ B ≤ 100,000,000,000)

둘째 줄부터 N개의 줄에 행렬의 각 원소가 주어진다. 행렬의 각 원소는 1,000보다 작거나 같은 자연수 또는 0이다.

출력

첫째 줄부터 N개의 줄에 걸쳐 행렬 A를 B제곱한 결과를 출력한다.

접근 방법

문제 그대로 접근해서 풀면 해결되는 문제이다.

다만 거듭제곱을 A x A x A x … 형태로 B번을 반복해서 곱하게 되면 주어진 제한 시간안에 문제를 해결 할 수 없게 된다.

때문에 고속 거듭제곱 알고리즘을 활용하여 거듭제곱을 계산해 나가야 한다.

#define ll as long long

ll fast_pow(ll a, ll k)
{
    // a는 밑, k는 지수
    ll ans = 1;
    while(k)
    {
        // 지수가 홀수인 경우에만 ans에 밑을 한번 더 곱한다.
        if(b % 2) ans *= a;

        a *= a; // 홀수, 짝수 상관없이 밑을 제곱
        b /= 2; 
    }
    return ans;
}

구현

#include<iostream>
#define ll long long
using namespace std;

ll n, b, num;
int arr[5][5];
int res[5][5];
int tmp[5][5];

void mul(int v1[5][5], int v2[5][5])
{
    for(int i=0; i<n; i++)
    {
        for(int j=0; j<n; j++)
        {
            tmp[i][j] = 0;
            for(int k=0; k<n; k++)
                tmp[i][j] += v1[i][k] * v2[k][j];

            tmp[i][j] %= 1000;
        }
    }

    for(int i=0; i<n; i++)
        for(int j=0; j<n; j++)
            v1[i][j] = tmp[i][j];
}

int main(void)
{
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    cin>>n>>b;

    for(int i=0; i<n; i++)
    {
        for(int j=0; j<n; j++)
            cin>>arr[i][j];
        
        res[i][i] = 1; // 단위 행렬
    }

    while(b)
    {
        if(b % 2 == 1)
        {
            mul(res, arr);
        }
        mul(arr, arr);
        b /= 2;
    }

    for(int i=0; i<n; i++)
    {
        for(int j=0; j<n; j++)
            cout<<res[i][j]<<' ';
        cout<<'\n';
    }
    return 0;
}

오답노트

고속 거듭제곱 알고리즘과 행렬 제곱식 구현을 제대로 하지 못하여 답을 구하지 못하였다.

Leave a comment