COMBINATORIAL_BLAS 1.6
 
Loading...
Searching...
No Matches
BlockSpGEMM.h
Go to the documentation of this file.
1#ifndef _BLOCK_SPGEMM_H_
2#define _BLOCK_SPGEMM_H_
3
4#include "CombBLAS.h"
5
6
7namespace combblas
8{
9
10
11template <typename IT,
12 typename NTA,
13 typename DERA,
14 typename NTB,
15 typename DERB>
16struct BlockSpGEMM
17{
18
19private:
20
21 std::vector<std::vector<SpParMat<IT, NTA, DERA>>> A_blocks_;
22 std::vector<std::vector<SpParMat<IT, NTB, DERB>>> B_blocks_;
23 int br_, bc_, bi_, cur_block_;
24 IT nr_, nc_;
25
26
27
28
29public:
30
33 int br,
34 int bc,
35 int bi = 1
36 ) :
37 br_(br), bc_(bc), bi_(bi), cur_block_(0)
38 {
39 A_blocks_ = A.BlockSplit(br_, bi_);
40 B_blocks_ = B.BlockSplit(bi_, bc_);
41 nr_ = A.getnrow();
42 nc_ = B.getncol();
43 }
44
45
46
47 template<typename SR,
48 typename NTC,
49 typename DERC>
52 {
53 assert(bi_ == 1);
54
55 int rbid = cur_block_ / bc_;
56 int cbid = cur_block_ % bc_;
57 ++cur_block_;
58
59 IT bs = nr_ / br_;
60 IT r = nr_ % br_;
61 roffset = (std::min(static_cast<IT>(rbid), r)*(bs+1)) +
62 ((rbid < r ? 0 : rbid-r)*bs);
63 // (std::max(static_cast<IT>(0), rbid-r)*bs);
64
65 bs = nc_ / bc_;
66 r = nc_ % bc_;
67 coffset = (std::min(static_cast<IT>(cbid), r)*(bs+1)) +
68 ((cbid < r ? 0 : cbid-r)*bs);
69 // (std::max(static_cast<IT>(0), cbid-r)*bs);
70
72 (A_blocks_[rbid][0], B_blocks_[0][cbid], false, false);
73 }
74
75
76
77 bool
79 {
80 return cur_block_ < br_*bc_;
81 }
82
83
84
85 template<typename SR,
86 typename NTC,
87 typename DERC>
90 {
91 assert(bi_ == 1);
92
93 IT bs = nr_ / br_;
94 IT r = nr_ % br_;
95 roffset = (std::min(static_cast<IT>(rbid), r)*(bs+1)) +
96 ((rbid < r ? 0 : rbid-r)*bs);
97 // (std::max(static_cast<IT>(0), rbid-r)*bs);
98
99 bs = nc_ / bc_;
100 r = nc_ % bc_;
101 coffset = (std::min(static_cast<IT>(cbid), r)*(bs+1)) +
102 ((cbid < r ? 0 : cbid-r)*bs);
103 // (std::max(static_cast<IT>(0), cbid-r)*bs);
104
106 (A_blocks_[rbid][0], B_blocks_[0][cbid], false, false);
107 }
108
109
110
111 std::vector<IT>
113 {
114 IT bs = nr_ / br_;
115 IT r = nr_ % br_;
116 if (!is_row)
117 {
118 bs = nc_ / bc_;
119 r = nc_ % bc_;
120 }
121
122 int nblocks = (is_row ? br_ : bc_);
123 std::vector<IT> offsets(nblocks+1);
124 for (int b = 0; b < nblocks; ++b)
125 offsets[b] = (std::min(static_cast<IT>(b), r)*(bs+1)) +
126 ((b < r ? 0 : b-r)*bs);
127 offsets[nblocks] = (is_row ? nr_ : nc_);
128
129 return offsets;
130 }
131};
132
133
134}
135
136
137#endif
int64_t IT
SelectMaxSRing< bool, int64_t > SR
Definition SpMMError.cpp:18
Definition test.cpp:53
double A
SpParMat< IT, NTC, DERC > getBlockId(int rbid, int cbid, IT &roffset, IT &coffset)
Definition BlockSpGEMM.h:89
std::vector< IT > getBlockOffsets(bool is_row)
BlockSpGEMM(SpParMat< IT, NTA, DERA > &A, SpParMat< IT, NTB, DERB > &B, int br, int bc, int bi=1)
Definition BlockSpGEMM.h:31
SpParMat< IT, NTC, DERC > getNextBlock(IT &roffset, IT &coffset)
Definition BlockSpGEMM.h:51