COMBINATORIAL_BLAS
1.6
Loading...
Searching...
No Matches
test_mpipspgemm.cpp
Go to the documentation of this file.
1
#include <mpi.h>
2
#include <
sys/time.h
>
3
#include <iostream>
4
#include <functional>
5
#include <algorithm>
6
#include <vector>
7
#include <string>
8
#include <sstream>
9
#include <
stdint.h
>
10
#include <cmath>
11
#include "CombBLAS/CombBLAS.h"
12
#include "
Glue.h
"
13
#include "
CCGrid.h
"
14
#include "
Reductions.h
"
15
#include "
Multiplier.h
"
16
#include "
SplitMatDist.h
"
17
18
using namespace
std;
19
using namespace
combblas
;
20
21
double
comm_bcast
;
22
double
comm_reduce
;
23
double
comp_summa
;
24
double
comp_reduce
;
25
double
comp_result
;
26
double
comp_reduce_layer
;
27
double
comp_split
;
28
double
comp_trans
;
29
double
comm_split
;
30
31
#define ITERS 1
32
33
int
main
(
int
argc
,
char
*
argv
[])
34
{
35
int
provided
;
36
//MPI_Init_thread(&argc, &argv, MPI_THREAD_SINGLE, &provided);
37
38
39
MPI_Init_thread
(&
argc
, &
argv
,
MPI_THREAD_SERIALIZED
, &
provided
);
40
if
(
provided
<
MPI_THREAD_SERIALIZED
)
41
{
42
printf
(
"ERROR: The MPI library does not have MPI_THREAD_SERIALIZED support\n"
);
43
MPI_Abort
(
MPI_COMM_WORLD
, 1);
44
}
45
46
47
int
nprocs
, myrank;
48
MPI_Comm_size
(
MPI_COMM_WORLD
,&
nprocs
);
49
MPI_Comm_rank
(
MPI_COMM_WORLD
,&myrank);
50
51
if
(
argc
!= 8)
52
{
53
cout
<<
argc
<<
endl
;
54
if
(myrank == 0)
55
{
56
printf
(
"Usage (input): ./mpipspgemm <GridRows> <GridCols> <Layers> <matA> <matB> <matC> <algo>\n"
);
57
printf
(
"Example: ./mpipspgemm 4 4 2 matA.mtx matB.mtx matB.mtx threaded\n"
);
58
printf
(
"algo: outer | column | threaded \n"
);
59
}
60
return
-1;
61
}
62
63
unsigned
GRROWS
= (
unsigned
)
atoi
(
argv
[1]);
64
unsigned
GRCOLS
= (
unsigned
)
atoi
(
argv
[2]);
65
unsigned
C_FACTOR
= (
unsigned
)
atoi
(
argv
[3]);
66
CCGrid
CMG
(
C_FACTOR
,
GRCOLS
);
67
68
if
(
GRROWS
!=
GRCOLS
)
69
{
70
SpParHelper::Print
(
"This version of the Combinatorial BLAS only works on a square logical processor grid\n"
);
71
MPI_Barrier
(
MPI_COMM_WORLD
);
72
MPI_Abort
(
MPI_COMM_WORLD
, 1);
73
}
74
75
int
layer_length
=
GRROWS
*
GRCOLS
;
76
if
(
layer_length
*
C_FACTOR
!=
nprocs
)
77
{
78
SpParHelper::Print
(
"The product of <GridRows> <GridCols> <Replicas> does not match the number of processes\n"
);
79
MPI_Barrier
(
MPI_COMM_WORLD
);
80
MPI_Abort
(
MPI_COMM_WORLD
, 1);
81
}
82
83
84
85
SpDCCols<int32_t, double>
splitA
,
splitB
,
controlC
;
86
SpDCCols<int32_t, double>
*
splitC
;
87
string
type;
88
89
90
string
fileA
(
argv
[4]);
91
string
fileB
(
argv
[5]);
92
string
fileC
(
argv
[6]);
93
94
{
95
shared_ptr<CommGrid>
layerGrid
;
96
layerGrid
.reset(
new
CommGrid
(
CMG
.layerWorld, 0, 0) );
97
FullyDistVec<int32_t, int32_t>
p(
layerGrid
);
// permutation vector defined on layers
98
99
double
t01
=
MPI_Wtime
();
100
101
SpDCCols<int32_t, double>
*
A
=
ReadMat<double>
(
fileA
,
CMG
,
true
, p);
102
SpDCCols<int32_t, double>
*
B
=
ReadMat<double>
(
fileB
,
CMG
,
true
, p);
103
SpDCCols<int32_t, double>
*
C
=
ReadMat<double>
(
fileC
,
CMG
,
true
, p);
104
105
SplitMat
(
CMG
,
A
,
splitA
,
false
);
106
SplitMat
(
CMG
,
B
,
splitB
,
true
);
107
SplitMat
(
CMG
,
C
,
controlC
,
false
);
108
109
if
(myrank == 0)
cout
<<
"Matrices read and replicated along layers : time "
<<
MPI_Wtime
() -
t01
<<
endl
;
110
111
type =
string
(
argv
[7]);
112
113
if
(type ==
string
(
"outer"
))
114
{
115
for
(
int
k=0; k<
ITERS
; k++)
116
{
117
splitB
.Transpose();
// locally "transpose" [ABAB: check correctness]
118
splitC
=
multiply
(
splitA
,
splitB
,
CMG
,
true
,
false
);
// outer product
119
if
(
controlC
== *
splitC
)
120
SpParHelper::Print
(
"Outer product multiplication working correctly\n"
);
121
else
122
SpParHelper::Print
(
"ERROR in Outer product multiplication, go fix it!\n"
);
123
delete
splitC
;
124
}
125
126
}
127
else
if
(type ==
string
(
"column"
))
128
{
129
130
for
(
int
k=0; k<
ITERS
; k++)
131
{
132
splitC
=
multiply
(
splitA
,
splitB
,
CMG
,
false
,
false
);
133
if
(
controlC
== *
splitC
)
134
SpParHelper::Print
(
"Col-heap multiplication working correctly\n"
);
135
else
136
SpParHelper::Print
(
"ERROR in Col-heap multiplication, go fix it!\n"
);
137
138
delete
splitC
;
139
}
140
141
}
142
else
// default threaded
143
{
144
for
(
int
k=0; k<
ITERS
; k++)
145
{
146
splitC
=
multiply
(
splitA
,
splitB
,
CMG
,
false
,
true
);
147
if
(
controlC
== *
splitC
)
148
SpParHelper::Print
(
"Col-heap-threaded multiplication working correctly\n"
);
149
else
150
SpParHelper::Print
(
"ERROR in Col-heap-threaded multiplication, go fix it!\n"
);
151
delete
splitC
;
152
}
153
}
154
}
155
156
MPI_Finalize
();
157
return
0;
158
}
159
160
CCGrid.h
main
int main()
Definition
Driver.cpp:12
Glue.h
Multiplier.h
Reductions.h
SplitMatDist.h
B
Definition
test.cpp:53
combblas::CCGrid
Definition
CCGrid.h:7
combblas::CommGrid
Definition
CommGrid.h:45
combblas::DistEdgeList
Definition
DistEdgeList.h:82
combblas::SpParHelper::Print
static void Print(const std::string &s)
Definition
SpParHelper.cpp:836
nprocs
int nprocs
Definition
comms.cpp:55
combblas
Definition
CCGrid.h:4
combblas::SplitMat
void SplitMat(CCGrid &CMG, SpDCCols< IT, NT > *localmat, SpDCCols< IT, NT > &splitmat, bool rowsplit=false)
Definition
SplitMatDist.h:144
combblas::multiply
SpDCCols< IT, NT > * multiply(SpDCCols< IT, NT > &splitA, SpDCCols< IT, NT > &splitB, CCGrid &CMG, bool isBT, bool threaded)
Definition
Multiplier.h:11
A
double A
C
double C
Definition
options.h:15
stdint.h
comm_reduce
double comm_reduce
Definition
test_mpipspgemm.cpp:22
comp_split
double comp_split
Definition
test_mpipspgemm.cpp:27
ITERS
#define ITERS
Definition
test_mpipspgemm.cpp:31
comm_bcast
double comm_bcast
Definition
test_mpipspgemm.cpp:21
comm_split
double comm_split
Definition
test_mpipspgemm.cpp:29
comp_reduce
double comp_reduce
Definition
test_mpipspgemm.cpp:24
comp_result
double comp_result
Definition
test_mpipspgemm.cpp:25
comp_trans
double comp_trans
Definition
test_mpipspgemm.cpp:28
comp_summa
double comp_summa
Definition
test_mpipspgemm.cpp:23
comp_reduce_layer
double comp_reduce_layer
Definition
test_mpipspgemm.cpp:26
time.h
3DSpGEMM
test_mpipspgemm.cpp
Generated by
1.9.8