1#ifndef _MULTIWAY_MERGE_H_
2#define _MULTIWAY_MERGE_H_
27 std::vector<T> tsum(nthreads+1);
29 T* out =
new T[
size+1];
38 ithread = omp_get_thread_num();
43#pragma omp for schedule(static)
45 for (
int i=0; i<
size; i++)
51 tsum[ithread+1] = sum;
56 for(
int i=0; i<(ithread+1); i++)
62#pragma omp for schedule(static)
64 for (
int i=0; i<
size; i++)
84template <
typename RT,
typename IT,
typename NT>
87 std::vector<RT> splitters(nsplits+1);
88 splitters[0] =
static_cast<RT
>(0);
89 ColLexiCompare<IT,NT> comp;
91#pragma omp parallel for
93 for(
int i=1; i< nsplits; i++)
95 IT cur_col = i * (spTuples->getncol()/nsplits);
96 std::tuple<IT,IT,NT> search_tuple(0, cur_col,
NT());
97 std::tuple<IT,IT,NT>* it = std::lower_bound (spTuples->tuples, spTuples->tuples + spTuples->getnnz(), search_tuple, comp);
98 splitters[i] = (RT) (it - spTuples->tuples);
100 splitters[nsplits] = spTuples->getnnz();
107 template <
typename RT,
typename IT,
typename NT>
110 std::vector<RT> splitters(nsplits+1);
111 splitters[0] =
static_cast<RT
>(0);
112 ColLexiCompare<IT,NT> comp;
114 std::tuple<IT,IT,NT>* start = spTuples->tuples;
115 std::tuple<IT,IT,NT>* end = spTuples->tuples + spTuples->getnnz();
116 for(
int i=1; i< nsplits; i++)
118 IT cur_col = i * (spTuples->getncol()/nsplits);
119 std::tuple<IT,IT,NT> search_tuple(0, cur_col,
NT());
120 std::tuple<IT,IT,NT>* it = std::lower_bound (start, end, search_tuple, comp);
121 splitters[i] = (RT) (it - spTuples->tuples);
124 splitters[nsplits] = spTuples->getnnz();
130template<
class IT,
class NT>
133 int nlists = ArrSpTups.size();
134 ColLexiCompare<IT,int> heapcomp;
135 std::vector<std::tuple<IT, IT, int>> heap(nlists);
136 std::vector<IT> curptr(nlists,
static_cast<IT>(0));
138 for(
int i=0; i< nlists; ++i)
140 if(ArrSpTups[i]->getnnz()>0)
142 heap[hsize++] = std::make_tuple(std::get<0>(ArrSpTups[i]->tuples[0]), std::get<1>(ArrSpTups[i]->tuples[0]), i);
146 std::make_heap(heap.data(), heap.data()+hsize, std::not2(heapcomp));
148 std::tuple<IT, IT, NT> curTuple;
152 std::pop_heap(heap.data(), heap.data() + hsize, std::not2(heapcomp));
153 int source = std::get<2>(heap[hsize-1]);
154 if( (estnnz ==0) || (std::get<0>(curTuple) != std::get<0>(heap[hsize-1])) || (std::get<1>(curTuple) != std::get<1>(heap[hsize-1])))
156 curTuple = ArrSpTups[source]->tuples[curptr[source]];
160 if(curptr[source] != ArrSpTups[source]->getnnz())
162 heap[hsize-1] = std::make_tuple(std::get<0>(ArrSpTups[source]->tuples[curptr[source]]),
163 std::get<1>(ArrSpTups[source]->tuples[curptr[source]]), source);
164 std::push_heap(heap.data(), heap.data()+hsize, std::not2(heapcomp));
184template<
class SR,
class IT,
class NT>
185void SerialMerge(
const std::vector<SpTuples<IT,NT> *> & ArrSpTups, std::tuple<IT, IT, NT> * ntuples)
187 int nlists = ArrSpTups.size();
188 ColLexiCompare<IT,int> heapcomp;
189 std::vector<std::tuple<IT, IT, int>> heap(nlists);
190 std::vector<IT> curptr(nlists,
static_cast<IT>(0));
193 for(
int i=0; i< nlists; ++i)
195 if(ArrSpTups[i]->getnnz()>0)
197 estnnz += ArrSpTups[i]->getnnz();
198 heap[hsize++] = std::make_tuple(std::get<0>(ArrSpTups[i]->tuples[0]), std::get<1>(ArrSpTups[i]->tuples[0]), i);
202 std::make_heap(heap.data(), heap.data()+hsize, std::not2(heapcomp));
207 std::pop_heap(heap.data(), heap.data() + hsize, std::not2(heapcomp));
208 int source = std::get<2>(heap[hsize-1]);
211 ((std::get<0>(ntuples[cnz-1]) == std::get<0>(heap[hsize-1])) && (std::get<1>(ntuples[cnz-1]) == std::get<1>(heap[hsize-1]))) )
213 std::get<2>(ntuples[cnz-1]) = SR::add(std::get<2>(ntuples[cnz-1]), ArrSpTups[source]->numvalue(curptr[source]++));
217 ntuples[cnz++] = ArrSpTups[source]->tuples[curptr[source]++];
220 if(curptr[source] != ArrSpTups[source]->getnnz())
222 heap[hsize-1] = std::make_tuple(std::get<0>(ArrSpTups[source]->tuples[curptr[source]]),
223 std::get<1>(ArrSpTups[source]->tuples[curptr[source]]), source);
224 std::push_heap(heap.data(), heap.data()+hsize, std::not2(heapcomp));
237 template<
class IT,
class NT>
241 int nlists = ArrSpTups.size();
242 IT ncols = endCol - startCol;
243 std::vector<IT> curptr(nlists,
static_cast<IT>(0));
244 const IT minHashTableSize = 16;
245 const IT hashScale = 107;
246 std::vector<NT> globalHashVec(minHashTableSize);
250 IT* colnnzC =
new IT[ncols]();
254 for(
IT col = 0; col<ncols; col++)
256 IT globalCol = col + startCol;
260 for(
int i=0; i<nlists; i++)
262 IT curidx = curptr[i];
263 while((ArrSpTups[i]->getnnz()>curidx) && (ArrSpTups[i]->colindex(curidx++) == globalCol))
268 size_t ht_size = minHashTableSize;
269 while(ht_size < nnzcol)
273 if(globalHashVec.size() < ht_size)
274 globalHashVec.resize(ht_size);
276 for(
size_t j=0; j < ht_size; ++j)
278 globalHashVec[j] = -1;
280 for(
int i=0; i<nlists; i++)
283 while((ArrSpTups[i]->getnnz()>curptr[i]) && (ArrSpTups[i]->colindex(curptr[i]) == globalCol))
285 IT key = ArrSpTups[i]->rowindex(curptr[i]);
286 IT hash = (key*hashScale) & (ht_size-1);
290 if (globalHashVec[hash] == key)
294 else if (globalHashVec[hash] == -1)
296 globalHashVec[
hash] = key;
308 totnnz += colnnzC[col];
309 if(colnnzC[col] > maxnnzPerCol) maxnnzPerCol = colnnzC[col];
320 template<
class SR,
class IT,
class NT>
321 void SerialMergeHash(
const std::vector<SpTuples<IT,NT> *> & ArrSpTups, std::tuple<IT, IT, NT> * ntuples,
IT* colnnz,
IT maxcolnnz,
IT startCol,
IT endCol,
bool sorted)
323 int nlists = ArrSpTups.size();
324 IT ncols = endCol - startCol;
326 std::vector<IT> curptr(nlists,
static_cast<IT>(0));
328 const IT minHashTableSize = 16;
329 const IT hashScale = 107;
331 std::vector< std::pair<uint32_t,NT>> globalHashVec(std::max(minHashTableSize, maxcolnnz*2));
333 for(
IT col = 0; col<ncols; col++)
335 IT globalCol = col + startCol;
336 size_t ht_size = minHashTableSize;
337 while(ht_size < colnnz[col])
341 for(
size_t j=0; j < ht_size; ++j)
343 globalHashVec[j].first = -1;
345 for(
int i=0; i<nlists; i++)
347 while((ArrSpTups[i]->getnnz()>curptr[i]) && (ArrSpTups[i]->colindex(curptr[i]) == globalCol))
349 IT key = ArrSpTups[i]->rowindex(curptr[i]);
350 IT hash = (key*hashScale) & (ht_size-1);
354 NT curval = ArrSpTups[i]->numvalue(curptr[i]);
355 if (globalHashVec[hash].first == key)
357 globalHashVec[
hash].second = SR::add(curval, globalHashVec[hash].second);
360 else if (globalHashVec[hash].first == -1)
362 globalHashVec[
hash].first = key;
363 globalHashVec[
hash].second = curval;
378 for (
size_t j=0; j < ht_size; ++j)
380 if (globalHashVec[j].first != -1)
382 globalHashVec[index++] = globalHashVec[j];
385 integerSort<NT>(globalHashVec.data(), index);
389 for (
size_t j=0; j < index; ++j)
391 ntuples[outptr++]= std::make_tuple(globalHashVec[j].first, globalCol, globalHashVec[j].second);
396 for (
size_t j=0; j < ht_size; ++j)
398 if (globalHashVec[j].first != -1)
400 ntuples[outptr++]= std::make_tuple(globalHashVec[j].first, globalCol, globalHashVec[j].second);
411template<
class SR,
class IT,
class NT>
412SpTuples<IT, NT>*
MultiwayMerge( std::vector<SpTuples<IT,NT> *> & ArrSpTups,
IT mdim = 0,
IT ndim = 0,
bool delarrs =
false )
415 int nlists = ArrSpTups.size();
418 return new SpTuples<IT,NT>(0, mdim, ndim);
431 std::tuple<IT, IT, NT>* mergeTups =
new std::tuple<IT, IT, NT>[ArrSpTups[0]->getnnz()];
433#pragma omp parallel for
435 for(
int i=0; i<ArrSpTups[0]->getnnz(); i++)
436 mergeTups[i] = ArrSpTups[0]->tuples[i];
438 return new SpTuples<IT,NT> (ArrSpTups[0]->getnnz(), mdim, ndim, mergeTups,
false);
443 for(
int i=0; i< nlists; ++i)
445 if((mdim != ArrSpTups[i]->getnrow()) || ndim != ArrSpTups[i]->getncol())
447 std::cerr <<
"Dimensions of SpTuples do not match on multiwayMerge()" << std::endl;
448 return new SpTuples<IT,NT>(0,0,0);
456 nthreads = omp_get_num_threads();
459 int nsplits = 4*nthreads;
460 nsplits = std::min(nsplits, (
int)ndim);
461 std::vector< std::vector<IT> > colPtrs;
462 for(
int i=0; i< nlists; i++)
464 colPtrs.push_back(findColSplitters<IT>(ArrSpTups[i], nsplits));
467 std::vector<IT> mergedNnzPerSplit(nsplits);
468 std::vector<IT> inputNnzPerSplit(nsplits);
471#pragma omp parallel for schedule(dynamic)
473 for(
int i=0; i< nsplits; i++)
475 std::vector<SpTuples<IT,NT> *> listSplitTups(nlists);
476 IT t =
static_cast<IT>(0);
477 for(
int j=0; j< nlists; ++j)
479 IT curnnz= colPtrs[j][i+1] - colPtrs[j][i];
480 listSplitTups[j] =
new SpTuples<IT, NT> (curnnz, mdim, ndim, ArrSpTups[j]->tuples + colPtrs[j][i],
true);
481 t += colPtrs[j][i+1] - colPtrs[j][i];
484 inputNnzPerSplit[i] = t;
487 std::vector<IT> mdisp(nsplits+1,0);
488 for(
int i=0; i<nsplits; ++i)
489 mdisp[i+1] = mdisp[i] + mergedNnzPerSplit[i];
490 IT mergedNnzAll = mdisp[nsplits];
494 IT inputNnzAll = std::accumulate(inputNnzPerSplit.begin(), inputNnzPerSplit.end(),
static_cast<IT>(0));
495 double ratio = inputNnzAll / (double) mergedNnzAll;
496 std::ostringstream outs;
497 outs <<
"Multiwaymerge: inputNnz/mergedNnz = " << ratio << std::endl;
504 std::tuple<IT, IT, NT> * mergeBuf =
new std::tuple<IT, IT, NT>[mergedNnzAll];
507#pragma omp parallel for schedule(dynamic)
509 for(
int i=0; i< nsplits; i++)
511 std::vector<SpTuples<IT,NT> *> listSplitTups(nlists);
512 for(
int j=0; j< nlists; ++j)
514 IT curnnz= colPtrs[j][i+1] - colPtrs[j][i];
515 listSplitTups[j] =
new SpTuples<IT, NT> (curnnz, mdim, ndim, ArrSpTups[j]->tuples + colPtrs[j][i],
true);
517 SerialMerge<SR>(listSplitTups, mergeBuf + mdisp[i]);
520 for(
int i=0; i< nlists; i++)
525 return new SpTuples<IT, NT> (mergedNnzAll, mdim, ndim, mergeBuf,
true,
false);
536 template<
class SR,
class IT,
class NT>
537 SpTuples<IT, NT>*
MultiwayMergeHash( std::vector<SpTuples<IT,NT> *> & ArrSpTups,
IT mdim = 0,
IT ndim = 0,
bool delarrs =
false,
bool sorted=
true )
540 MPI_Comm_size(MPI_COMM_WORLD,&
nprocs);
541 MPI_Comm_rank(MPI_COMM_WORLD,&myrank);
543 int nlists = ArrSpTups.size();
546 return new SpTuples<IT,NT>(0, mdim, ndim);
556 std::tuple<IT, IT, NT>* mergeTups =
static_cast<std::tuple<IT, IT, NT>*
>
557 (::operator
new (
sizeof(std::tuple<IT, IT, NT>[ArrSpTups[0]->getnnz()])));
559#pragma omp parallel for
561 for(
int i=0; i<ArrSpTups[0]->getnnz(); i++)
562 mergeTups[i] = ArrSpTups[0]->tuples[i];
567 return new SpTuples<IT,NT> (ArrSpTups[0]->getnnz(), mdim, ndim, mergeTups,
true,
true);
572 for(
int i=0; i< nlists; ++i)
574 if((mdim != ArrSpTups[i]->getnrow()) || ndim != ArrSpTups[i]->getncol())
576 std::cerr <<
"Dimensions of SpTuples do not match on multiwayMerge()" << std::endl;
577 return new SpTuples<IT,NT>(0,0,0);
585 nthreads = omp_get_num_threads();
588 int nsplits = 4*nthreads;
589 nsplits = std::min(nsplits, (
int)ndim);
590 std::vector< std::vector<IT> > colPtrs(nlists);
594#pragma omp parallel for
596 for(
int j=0; j< nlists; j++)
598 colPtrs[j]=findColSplittersFinger<IT>(ArrSpTups[j], nsplits);
606 std::vector<std::vector<SpTuples<IT,NT> *>> listSplitTups(nsplits);
608 for(
int i=0; i< nsplits; ++i)
610 listSplitTups[i].resize(nlists);
612 for(
int j=0; j< nlists; ++j)
614 IT curnnz= colPtrs[j][i+1] - colPtrs[j][i];
615 listSplitTups[i][j] =
new SpTuples<IT, NT> (curnnz, mdim, ndim, ArrSpTups[j]->tuples + colPtrs[j][i],
true);
620 std::vector<IT> mergedNnzPerSplit(nsplits);
621 std::vector<IT> mergedNnzPerSplit1(nsplits);
622 std::vector<IT> maxNnzPerColumnSplit(nsplits);
623 std::vector<IT*> nnzPerColSplit(nsplits);
627#pragma omp parallel for schedule(dynamic)
629 for(
int i=0; i< nsplits; i++)
631 IT startCol = i* (ndim/nsplits);
632 IT endCol = (i+1)* (ndim/nsplits);
633 if(i == (nsplits-1)) endCol = ndim;
635 nnzPerColSplit[i] =
SerialMergeNNZHash(listSplitTups[i], mergedNnzPerSplit[i], maxNnzPerColumnSplit[i], startCol, endCol);
638 std::vector<IT> mdisp(nsplits+1,0);
639 for(
int i=0; i<nsplits; ++i)
640 mdisp[i+1] = mdisp[i] + mergedNnzPerSplit[i];
641 IT mergedNnzAll = mdisp[nsplits];
644 std::tuple<IT, IT, NT> * mergeBuf =
static_cast<std::tuple<IT, IT, NT>*
> (::operator
new (
sizeof(std::tuple<IT, IT, NT>[mergedNnzAll])));
651#pragma omp parallel for schedule(dynamic)
653 for(
int i=0; i< nsplits; i++)
656 IT startCol = i* (ndim/nsplits);
657 IT endCol = (i+1)* (ndim/nsplits);
658 if(i == (nsplits-1)) endCol = ndim;
659 SerialMergeHash<SR>(listSplitTups[i], mergeBuf + mdisp[i], nnzPerColSplit[i], maxNnzPerColumnSplit[i], startCol, endCol, sorted);
664 for(
int i=0; i< nsplits; ++i)
666 delete nnzPerColSplit[i];
667 for(
int j=0; j< nlists; ++j)
669 listSplitTups[i][j]->tuples_deleted =
true;
670 delete listSplitTups[i][j];
674 for(
int i=0; i< nlists; i++)
683 return new SpTuples<IT, NT> (mergedNnzAll, mdim, ndim, mergeBuf,
true,
true);
692 template<
class SR,
class IT,
class NT>
693 SpTuples<IT, NT>*
MultiwayMergeHashSliding( std::vector<SpTuples<IT,NT> *> & ArrSpTups,
IT mdim = 0,
IT ndim = 0,
bool delarrs =
false,
bool sorted=
true,
IT maxHashTableSize = 16384)
696 MPI_Comm_size(MPI_COMM_WORLD,&
nprocs);
697 MPI_Comm_rank(MPI_COMM_WORLD,&myrank);
699 int nlists = ArrSpTups.size();
702 return new SpTuples<IT,NT>(0, mdim, ndim);
712 std::tuple<IT, IT, NT>* mergeTups =
static_cast<std::tuple<IT, IT, NT>*
>
713 (::operator
new (
sizeof(std::tuple<IT, IT, NT>[ArrSpTups[0]->getnnz()])));
715#pragma omp parallel for
717 for(
int i=0; i<ArrSpTups[0]->getnnz(); i++)
718 mergeTups[i] = ArrSpTups[0]->tuples[i];
723 return new SpTuples<IT,NT> (ArrSpTups[0]->getnnz(), mdim, ndim, mergeTups,
true,
true);
728 for(
int i=0; i< nlists; ++i)
730 if((mdim != ArrSpTups[i]->getnrow()) || ndim != ArrSpTups[i]->getncol())
732 std::cerr <<
"Dimensions of SpTuples do not match on MultiwayMergeHashSliding()" << std::endl;
734 return new SpTuples<IT,NT>(0,0,0);
742 nthreads = omp_get_num_threads();
745 int nsplits = 4*nthreads;
746 nsplits = std::min(nsplits, (
int)ndim);
748 const IT minHashTableSize = 16;
750 const IT hashScale = 107;
755 IT** colPtrs =
static_cast<IT**
> (::operator
new (
sizeof(
IT*[nlists])));
756 for(
int l = 0; l < nlists; l++){
757 colPtrs[l] =
static_cast<IT*
> (::operator
new (
sizeof(
IT[ndim+1])));
759 ColLexiCompare<IT,NT> colCmp;
760 RowLexiCompare<IT,NT> rowCmp;
763#pragma omp parallel for
765 for(
int s = 0; s < nsplits; s++){
766 IT startColSplit = s * (ndim/nsplits);
767 IT endColSplit = (s == (nsplits-1) ) ? ndim : (s+1) * (ndim/nsplits);
768 for(
int l = 0; l < nlists; l++){
769 std::tuple<IT, IT, NT> firstTuple(0, startColSplit,
NT());
770 std::tuple<IT, IT, NT>* first = std::lower_bound(ArrSpTups[l]->tuples, ArrSpTups[l]->tuples+ArrSpTups[l]->getnnz(), firstTuple, colCmp);
771 std::tuple<IT, IT, NT> lastTuple(0, endColSplit,
NT());
772 std::tuple<IT, IT, NT>* last = std::lower_bound(ArrSpTups[l]->tuples, ArrSpTups[l]->tuples+ArrSpTups[l]->getnnz(), lastTuple, colCmp);
773 for(
IT c = startColSplit; c < endColSplit; c++){
774 if(c == 0) colPtrs[l][c] = 0;
776 std::tuple<IT, IT, NT> searchTuple(0, c,
NT());
777 std::tuple<IT, IT, NT>* pos = std::lower_bound(first, last, searchTuple, colCmp);
778 colPtrs[l][c] = pos - ArrSpTups[l]->tuples;
781 if(s == nsplits-1) colPtrs[l][ndim] = ArrSpTups[l]->getnnz();
785 size_t* flopsPerCol =
static_cast<size_t*
> (::operator
new (
sizeof(
size_t[ndim])));
786 IT* nWindowPerColSymbolic =
static_cast<IT*
> (::operator
new (
sizeof(
IT[ndim])));
788#pragma omp parallel for
790 for(
IT c = 0; c < ndim; c++){
792 for(
int l = 0; l < nlists; l++){
793 flopsPerCol[c] += colPtrs[l][c+1] - colPtrs[l][c];
795 nWindowPerColSymbolic[c] = flopsPerCol[c] / maxHashTableSize + 1;
798 size_t* prefixSumFlopsPerCol = prefixSum<size_t>(flopsPerCol, ndim, nthreads);
799 size_t totalFlops = prefixSumFlopsPerCol[ndim];
800 size_t flopsPerSplit = totalFlops / nsplits;
801 IT* colSplitters =
static_cast<size_t*
> (::operator
new (
sizeof(
size_t[nsplits+1])));
808#pragma omp parallel for
810 for(
int s = 0; s < nsplits; s++){
811 size_t searchItem = s * flopsPerSplit;
812 size_t* searchResult = std::lower_bound(prefixSumFlopsPerCol, prefixSumFlopsPerCol + ndim + 1, searchItem);
813 colSplitters[s] = searchResult - prefixSumFlopsPerCol;
815 colSplitters[nsplits] = ndim;
821 IT* prefixSumWindowSymbolic = prefixSum<IT>(nWindowPerColSymbolic, ndim, nthreads);
823 std::pair<IT, IT>* windowsSymbolic =
static_cast<std::pair<IT, IT>*
> (::operator
new (
sizeof(std::pair<IT, IT>[prefixSumWindowSymbolic[ndim]])));
824 IT* nnzPerCol =
static_cast<IT*
> (::operator
new (
sizeof(
IT[ndim])));
825 IT* nWindowPerCol =
static_cast<IT*
> (::operator
new (
sizeof(
IT[ndim])));
830 std::pair<IT, IT>** rowIdsRange =
static_cast<std::pair<IT, IT>**
> (::operator
new (
sizeof(std::pair<IT, IT>*[nsplits])));
831 for(
int s = 0; s < nsplits; s++){
832 rowIdsRange[s] =
static_cast<std::pair<IT, IT>*
> (::operator
new (
sizeof(std::pair<IT, IT>[nlists])));
839 std::vector<NT> globalHashVec(minHashTableSize);
840 size_t tid = omp_get_thread_num();
842#pragma omp for schedule(dynamic)
844 for(
int s = 0; s < nsplits; s++){
845 IT startCol = colSplitters[s];
846 IT endCol = colSplitters[s+1];
848 for(
IT c = startCol; c < endCol; c++){
850 nWindowPerCol[c] = 1;
851 if(nWindowPerColSymbolic[c] == 1){
854 size_t wsIdx = prefixSumWindowSymbolic[c];
856 windowsSymbolic[wsIdx].first = 0;
857 windowsSymbolic[wsIdx].second = 0;
859 size_t htSize = minHashTableSize;
860 while(htSize < flopsPerCol[c]) {
864 if(globalHashVec.size() < htSize) globalHashVec.resize(htSize);
865 for(
size_t j=0; j < htSize; ++j) {
866 globalHashVec[j] = -1;
869 for(
int l = 0; l < nlists; l++){
870 for(
IT i = colPtrs[l][c]; i < colPtrs[l][c+1]; i++){
871 IT key = ArrSpTups[l]->rowindex(i);
872 IT hash = (key*hashScale) & (htSize-1);
876 if (globalHashVec[hash] == key) {
880 else if (globalHashVec[hash] == -1) {
882 globalHashVec[
hash] = key;
884 windowsSymbolic[wsIdx].second++;
896 IT nrowsPerWindow = mdim / nWindowPerColSymbolic[c];
898 for(
size_t w = 0; w < nWindowPerColSymbolic[c]; w++){
899 IT rowStart = w * nrowsPerWindow;
900 IT rowEnd = (w == nWindowPerColSymbolic[c]-1) ? mdim : (w+1) * nrowsPerWindow;
901 size_t wsIdx = prefixSumWindowSymbolic[c] + w;
903 windowsSymbolic[wsIdx].first = rowStart;
904 windowsSymbolic[wsIdx].second = 0;
906 size_t flopsWindow = 0;
907 for(
int l = 0; l < nlists; l++){
908 std::tuple<IT, IT, NT>* first = ArrSpTups[l]->tuples + colPtrs[l][c];
909 std::tuple<IT, IT, NT>* last = ArrSpTups[l]->tuples + colPtrs[l][c+1];
912 std::tuple<IT, IT, NT> searchTuple(rowStart, 0,
NT());
913 first = std::lower_bound(first, last, searchTuple, rowCmp);
917 std::tuple<IT, IT, NT> searchTuple(rowEnd, 0,
NT());
918 last = std::lower_bound(first, last, searchTuple, rowCmp);
921 rowIdsRange[s][l].first = first - (ArrSpTups[l]->tuples);
922 rowIdsRange[s][l].second = last - (ArrSpTups[l]->tuples);
924 flopsWindow += last - first;
926 size_t htSize = minHashTableSize;
927 while(htSize < flopsWindow) {
931 if(globalHashVec.size() < htSize) globalHashVec.resize(htSize);
932 for(
size_t j=0; j < htSize; ++j) {
933 globalHashVec[j] = -1;
935 for(
int l = 0; l < nlists; l++){
936 for(
IT i = rowIdsRange[s][l].first; i < rowIdsRange[s][l].second; i++){
937 IT key = ArrSpTups[l]->rowindex(i);
938 IT hash = (key * hashScale) & (htSize-1);
941 if (globalHashVec[hash] == key) {
945 else if (globalHashVec[hash] == -1) {
947 globalHashVec[
hash] = key;
949 windowsSymbolic[wsIdx].second++;
961 runningSum = windowsSymbolic[wsIdx].second;
964 if(runningSum + windowsSymbolic[wsIdx].second > maxHashTableSize) {
966 runningSum = windowsSymbolic[wsIdx].second;
969 runningSum += windowsSymbolic[wsIdx].second;
981 IT* prefixSumWindow = prefixSum<IT>(nWindowPerCol, ndim, nthreads);
982 std::pair<IT, IT>* windows =
static_cast<std::pair<IT, IT>*
> (::operator
new (
sizeof(std::pair<IT, IT>[prefixSumWindow[ndim]])));
985#pragma omp parallel for schedule(dynamic)
987 for(
int s = 0; s < nsplits; s++){
988 IT colStart = colSplitters[s];
989 IT colEnd = colSplitters[s+1];
990 for(
IT c = colStart; c < colEnd; c++){
991 IT nWindowSymbolic = nWindowPerColSymbolic[c];
992 IT wsIdx = prefixSumWindowSymbolic[c];
993 IT wcIdx = prefixSumWindow[c];
994 windows[wcIdx].first = windowsSymbolic[wsIdx].first;
995 windows[wcIdx].second = windowsSymbolic[wsIdx].second;
997 for(
IT w = 1; w < nWindowSymbolic; w++){
998 wsIdx = prefixSumWindowSymbolic[c] + w;
999 if(windows[wcIdx].second + windowsSymbolic[wsIdx].second > maxHashTableSize){
1001 windows[wcIdx].first = windowsSymbolic[wsIdx].first;
1002 windows[wcIdx].second = windowsSymbolic[wsIdx].second;
1005 windows[wcIdx].second = windows[wcIdx].second + windowsSymbolic[wsIdx].second;
1011 IT* prefixSumNnzPerCol = prefixSum<IT>(nnzPerCol, ndim, nthreads);
1012 IT totalNnz = prefixSumNnzPerCol[ndim];
1013 std::tuple<IT, IT, NT> * mergeBuf =
static_cast<std::tuple<IT, IT, NT>*
> (::operator
new (
sizeof(std::tuple<IT, IT, NT>[totalNnz])));
1019 std::vector< std::pair<uint32_t,NT> > globalHashVec(minHashTableSize);
1020 size_t tid = omp_get_thread_num();
1022#pragma omp for schedule(dynamic)
1024 for(
int s = 0; s < nsplits; s++){
1025 IT startCol = colSplitters[s];
1026 IT endCol = colSplitters[s+1];
1027 for(
IT c = startCol; c < endCol; c++){
1028 IT nWindow = nWindowPerCol[c];
1029 IT outptr = prefixSumNnzPerCol[c];
1031 IT wcIdx = prefixSumWindow[c];
1032 IT nnzWindow = windows[wcIdx].second;
1034 size_t htSize = minHashTableSize;
1035 while(htSize < nnzWindow)
1040 if(globalHashVec.size() < htSize) globalHashVec.resize(htSize);
1041 for(
size_t j=0; j < htSize; ++j)
1043 globalHashVec[j].first = -1;
1046 for(
int l = 0; l < nlists; l++){
1047 for(
IT i = colPtrs[l][c]; i < colPtrs[l][c+1]; i++){
1048 IT key = ArrSpTups[l]->rowindex(i);
1049 IT hash = (key * hashScale) & (htSize-1);
1052 if (globalHashVec[hash].first == key) {
1055 globalHashVec[
hash].second += ArrSpTups[l]->numvalue(i);
1058 else if (globalHashVec[hash].first == -1) {
1061 globalHashVec[
hash].first = key;
1062 globalHashVec[
hash].second = ArrSpTups[l]->numvalue(i);
1074 for (
size_t j=0; j < htSize; j++){
1075 if (globalHashVec[j].first != -1){
1076 globalHashVec[index] = globalHashVec[j];
1080 integerSort<NT>(globalHashVec.data(), index);
1082 for(
size_t j = 0; j < index; j++){
1083 mergeBuf[outptr] = std::tuple<IT, IT, NT>(globalHashVec[j].first, c, globalHashVec[j].second);
1088 for (
size_t j=0; j < htSize; j++){
1089 if (globalHashVec[j].first != -1){
1090 mergeBuf[outptr] = std::tuple<IT, IT, NT>(globalHashVec[j].first, c, globalHashVec[j].second);
1097 for (
int l = 0; l < nlists; l++){
1098 rowIdsRange[s][l].first = colPtrs[l][c];
1099 rowIdsRange[s][l].second = colPtrs[l][c+1];
1102 for (
size_t w = 0; w < nWindow; w++){
1103 IT wcIdx = prefixSumWindow[c] + w;
1104 IT startRow = windows[wcIdx].first;
1105 IT endRow = (w == nWindow-1) ? mdim : windows[wcIdx+1].first;
1106 IT nnzWindow = windows[wcIdx].second;
1108 size_t htSize = minHashTableSize;
1109 while(htSize < nnzWindow) htSize <<= 1;
1110 if(globalHashVec.size() < htSize) globalHashVec.resize(htSize);
1111 for(
size_t j = 0; j < htSize; j++) globalHashVec[j].first = -1;
1113 for(
int l = 0; l < nlists; l++){
1114 while( rowIdsRange[s][l].first < rowIdsRange[s][l].second ){
1115 IT i = rowIdsRange[s][l].first;
1116 IT key = ArrSpTups[l]->rowindex(i);
1117 if(key >= endRow)
break;
1118 IT hash = (key * hashScale) & (htSize-1);
1120 if (globalHashVec[hash].first == key) {
1123 globalHashVec[
hash].second += ArrSpTups[l]->numvalue(i);
1126 else if (globalHashVec[hash].first == -1) {
1129 globalHashVec[
hash].first = key;
1130 globalHashVec[
hash].second = ArrSpTups[l]->numvalue(i);
1138 rowIdsRange[s][l].first++;
1143 for (
size_t j=0; j < htSize; j++){
1144 if (globalHashVec[j].first != -1){
1145 globalHashVec[index++] = globalHashVec[j];
1148 integerSort<NT>(globalHashVec.data(), index);
1150 for(
size_t j = 0; j < index; j++){
1151 mergeBuf[outptr] = std::tuple<IT, IT, NT>(globalHashVec[j].first, c, globalHashVec[j].second);
1156 for (
size_t j=0; j < htSize; j++){
1157 if (globalHashVec[j].first != -1){
1158 mergeBuf[outptr] = std::tuple<IT, IT, NT>(globalHashVec[j].first, c, globalHashVec[j].second);
1170 delete [] prefixSumFlopsPerCol;
1171 delete [] prefixSumNnzPerCol;
1172 delete [] prefixSumWindowSymbolic;
1173 delete [] prefixSumWindow;
1176 ::operator
delete(colSplitters);
1177 for(
int s = 0; s < nsplits; s++) ::operator
delete(rowIdsRange[s]);
1178 ::operator
delete(rowIdsRange);
1180 ::operator
delete(nWindowPerColSymbolic);
1181 ::operator
delete(windowsSymbolic);
1182 ::operator
delete(nWindowPerCol);
1183 ::operator
delete(windows);
1185 ::operator
delete(flopsPerCol);
1186 ::operator
delete(nnzPerCol);
1188 for(
int l = 0; l < nlists; l++) ::operator
delete(colPtrs[l]);
1189 ::operator
delete(colPtrs);
1191 for(
int i=0; i< nlists; i++)
1194 delete ArrSpTups[i];
1200 return new SpTuples<IT, NT> (totalNnz, mdim, ndim, mergeBuf,
true,
true);
static void Print(const std::string &s)
void SerialMerge(const std::vector< SpTuples< IT, NT > * > &ArrSpTups, std::tuple< IT, IT, NT > *ntuples)
IT SerialMergeNNZ(const std::vector< SpTuples< IT, NT > * > &ArrSpTups)
void SerialMergeHash(const std::vector< SpTuples< IT, NT > * > &ArrSpTups, std::tuple< IT, IT, NT > *ntuples, IT *colnnz, IT maxcolnnz, IT startCol, IT endCol, bool sorted)
T * prefixSum(T *in, int size, int nthreads)
SpTuples< IT, NT > * MultiwayMerge(std::vector< SpTuples< IT, NT > * > &ArrSpTups, IT mdim=0, IT ndim=0, bool delarrs=false)
SpTuples< IT, NT > * MultiwayMergeHash(std::vector< SpTuples< IT, NT > * > &ArrSpTups, IT mdim=0, IT ndim=0, bool delarrs=false, bool sorted=true)
IT * SerialMergeNNZHash(const std::vector< SpTuples< IT, NT > * > &ArrSpTups, IT &totnnz, IT &maxnnzPerCol, IT startCol, IT endCol)
std::vector< RT > findColSplitters(SpTuples< IT, NT > *&spTuples, int nsplits)
SpTuples< IT, NT > * MultiwayMergeHashSliding(std::vector< SpTuples< IT, NT > * > &ArrSpTups, IT mdim=0, IT ndim=0, bool delarrs=false, bool sorted=true, IT maxHashTableSize=16384)
std::vector< RT > findColSplittersFinger(SpTuples< IT, NT > *&spTuples, int nsplits)
unsigned int hash(unsigned int a)