COMBINATORIAL_BLAS
1.6
Loading...
Searching...
No Matches
KTipsTest.cpp
Go to the documentation of this file.
1
#include <mpi.h>
2
#include <iostream>
3
#include <algorithm>
4
#include <functional>
5
#include <string>
6
#include <vector>
7
#include "CombBLAS/CombBLAS.h"
8
9
using namespace
combblas
;
10
11
template
<
class
IT>
12
struct
KTipsSR
13
{
14
static
IT
id
() {
return
static_cast<
IT
>
(0); }
15
static
bool
returnedSAID
() {
return
false
; }
16
static
MPI_Op
mpi_op
() {
return
MPI_LOR
; }
17
static
IT
add
(
const
IT
&
arg1
,
const
IT
&
arg2
) {
return
(
arg1
||
arg2
); }
18
static
IT
multiply
(
const
IT
&
arg1
,
const
IT
&
arg2
) {
return
(
arg1
&&
arg2
); }
19
static
void
axpy
(
IT
a,
const
IT
&
x
,
IT
&
y
) {
y
=
add
(
y
,
multiply
(a,
x
)); }
20
};
21
22
template
<
class
IT,
class
NT,
class
DER>
23
FullyDistVec<IT,IT>
LastNzRowIdxPerCol
(
const
SpParMat<IT,NT,DER>
&
A
)
24
{
25
std::shared_ptr<CommGrid>
grid
=
A
.getcommgrid();
26
int
myrank =
grid
->GetRank();
27
int
myproccol =
grid
->GetRankInProcRow();
28
int
myprocrow =
grid
->GetRankInProcCol();
29
30
MPI_Comm
ColWorld
=
grid
->GetColWorld();
31
32
IT
total_rows
=
A
.getnrow();
33
IT
total_cols
=
A
.getncol();
34
35
int
procrows =
grid
->GetGridRows();
36
int
proccols =
grid
->GetGridCols();
37
38
IT
rows_perproc
=
total_rows
/ procrows;
39
IT
cols_perproc
=
total_cols
/ proccols;
40
41
IT
row_offset
= myprocrow *
rows_perproc
;
42
IT
col_offset
= myproccol *
cols_perproc
;
43
44
DER
*spSeq =
A
.seqptr();
45
46
IT
localcols
= spSeq->getncol();
47
std::vector<IT>
local_colidx
(
localcols
,
static_cast<
IT
>
(-1));
48
49
for
(
auto
colit
= spSeq->begcol();
colit
!= spSeq->endcol(); ++
colit
)
50
{
51
auto
nzit
= spSeq->begnz(
colit
);
52
if
(
nzit
!= spSeq->endnz(
colit
))
53
local_colidx
[
colit
.colid()] =
nzit
.rowid() +
row_offset
;
54
}
55
56
MPI_Allreduce
(
MPI_IN_PLACE
,
local_colidx
.data(),
static_cast<
int
>
(
localcols
),
MPIType<IT>
(),
MPI_MAX
,
ColWorld
);
57
58
std::vector<IT>
fillarr
;
59
60
if
(!myprocrow)
61
for
(
auto
itr
=
local_colidx
.begin();
itr
!=
local_colidx
.end(); ++
itr
)
62
fillarr
.push_back(*
itr
);
63
64
return
FullyDistVec<IT,IT>
(
fillarr
,
grid
);
65
}
66
67
template
<
class
IT,
class
NT,
class
DER>
68
SpParMat<IT,NT,DER>
FrontierMat
(
const
SpParMat<IT,NT,DER>
&
A
,
const
FullyDistSpVec<IT,IT>
&
sources
,
const
NT
&
initval
)
69
{
70
FullyDistVec<IT,IT>
ri
=
sources
.FindInds([](
int
arg1
) {
return
true
; });
71
FullyDistVec<IT,IT>
ci
(
A
.getcommgrid());
72
ci
.iota(
sources
.getnnz(),
static_cast<
IT
>
(0));
73
return
SpParMat<IT,NT,DER>
(
A
.getnrow(),
sources
.getnnz(),
ri
,
ci
,
initval
,
false
);
74
}
75
76
int
main
(
int
argc
,
char
*
argv
[])
77
{
78
int
myrank,
nprocs
;
79
MPI_Init
(&
argc
, &
argv
);
80
MPI_Comm_rank
(
MPI_COMM_WORLD
, &myrank);
81
MPI_Comm_size
(
MPI_COMM_WORLD
, &
nprocs
);
82
83
if
(
argc
!= 3)
84
{
85
if
(!myrank)
86
std::cerr <<
"Usage: ./KTipsTest <Matrix> <l>"
<< std::endl;
87
MPI_Finalize
();
88
return
-1;
89
}
90
91
{
92
int
l
=
atoi
(
argv
[2]);
93
94
std::shared_ptr<CommGrid>
fullWorld
;
95
fullWorld
.reset(
new
CommGrid
(
MPI_COMM_WORLD
, 0, 0));
96
97
SpParMat<int,int, SpDCCols<int,int>
>
A
(
fullWorld
);
98
99
A
.ParallelReadMM(std::string(
argv
[1]),
false
,
maximum<int>
());
100
101
FullyDistVec<int,int>
D
=
A
.Reduce(
Column
, std::plus<int>(),
static_cast<
int
>
(0));
102
FullyDistSpVec<int,int>
R
=
D
.Find(std::bind2nd(std::equal_to<int>(),
static_cast<
int
>
(1)));
103
104
SpParMat<int,int, SpDCCols<int,int>
>
F0
=
FrontierMat
(
A
,
R
,
static_cast<
int
>
(1));
105
106
SpParMat<int,int, SpDCCols<int,int>
>
F1
=
PSpGEMM<KTipsSR<int>
>(
A
,
F0
);
107
SpParMat<int,int, SpDCCols<int,int>
>
V
=
F0
;
108
V
+=
F1
;
109
110
FullyDistVec<int,int>
TipSources
(
A
.getcommgrid(),
F0
.getncol(),
static_cast<
int
>
(-1));
111
FullyDistVec<int,int>
TipDests
(
A
.getcommgrid(),
F0
.getncol(),
static_cast<
int
>
(-1));
112
113
for
(
int
k = 1; k <=
l
; ++k)
114
{
115
SpParMat<int,int, SpDCCols<int,int>
>
F2
=
PSpGEMM<KTipsSR<int>
>(
A
,
F1
);
116
F2
.SetDifference(
V
);
117
V
+=
F2
;
118
119
FullyDistVec<int,int>
Ns
=
F2
.Reduce(
Column
, std::plus<int>(),
static_cast<
int
>
(0));
120
121
FullyDistSpVec<int,int>
Tc
=
Ns
.Find(std::bind2nd(std::greater_equal<int>(),
static_cast<
int
>
(2)));
122
FullyDistSpVec<int,int>
Td
=
Ns
.Find(std::bind2nd(std::not_equal_to<int>(),
static_cast<
int
>
(1)));
123
124
FullyDistVec<int,int>
C0
=
LastNzRowIdxPerCol
(
F0
);
125
FullyDistVec<int,int>
C1
=
LastNzRowIdxPerCol
(
F1
);
126
127
FullyDistSpVec<int,int>
kSources
=
C0
.GGet(
Tc
, [](
const
int
arg1
,
const
int
arg2
) {
return
arg2
; },
static_cast<
int
>
(-1));
128
FullyDistSpVec<int,int>
kDests
=
C1
.GGet(
Tc
, [](
const
int
arg1
,
const
int
arg2
) {
return
arg2
; },
static_cast<
int
>
(-1));
129
130
TipSources
.Set(
kSources
);
131
TipDests
.Set(
kDests
);
132
133
F1
.PruneColumnByIndex(
Td
);
134
F2
.PruneColumnByIndex(
Td
);
135
136
F0
=
F1
;
137
F1
=
F2
;
138
}
139
140
TipSources
.DebugPrint();
141
TipDests
.DebugPrint();
142
}
143
144
MPI_Finalize
();
145
return
0;
146
}
main
int main()
Definition
Driver.cpp:12
LastNzRowIdxPerCol
FullyDistVec< IT, IT > LastNzRowIdxPerCol(const SpParMat< IT, NT, DER > &A)
Definition
KTipsTest.cpp:23
FrontierMat
SpParMat< IT, NT, DER > FrontierMat(const SpParMat< IT, NT, DER > &A, const FullyDistSpVec< IT, IT > &sources, const NT &initval)
Definition
KTipsTest.cpp:68
combblas::CommGrid
Definition
CommGrid.h:45
combblas::DistEdgeList
Definition
DistEdgeList.h:82
nprocs
int nprocs
Definition
comms.cpp:55
combblas
Definition
CCGrid.h:4
combblas::Column
@ Column
Definition
SpDefs.h:115
A
double A
D
double D
Definition
options.h:15
KTipsSR
Definition
KTipsTest.cpp:13
KTipsSR::axpy
static void axpy(IT a, const IT &x, IT &y)
Definition
KTipsTest.cpp:19
KTipsSR::mpi_op
static MPI_Op mpi_op()
Definition
KTipsTest.cpp:16
KTipsSR::returnedSAID
static bool returnedSAID()
Definition
KTipsTest.cpp:15
KTipsSR::id
static IT id()
Definition
KTipsTest.cpp:14
KTipsSR::add
static IT add(const IT &arg1, const IT &arg2)
Definition
KTipsTest.cpp:17
KTipsSR::multiply
static IT multiply(const IT &arg1, const IT &arg2)
Definition
KTipsTest.cpp:18
ReleaseTests
KTipsTest.cpp
Generated by
1.9.8