Skip to content

Commit d4fdcab

Browse files
fix: add thread safety primitives to internal shard list
1 parent 74b3865 commit d4fdcab

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

include/dpp/cluster.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,11 @@ class DPP_EXPORT cluster {
152152
*/
153153
std::shared_mutex named_commands_mutex;
154154

155+
/**
156+
* @brief Mutex for protection of shards list
157+
*/
158+
mutable std::shared_mutex shards_mutex;
159+
155160
/**
156161
* @brief Typedef for slashcommand handler type
157162
*/
@@ -558,9 +563,9 @@ class DPP_EXPORT cluster {
558563
/**
559564
* @brief Get the list of shards
560565
*
561-
* @return shard_list& Reference to map of shards for this cluster
566+
* @return shard_list map of shards for this cluster
562567
*/
563-
const shard_list& get_shards();
568+
shard_list get_shards() const;
564569

565570
/**
566571
* @brief Sets the request timeout.

src/dpp/cluster.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,11 @@ void cluster::start(start_type return_after) {
229229
if (now >= shard_reconnect_time) {
230230
/* This shard needs to be reconnected */
231231
reconnections.erase(reconnect);
232-
discord_client* old = shards[shard_id];
232+
discord_client* old = nullptr;
233+
{
234+
std::shared_lock lk(shards_mutex);
235+
old = shards[shard_id];
236+
}
233237
/* These values must be copied to the new connection
234238
* to attempt to resume it
235239
*/
@@ -238,6 +242,7 @@ void cluster::start(start_type return_after) {
238242
log(ll_info, "Reconnecting shard " + std::to_string(shard_id));
239243
/* Make a new resumed connection based off the old one */
240244
try {
245+
std::unique_lock lk(shards_mutex);
241246
if (shards[shard_id] != nullptr) {
242247
log(ll_trace, "Attempting resume...");
243248
shards[shard_id] = nullptr;
@@ -255,6 +260,7 @@ void cluster::start(start_type return_after) {
255260
shards[shard_id]->run();
256261
}
257262
catch (const std::exception& e) {
263+
std::unique_lock lk(shards_mutex);
258264
log(ll_info, "Exception when reconnecting shard " + std::to_string(shard_id) + ": " + std::string(e.what()));
259265
delete shards[shard_id];
260266
delete old;
@@ -340,6 +346,7 @@ void cluster::start(start_type return_after) {
340346
if (s % maxclusters == cluster_id) {
341347
/* Each discord_client is inserted into the socket engine when we call run() */
342348
try {
349+
std::unique_lock lk(shards_mutex);
343350
this->shards[s] = new discord_client(this, s, numshards, token, intents, compressed, ws_mode);
344351
this->shards[s]->run();
345352
}
@@ -455,6 +462,7 @@ void cluster::shutdown() {
455462
next_timer = {};
456463
}
457464

465+
std::unique_lock lk(shards_mutex);
458466
/* Terminate shards */
459467
for (const auto& sh : shards) {
460468
delete sh.second;
@@ -581,6 +589,7 @@ void cluster::set_presence(const dpp::presence &p) {
581589
}
582590

583591
json pres = p.to_json();
592+
std::shared_lock lk(shards_mutex);
584593
for (auto& s : shards) {
585594
if (s.second->is_connected()) {
586595
s.second->queue_message(s.second->jsonobj_to_string(pres));
@@ -610,15 +619,16 @@ std::string cluster::get_audit_reason() {
610619
}
611620

612621
discord_client* cluster::get_shard(uint32_t id) const {
622+
std::shared_lock lk(shards_mutex);
613623
auto i = shards.find(id);
614624
if (i != shards.end()) {
615625
return i->second;
616-
} else {
617-
return nullptr;
618626
}
627+
return nullptr;
619628
}
620629

621-
const shard_list& cluster::get_shards() {
630+
shard_list cluster::get_shards() const {
631+
std::shared_lock lk(shards_mutex);
622632
return shards;
623633
}
624634

0 commit comments

Comments
 (0)