#include #include #include #define AT(mat, x, y) ((mat)->data[((x)-1) + ( ((y)-1) * ((mat)->w) )]) #define PRINT_MAT(mat) for (int i = 1; i <= (mat)->h; ++i) { for (int j = 1; j <= (mat)->w; ++j) { printf("%d,", AT((mat),(j), (i))); } printf("\n"); } printf("\n"); typedef struct { uint8_t w, h; int *data; } MAT; MAT *new(uint8_t, uint8_t); MAT *mul(MAT*, MAT*); int main(void) { MAT* a = new(3, 2); MAT* b = new(3, 3); MAT* mult; AT(a, 1, 1) = 0; AT(a, 2, 1) = 1; AT(a, 3, 1) = 2; AT(a, 1, 2) = 3; AT(a, 2, 2) = 4; AT(a, 3, 2) = 5; /* AT(a, 1, 3) = 6; AT(a, 2, 3) = 7; AT(a, 3, 3) = 8; */ PRINT_MAT(a); AT(b, 1, 1) = 1; AT(b, 2, 1) = 0; AT(b, 3, 1) = 0; AT(b, 1, 2) = 0; AT(b, 2, 2) = 1; AT(b, 3, 2) = 0; AT(b, 1, 3) = 0; AT(b, 2, 3) = 0; AT(b, 3, 3) = 1; PRINT_MAT(b); mult = mul(a,b); PRINT_MAT(mult); } MAT *new(uint8_t w, uint8_t h){ MAT* ret = malloc(sizeof(MAT)); ret->w = w; ret->h = h; ret->data = malloc(sizeof(int) * w * h); return ret; } MAT *mul(MAT *a, MAT *b){ if(a->w != b->h) return NULL; MAT* ret = new(b->w, a->h); for (int i = 1; i <= a->h; ++i) { for (int j = 1; j <= b->w; ++j) { int tmp = 0; for (int k = 1; k <= a->w; ++k) { tmp += AT(a, k, i) * AT(b, j, k); } AT(ret, j, i) = tmp; } } return ret; }