enum_name
Loading...
Searching...
No Matches
mpi_reporter.h
1#ifndef DOCTEST_MPI_REPORTER_H
2#define DOCTEST_MPI_REPORTER_H
3
4// #include <doctest/doctest.h>
5#include <fstream>
6#include <string>
7#include "mpi.h"
8
9
10#include <vector>
11#include <mutex>
12
13namespace doctest {
14
15extern int nb_test_cases_skipped_insufficient_procs;
16int mpi_comm_world_size();
17
18namespace {
19
20// https://stackoverflow.com/a/11826666/1583122
21struct NullBuffer : std::streambuf {
22 int overflow(int c) { return c; }
23};
24class NullStream : public std::ostream {
25 public:
26 NullStream()
27 : std::ostream(&nullBuff)
28 {}
29 private:
30 NullBuffer nullBuff = {};
31};
32static NullStream nullStream;
33
34
35/* \brief Extends the ConsoleReporter of doctest
36 * Each process writes its results to its own file
37 * Intended to be used when a test assertion fails and the user wants to know exactly what happens on which process
38 */
39struct MpiFileReporter : public ConsoleReporter {
40 std::ofstream logfile_stream = {};
41
42 MpiFileReporter(const ContextOptions& co)
43 : ConsoleReporter(co,logfile_stream)
44 {
45 int rank = 0;
46 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
47
48 std::string logfile_name = "doctest_" + std::to_string(rank) + ".log";
49
50 logfile_stream = std::ofstream(logfile_name.c_str(), std::fstream::out);
51 }
52};
53
54
55/* \brief Extends the ConsoleReporter of doctest
56 * Allows to manage the execution of tests in a parallel framework
57 * All results are collected on rank 0
58 */
59struct MpiConsoleReporter : public ConsoleReporter {
60private:
61 static std::ostream& replace_by_null_if_not_rank_0(std::ostream* os) {
62 int rank = 0;
63 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
64 if (rank==0) {
65 return *os;
66 } else {
67 return nullStream;
68 }
69 }
70 std::vector<std::pair<std::string, int>> m_failure_str_queue = {};
71public:
72 MpiConsoleReporter(const ContextOptions& co)
73 : ConsoleReporter(co,replace_by_null_if_not_rank_0(co.cout))
74 {}
75
76 std::string file_line_to_string(const char* file, int line,
77 const char* tail = ""){
78 std::stringstream ss;
79 ss << skipPathFromFilename(file)
80 << (opt.gnu_file_line ? ":" : "(")
81 << (opt.no_line_numbers ? 0 : line) // 0 or the real num depending on the option
82 << (opt.gnu_file_line ? ":" : "):") << tail;
83 return ss.str();
84 }
85
86 void test_run_end(const TestRunStats& p) override {
87 ConsoleReporter::test_run_end(p);
88
89 const bool anythingFailed = p.numTestCasesFailed > 0 || p.numAssertsFailed > 0;
90
91 // -----------------------------------------------------
92 // > Gather information in rank 0
93 int n_rank, rank;
94 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
95 MPI_Comm_size(MPI_COMM_WORLD, &n_rank);
96
97 int g_numAsserts = 0;
98 int g_numAssertsFailed = 0;
99 int g_numTestCasesFailed = 0;
100
101 MPI_Reduce(&p.numAsserts , &g_numAsserts , 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
102 MPI_Reduce(&p.numAssertsFailed , &g_numAssertsFailed , 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
103 MPI_Reduce(&p.numTestCasesFailed, &g_numTestCasesFailed, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
104
105 std::vector<int> numAssertsFailedByRank;
106 if(rank == 0){
107 numAssertsFailedByRank.resize(static_cast<std::size_t>(n_rank));
108 }
109
110 MPI_Gather(&p.numAssertsFailed, 1, MPI_INT, numAssertsFailedByRank.data(), 1, MPI_INT, 0, MPI_COMM_WORLD);
111
112 if(rank == 0) {
113 separator_to_stream();
114 s << Color::Cyan << "[doctest] " << Color::None << "assertions on all processes: " << std::setw(6)
115 << g_numAsserts << " | "
116 << ((g_numAsserts == 0 || anythingFailed) ? Color::None : Color::Green)
117 << std::setw(6) << (g_numAsserts - g_numAssertsFailed) << " passed" << Color::None
118 << " | " << (g_numAssertsFailed > 0 ? Color::Red : Color::None) << std::setw(6)
119 << g_numAssertsFailed << " failed" << Color::None << " |\n";
120 if (nb_test_cases_skipped_insufficient_procs>0) {
121 s << Color::Cyan << "[doctest] " << Color::Yellow << "WARNING: Skipped ";
122 if (nb_test_cases_skipped_insufficient_procs>1) {
123 s << nb_test_cases_skipped_insufficient_procs << " tests requiring more than ";
124 } else {
125 s << nb_test_cases_skipped_insufficient_procs << " test requiring more than ";
126 }
127 if (mpi_comm_world_size()>1) {
128 s << mpi_comm_world_size() << " MPI processes to run\n";
129 } else {
130 s << mpi_comm_world_size() << " MPI process to run\n";
131 }
132 }
133
134 separator_to_stream();
135 if(g_numAssertsFailed > 0){
136
137 s << Color::Cyan << "[doctest] " << Color::None << "fail on rank:" << std::setw(6) << "\n";
138 for(std::size_t i = 0; i < numAssertsFailedByRank.size(); ++i){
139 if( numAssertsFailedByRank[i] > 0 ){
140 s << std::setw(16) << " -> On rank [" << i << "] with " << numAssertsFailedByRank[i] << " test failed" << std::endl;
141 }
142 }
143 }
144 s << Color::Cyan << "[doctest] " << Color::None
145 << "Status: " << (g_numTestCasesFailed > 0 ? Color::Red : Color::Green)
146 << ((g_numTestCasesFailed > 0) ? "FAILURE!" : "SUCCESS!") << Color::None << std::endl;
147 }
148 }
149
150 void test_case_end(const CurrentTestCaseStats& st) override {
151 if (is_mpi_test_case()) {
152 // function called by every rank at the end of a test
153 // if failed assertions happened, they have been sent to rank 0
154 // here rank zero gathers them and prints them all
155
156 int rank;
157 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
158
159 std::vector<MPI_Request> requests;
160 requests.reserve(m_failure_str_queue.size()); // avoid realloc & copy of MPI_Request
161 for (const std::pair<std::string, int> &failure : m_failure_str_queue)
162 {
163 const std::string & failure_str = failure.first;
164 const int failure_line = failure.second;
165
166 int failure_msg_size = static_cast<int>(failure_str.size());
167
168 requests.push_back(MPI_REQUEST_NULL);
169 MPI_Isend(failure_str.c_str(), failure_msg_size, MPI_BYTE,
170 0, failure_line, MPI_COMM_WORLD, &requests.back()); // Tag = file line
171 }
172
173
174 // Compute the number of assert with fail among all procs
175 const int nb_fail_asserts = static_cast<int>(m_failure_str_queue.size());
176 int nb_fail_asserts_glob = 0;
177 MPI_Reduce(&nb_fail_asserts, &nb_fail_asserts_glob, 1, MPI_INT, MPI_SUM, 0, MPI_COMM_WORLD);
178
179 if(rank == 0) {
180 MPI_Status status;
181 MPI_Status status_recv;
182
183 using id_string = std::pair<int,std::string>;
184 std::vector<id_string> msgs(static_cast<std::size_t>(nb_fail_asserts_glob));
185
186 for (std::size_t i=0; i<static_cast<std::size_t>(nb_fail_asserts_glob); ++i) {
187 MPI_Probe(MPI_ANY_SOURCE, MPI_ANY_TAG, MPI_COMM_WORLD, &status);
188
189 int count;
190 MPI_Get_count(&status, MPI_BYTE, &count);
191
192 std::string recv_msg(static_cast<std::size_t>(count),'\0');
193 void* recv_msg_data = const_cast<char*>(recv_msg.data()); // const_cast needed. Non-const .data() exists in C++11 though...
194 MPI_Recv(recv_msg_data, count, MPI_BYTE, status.MPI_SOURCE,
195 status.MPI_TAG, MPI_COMM_WORLD, &status_recv);
196
197 msgs[i] = {status.MPI_SOURCE,recv_msg};
198 }
199
200 std::sort(begin(msgs),end(msgs),[](const id_string& x, const id_string& y){ return x.first < y.first; });
201
202 // print
203 if (nb_fail_asserts_glob>0) {
204 separator_to_stream();
205 file_line_to_stream(tc->m_file.c_str(), static_cast<int>(tc->m_line), "\n");
206 if(tc->m_test_suite && tc->m_test_suite[0] != '\0')
207 s << Color::Yellow << "TEST SUITE: " << Color::None << tc->m_test_suite << "\n";
208 if(strncmp(tc->m_name, " Scenario:", 11) != 0)
209 s << Color::Yellow << "TEST CASE: ";
210 s << Color::None << tc->m_name << "\n\n";
211 for(const auto& msg : msgs) {
212 s << msg.second;
213 }
214 s << "\n";
215 }
216 }
217
218 MPI_Waitall(static_cast<int>(requests.size()), requests.data(), MPI_STATUSES_IGNORE);
219 m_failure_str_queue.clear();
220 }
221
222 ConsoleReporter::test_case_end(st);
223 }
224
225 bool is_mpi_test_case() const {
226 return tc->m_description != nullptr
227 && std::string(tc->m_description) == std::string("MPI_TEST_CASE");
228 }
229
230 void log_assert(const AssertData& rb) override {
231 if (!is_mpi_test_case()) {
232 ConsoleReporter::log_assert(rb);
233 } else {
234 int rank;
235 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
236
237
238 if(!rb.m_failed && !opt.success)
239 return;
240
241 std::lock_guard<std::mutex> lock(mutex);
242
243 std::stringstream failure_msg;
244 failure_msg << Color::Red << "On rank [" << rank << "] : " << Color::None;
245 failure_msg << file_line_to_string(rb.m_file, rb.m_line, " ");
246
247 if((rb.m_at & (assertType::is_throws_as | assertType::is_throws_with)) ==0){
248 failure_msg << Color::Cyan
249 << assertString(rb.m_at)
250 << "( " << rb.m_expr << " ) "
251 << Color::None
252
253 << (!rb.m_failed ? "is correct!\n" : "is NOT correct!\n")
254 << " values: "
255 << assertString(rb.m_at)
256 << "( " << rb.m_decomp.c_str() << " )\n";
257 }
258
259 m_failure_str_queue.push_back({failure_msg.str(), rb.m_line});
260 }
261 }
262}; // MpiConsoleReporter
263
264// "1" is the priority - used for ordering when multiple reporters/listeners are used
265REGISTER_REPORTER("MpiConsoleReporter", 1, MpiConsoleReporter);
266REGISTER_REPORTER("MpiFileReporter", 1, MpiFileReporter);
267
268} // anonymous
269} // doctest
270
271#endif // DOCTEST_REPORTER_H
Definition doctest.h:530