Skip to content

Commit 076886f

Browse files
committed
websocket: avoid races, remove ping goroutine, use a single mutex
1 parent 070d66a commit 076886f

File tree

1 file changed

+33
-70
lines changed

1 file changed

+33
-70
lines changed

dnscrypt-proxy/monitoring_ui.go

Lines changed: 33 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,9 @@ type MonitoringUI struct {
106106
clientsMutex sync.Mutex
107107
proxy *Proxy
108108

109+
// Mutex for all WebSocket write operations to prevent races
110+
writesMutex sync.Mutex
111+
109112
// WebSocket broadcast rate limiting
110113
broadcastMutex sync.Mutex
111114
lastBroadcast time.Time
@@ -891,32 +894,24 @@ func (ui *MonitoringUI) handleWebSocket(w http.ResponseWriter, r *http.Request)
891894
return
892895
}
893896

894-
// Configure upgrader with more permissive settings
895-
upgrader := websocket.Upgrader{
896-
ReadBufferSize: 1024,
897-
WriteBufferSize: 1024,
898-
CheckOrigin: func(r *http.Request) bool {
899-
return true // Allow all origins
900-
},
901-
}
902-
903-
conn, err := upgrader.Upgrade(w, r, nil)
897+
conn, err := ui.upgrader.Upgrade(w, r, nil)
904898
if err != nil {
905899
dlog.Warnf("WebSocket upgrade error: %v", err)
906900
return
907901
}
908902

909-
// Set read/write deadlines
910-
conn.SetReadDeadline(time.Now().Add(120 * time.Second))
911-
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
912-
903+
// Register the client
913904
ui.clientsMutex.Lock()
914905
ui.clients[conn] = true
915906
ui.clientsMutex.Unlock()
916907

917908
// Send initial metrics
918-
metrics := ui.metricsCollector.GetMetrics()
919-
if err := conn.WriteJSON(metrics); err != nil {
909+
ui.writesMutex.Lock()
910+
conn.SetWriteDeadline(time.Now().Add(10 * time.Second))
911+
err = conn.WriteJSON(ui.metricsCollector.GetMetrics())
912+
ui.writesMutex.Unlock()
913+
914+
if err != nil {
920915
dlog.Warnf("WebSocket initial write error: %v", err)
921916
conn.Close()
922917
ui.clientsMutex.Lock()
@@ -935,78 +930,43 @@ func (ui *MonitoringUI) handleWebSocket(w http.ResponseWriter, r *http.Request)
935930
dlog.Debugf("WebSocket connection closed and cleaned up")
936931
}()
937932

938-
// Create a ping handler to keep the connection alive
939-
conn.SetPingHandler(func(data string) error {
940-
dlog.Debugf("Received ping from client")
941-
return conn.WriteControl(websocket.PongMessage, []byte{}, time.Now().Add(5*time.Second))
942-
})
943-
944-
// Create a pong handler to respond to server pings
945-
conn.SetPongHandler(func(data string) error {
933+
// Set up ping/pong handlers for keep-alive (using WebSocket protocol level)
934+
conn.SetReadDeadline(time.Now().Add(120 * time.Second))
935+
conn.SetPongHandler(func(string) error {
946936
dlog.Debugf("Received pong from client")
947937
conn.SetReadDeadline(time.Now().Add(120 * time.Second))
948938
return nil
949939
})
950940

951941
for {
952-
// Reset read deadline for each message
953-
conn.SetReadDeadline(time.Now().Add(120 * time.Second))
954-
955-
// Read message
956-
messageType, message, err := conn.ReadMessage()
942+
// Read message from client
943+
var msg map[string]interface{}
944+
err := conn.ReadJSON(&msg)
957945
if err != nil {
958946
if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) {
959947
dlog.Warnf("WebSocket unexpected close error: %v", err)
960-
} else {
961-
dlog.Debugf("WebSocket read error (normal): %v", err)
962948
}
963949
break
964950
}
965951

966-
// Handle ping message from client
967-
if messageType == websocket.TextMessage {
968-
var msg map[string]interface{}
969-
if err := json.Unmarshal(message, &msg); err == nil {
970-
if msgType, ok := msg["type"].(string); ok && msgType == "ping" {
971-
dlog.Debugf("Received ping message from client")
972-
// Send a pong response
973-
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
974-
if err := conn.WriteJSON(map[string]string{"type": "pong"}); err != nil {
975-
dlog.Warnf("Error sending pong: %v", err)
976-
}
977-
978-
// Also send updated metrics
979-
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
980-
if err := conn.WriteJSON(ui.metricsCollector.GetMetrics()); err != nil {
981-
dlog.Warnf("Error sending metrics after ping: %v", err)
982-
}
983-
}
984-
}
985-
}
986-
}
987-
}()
988-
989-
// Send periodic pings to keep the connection alive
990-
go func() {
991-
ticker := time.NewTicker(30 * time.Second)
992-
defer ticker.Stop()
952+
// Handle ping message from client (application level)
953+
if msgType, ok := msg["type"].(string); ok && msgType == "ping" {
954+
dlog.Debugf("Received ping message from client")
993955

994-
for {
995-
select {
996-
case <-ticker.C:
997-
ui.clientsMutex.Lock()
998-
if _, exists := ui.clients[conn]; !exists {
999-
ui.clientsMutex.Unlock()
1000-
return // Connection is closed, stop the goroutine
956+
// Send pong response and updated metrics
957+
ui.writesMutex.Lock()
958+
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
959+
if err := conn.WriteJSON(map[string]string{"type": "pong"}); err != nil {
960+
ui.writesMutex.Unlock()
961+
dlog.Warnf("Error sending pong: %v", err)
962+
break
1001963
}
1002-
ui.clientsMutex.Unlock()
1003964

1004-
// Send ping
1005965
conn.SetWriteDeadline(time.Now().Add(5 * time.Second))
1006-
if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
1007-
dlog.Debugf("Error sending ping: %v", err)
1008-
return
966+
if err := conn.WriteJSON(ui.metricsCollector.GetMetrics()); err != nil {
967+
dlog.Warnf("Error sending metrics after ping: %v", err)
1009968
}
969+
ui.writesMutex.Unlock()
1010970
}
1011971
}
1012972
}()
@@ -1111,6 +1071,9 @@ func (ui *MonitoringUI) scheduleBroadcast() {
11111071
func (ui *MonitoringUI) broadcastMetrics() {
11121072
metrics := ui.metricsCollector.GetMetrics()
11131073

1074+
ui.writesMutex.Lock()
1075+
defer ui.writesMutex.Unlock()
1076+
11141077
ui.clientsMutex.Lock()
11151078
defer ui.clientsMutex.Unlock()
11161079

0 commit comments

Comments
 (0)