Matrix Multiplication

Matrix Multiplication
Matrix Multiplication

Modelling the nonlinear runtime \(T_n^b = \mathcal{O}((m^2)) n^a\) checked by Mathematica

The following C++ programme yields especially the runtime \(T_c^b = \mathcal{O}(1) n^3\) resp. \(T_f^b =\mathcal{O}(m^2) n^2\) for every \(n\times n\) matrix with loop number \(a \in \{2, 3\}\), bit length \(b\) and integer (!) depth of recursion \(m = {}_2n\).

// FastShift method programmed with C++ by Boris Haase and Math AI; version October 22, 2025
#include <chrono>
#include <vector>
#include <iostream>
#include <random>
#include <gmpxx.h>
#include <omp.h>
using namespace std;

struct FlatMatrix {                                                        // FlatMatrixView
    mpz_class* ptr;
    std::size_t row_stride;
    vector<mpz_class> data;

    FlatMatrix(int size) : data(size * size) {
        ptr = data.data();
        row_stride = size;
    }

    FlatMatrix(mpz_class* raw_ptr, int size, std::size_t stride)
        : ptr(raw_ptr), row_stride(stride) {}

    inline mpz_class& operator()(int i, int j) {
        return *(ptr + i * row_stride + j);
    }

    inline const mpz_class& operator()(int i, int j) const {
        return *(ptr + i * row_stride + j);
    }

    FlatMatrix submatrix(int row_offset, int col_offset, int size) const {
        return FlatMatrix(ptr + row_offset * row_stride + col_offset, size, row_stride);
    }
};

inline void fdiv_qr(mpz_class &q, mpz_class &r, const mpz_class &a, const mpz_class &b) {
    mpz_fdiv_qr(q.get_mpz_t(), r.get_mpz_t(), a.get_mpz_t(), b.get_mpz_t());
}

FlatMatrix generate_matrix(int n, int bitlength, int seed) {
    FlatMatrix A(n);
    gmp_randclass rng(gmp_randinit_default);
    rng.seed(seed);                                         // Initialising random generator

    for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j) {
            mpz_class val = rng.get_z_bits(bitlength);      // Numbers with `bitlength` bits
            A(i, j) = (rand() % 2 == 0) ? -val : val;       // Provide signs for each number
        }
    return A;
}

bool compare_matrices(const FlatMatrix& A, const FlatMatrix& B, int n) {          // Compare
    for (int i = 0; i < n; ++i)
        for (int j = 0; j < n; ++j)
            if (A(i, j) != B(i, j)) {
                cout << "Difference at position [" << i << "][" << j << "]:\n";
                cout << "FastShift = " << A(i, j) << "\n";
                cout << "Standard  = " << B(i, j) << "\n";
                return false;
            }
    return true;
}

FlatMatrix multiply_blocked(const FlatMatrix& A, const FlatMatrix& B, int n) {   // Standard
    int block_size = 32;
    FlatMatrix result(n);
    for (int i1 = 0; i1 < n; i1 += block_size) {
        for (int j1 = 0; j1 < n; j1 += block_size) {
            for (int k1 = 0; k1 < n; k1 += block_size) {

                int i2_end = min(i1 + block_size, n);
                int j2_end = min(j1 + block_size, n);
                int k2_end = min(k1 + block_size, n);

                for (int i2 = i1; i2 < i2_end; ++i2) {
                    for (int k2 = k1; k2 < k2_end; ++k2) {
                        mpz_class a_val = A(i2, k2);
                        for (int j2 = j1; j2 < j2_end; ++j2) {
                            result(i2, j2) += a_val * B(k2, j2);
                        }
                    }
                }
            }
        }
    }
    return result;
}

FlatMatrix add_shiR(const FlatMatrix& C, const FlatMatrix& D, int n, int r) {  // C << r + D
    FlatMatrix R(n);
    for (int i = 0; i < n; ++i) {
        mpz_class* rp = R.ptr + i * R.row_stride;
        mpz_class* cp = C.ptr + i * C.row_stride;
        mpz_class* dp = D.ptr + i * D.row_stride;
        
        for (int j = 0; j < n; ++j)
            rp[j] = (cp[j] << r) + dp[j];
    }
    return R;
}

FlatMatrix mul_twoR(const FlatMatrix& A, const FlatMatrix& B) {       // 2x2-matrix enrolled
    FlatMatrix R(2);
    R(0,0) = A(0,0)*B(0,0) + A(0,1)*B(1,0);
    R(0,1) = A(0,0)*B(0,1) + A(0,1)*B(1,1);
    R(1,0) = A(1,0)*B(0,0) + A(1,1)*B(1,0);
    R(1,1) = A(1,0)*B(0,1) + A(1,1)*B(1,1);
    return R;
}

FlatMatrix mul_fouR(const FlatMatrix& A, const FlatMatrix& B) {       // 4x4-matrix enrolled
    FlatMatrix R(4);
    for (int i = 0; i < 4; ++i)
        for (int j = 0; j < 4; ++j)
            R(i,j) = A(i,0)*B(0,j) + A(i,1)*B(1,j) + A(i,2)*B(2,j) + A(i,3)*B(3,j);
    return R;
}

FlatMatrix mul_octR(const FlatMatrix& A, const FlatMatrix& B) {       // 8x8-matrix enrolled
    FlatMatrix R(8);
    for (int i = 0; i < 8; ++i)
        for (int j = 0; j < 8; ++j) {
            R(i,j) = A(i,0)*B(0,j) + A(i,1)*B(1,j) + A(i,2)*B(2,j) + A(i,3)*B(3,j)
                   + A(i,4)*B(4,j) + A(i,5)*B(5,j) + A(i,6)*B(6,j) + A(i,7)*B(7,j);
        }
    return R;
}

FlatMatrix mul_hexR(const FlatMatrix& A, const FlatMatrix& B) {     // 16x16-matrix enrolled
    FlatMatrix R(16);
    for (int i = 0; i < 16; ++i)
        for (int j = 0; j < 16; ++j) {
            R(i,j) = mpz_class(0);
            for (int k = 0; k < 16; ++k)
                R(i,j) += A(i,k) * B(k,j);
        }
    return R;
}

FlatMatrix div_modR(const FlatMatrix& S11, const FlatMatrix& S12,       // balanced division
                    const FlatMatrix& S21, const FlatMatrix& S22, int n, int s)
{
    const int m = n / 2;
    FlatMatrix R(n);

    const mpz_class B = mpz_class(1) << s;
    const mpz_class H = B >> 1;

    for (int i = 0; i < m; ++i) {
        const mpz_class* s11 = S11.ptr + i * S11.row_stride;
        const mpz_class* s12 = S12.ptr + i * S12.row_stride;
        const mpz_class* s21 = S21.ptr + i * S21.row_stride;
        const mpz_class* s22 = S22.ptr + i * S22.row_stride;

        for (int j = 0; j < m; ++j) {
            mpz_class St = s11[j] + s12[j];                          // top block: S11 + S12
            mpz_class q, r;
            fdiv_qr(q, r, St, B);

            if (r > H) {
                ++q; r -= B;
            } else if (r <= -H) {
                --q; r += B;
            }
            R(i, j)     = q;
            R(i, j + m) = r;

            mpz_class Sb = s21[j] + s22[j];                       // bottom block: S21 + S22
            fdiv_qr(q, r, Sb, B);

            if (r > H) {
                ++q; r -= B;
            } else if (r <= -H) {
                --q; r += B;
            }
            R(i + m, j)     = q;
            R(i + m, j + m) = r;
        }
    }
    return R;
}

FlatMatrix fastshift(const FlatMatrix& Y, const FlatMatrix& Z, int n, int s) {  // Recursive
    switch(n) {
    case  2: return mul_twoR(Y, Z);
    case  4: return mul_fouR(Y, Z);
    case  8: return mul_octR(Y, Z);
    case 16: return mul_hexR(Y, Z);

    default: 
        int m = n / 2;
        int t = s * 2;

        auto Z00 = Z.submatrix(0, 0, m);
        auto Z01 = Z.submatrix(0, m, m);
        auto Z10 = Z.submatrix(m, 0, m);
        auto Z11 = Z.submatrix(m, m, m);

        FlatMatrix A = add_shiR(Z00, Z01, m, s);
        FlatMatrix B = add_shiR(Z10, Z11, m, s);
        FlatMatrix U(m), V(m), W(m), X(m);

        auto Y00 = Y.submatrix(0, 0, m);
        auto Y01 = Y.submatrix(0, m, m);
        auto Y10 = Y.submatrix(m, 0, m);
        auto Y11 = Y.submatrix(m, m, m);

        if (m > 32) {
            #pragma omp parallel sections
            {
                #pragma omp section
                U = fastshift(Y00, A, m, t);
                #pragma omp section
                V = fastshift(Y01, B, m, t);

                #pragma omp section
                W = fastshift(Y10, A, m, t);
                #pragma omp section
                X = fastshift(Y11, B, m, t);
            }
        } else {
            U = fastshift(Y00, A, m, t);
            V = fastshift(Y01, B, m, t);
            W = fastshift(Y10, A, m, t);
            X = fastshift(Y11, B, m, t);
        }
        return div_modR(U, V, W, X, n, s);
    }
}

int main() {                                                               // Main programme
    const int bits = 64;
    cout << "\nFastShift for n = 2^m and " << bits << " bits";

    for (int m = 1; m < 13; ++m) {
        int n = 1 << m;
        int s = bits * 2 + m;
        cout << "\n\nMatrices of size: " << n << " x " << n;
        FlatMatrix A = generate_matrix(n, bits,   42 + m);
        FlatMatrix B = generate_matrix(n, bits, 1337 + m);

        auto beg = chrono::steady_clock::now();
        FlatMatrix C_fast = fastshift(A, B, n, s);
        auto end = chrono::steady_clock::now();
        cout << "\nRuntime of FastShift: ";
        cout << chrono::duration_cast<chrono::milliseconds>(end - beg).count() << " ms";

        beg = chrono::steady_clock::now();
        FlatMatrix C_blocked = multiply_blocked(A, B, n);
        end = chrono::steady_clock::now();
        cout << "\nRuntime of C++ Y * Z: ";
        cout << chrono::duration_cast<chrono::milliseconds>(end - beg).count() << " ms";

        bool equal = compare_matrices(C_fast, C_blocked, n);
        cout << "\nResults of FastShift and C++: " << ((equal) ? "equal" : "not equal");
    }
    return 0;
}

Runtime comparison for \(n = 2^m\) and \(64\) bit resp. \(256\) bit: \(T_c\) (C++), \(T_f\) (FastShift), and the ratio \(T_c/T_f\) in milliseconds
\(m\)\(T_c^{64}\)\(T_f^{64}\)\(T_c^{64}\)/\(T_f^{64}\)\(T_c^{256}\)\(T_f^{256}\)\(T_c^{256}/T_f^{256}\)
62273.1424122.00
7174832.101891181.60
813943963.5215225442.80
91100120155.461209427774.36
1088637102638.6696657142336.79
117097515259813.49774781873298.87

For \(m = 12\), the runtime of FastShift is 281560 ms for 64 bit as well as 737616 ms for 256 bit.

© 2025 by Boris Haase

top