Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class ZagryadskovMComplexSpMMCCSSTL : public BaseTask {
private:
inline static void SpMM(const CCS &a, const CCS &b, CCS &c);
inline static void SpMMSymbolic(const CCS &a, const CCS &b, std::vector<int> &col_ptr, int jstart, int jend);
inline static void SpMMNumeric(const CCS &a, const CCS &b, CCS &c, const std::complex<double> &zero, double eps,
int jstart, int jend);
inline static void SpMMKernel(const CCS &a, const CCS &b, CCS &c, const std::complex<double> &zero, double eps,
inline static void SpMMNumeric(const CCS &a, const CCS &b, CCS &c, const std::complex<double> &zero, int jstart,
int jend);
inline static void SpMMKernel(const CCS &a, const CCS &b, CCS &c, const std::complex<double> &zero,
std::vector<int> &rows, std::vector<std::complex<double>> &acc,
std::vector<int> &marker, int j);
bool ValidationImpl() override;
Expand Down
21 changes: 8 additions & 13 deletions tasks/zagryadskov_m_complex_spmm_ccs/stl/src/ops_stl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@ void ZagryadskovMComplexSpMMCCSSTL::SpMMSymbolic(const CCS &a, const CCS &b, std
}

void ZagryadskovMComplexSpMMCCSSTL::SpMMKernel(const CCS &a, const CCS &b, CCS &c, const std::complex<double> &zero,
double eps, std::vector<int> &rows,
std::vector<std::complex<double>> &acc, std::vector<int> &marker,
int j) {
std::vector<int> &rows, std::vector<std::complex<double>> &acc,
std::vector<int> &marker, int j) {
rows.clear();
int write_ptr = c.col_ptr[j];

Expand All @@ -58,23 +57,21 @@ void ZagryadskovMComplexSpMMCCSSTL::SpMMKernel(const CCS &a, const CCS &b, CCS &
}

for (int r_idx : rows) {
if (std::norm(acc[r_idx]) > eps * eps) {
c.row_ind[write_ptr] = r_idx;
c.values[write_ptr] = acc[r_idx];
++write_ptr;
}
c.row_ind[write_ptr] = r_idx;
c.values[write_ptr] = acc[r_idx];
++write_ptr;
acc[r_idx] = zero;
}
}

void ZagryadskovMComplexSpMMCCSSTL::SpMMNumeric(const CCS &a, const CCS &b, CCS &c, const std::complex<double> &zero,
double eps, int jstart, int jend) {
int jstart, int jend) {
std::vector<int> marker(a.m, -1);
std::vector<std::complex<double>> acc(a.m, zero);
std::vector<int> rows;

for (int j = jstart; j < jend; ++j) {
SpMMKernel(a, b, c, zero, eps, rows, acc, marker, j);
SpMMKernel(a, b, c, zero, rows, acc, marker, j);
}
}

Expand All @@ -85,7 +82,6 @@ void ZagryadskovMComplexSpMMCCSSTL::SpMM(const CCS &a, const CCS &b, CCS &c) {
std::vector<std::thread> threads(num_threads);

std::complex<double> zero(0.0, 0.0);
const double eps = 1e-14;
c.col_ptr.assign(c.n + 1, 0);

for (int tid = 0; tid < num_threads; ++tid) {
Expand All @@ -107,8 +103,7 @@ void ZagryadskovMComplexSpMMCCSSTL::SpMM(const CCS &a, const CCS &b, CCS &c) {
for (int tid = 0; tid < num_threads; ++tid) {
int jstart = (tid * b.n) / num_threads;
int jend = ((tid + 1) * b.n) / num_threads;
threads[tid] =
std::thread(SpMMNumeric, std::cref(a), std::cref(b), std::ref(c), std::cref(zero), eps, jstart, jend);
threads[tid] = std::thread(SpMMNumeric, std::cref(a), std::cref(b), std::ref(c), std::cref(zero), jstart, jend);
}
for (auto &th : threads) {
th.join();
Expand Down
Loading