Skip to content

Commit 2bf2ad6

Browse files
committed
fix(transport): make connection multiaddrs match the full multiaddr including sni and certhash components
1 parent f008fe3 commit 2bf2ad6

File tree

9 files changed

+123
-38
lines changed

9 files changed

+123
-38
lines changed

p2p/test/transport/gating_test.go

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func TestInterceptSecuredOutgoing(t *testing.T) {
101101
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
102102
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
103103
// remove the certhash component from WebTransport and WebRTC addresses
104-
require.Equal(t, stripCertHash(h2.Addrs()[0]).String(), addrs.RemoteMultiaddr().String())
104+
require.Equal(t, h2.Addrs()[0].String(), addrs.RemoteMultiaddr().String())
105105
}),
106106
)
107107
err := h1.Connect(ctx, peer.AddrInfo{ID: h2.ID(), Addrs: h2.Addrs()})
@@ -135,8 +135,7 @@ func TestInterceptUpgradedOutgoing(t *testing.T) {
135135
connGater.EXPECT().InterceptAddrDial(h2.ID(), gomock.Any()).Return(true),
136136
connGater.EXPECT().InterceptSecured(network.DirOutbound, h2.ID(), gomock.Any()).Return(true),
137137
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
138-
// remove the certhash component from WebTransport addresses
139-
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.RemoteMultiaddr())
138+
require.Equal(t, h2.Addrs()[0], c.RemoteMultiaddr())
140139
require.Equal(t, h1.ID(), c.LocalPeer())
141140
require.Equal(t, h2.ID(), c.RemotePeer())
142141
}))
@@ -170,12 +169,12 @@ func TestInterceptAccept(t *testing.T) {
170169
// In WebRTC, retransmissions of the STUN packet might cause us to create multiple connections,
171170
// if the first connection attempt is rejected.
172171
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
173-
// remove the certhash component from WebTransport addresses
172+
// remove the certhash component from WebRTC and WebTransport addresses
174173
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
175174
}).AnyTimes()
176175
} else {
177176
connGater.EXPECT().InterceptAccept(gomock.Any()).Do(func(addrs network.ConnMultiaddrs) {
178-
// remove the certhash component from WebTransport addresses
177+
// remove the certhash component from WebRTC and WebTransport addresses
179178
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
180179
})
181180
}
@@ -213,8 +212,7 @@ func TestInterceptSecuredIncoming(t *testing.T) {
213212
gomock.InOrder(
214213
connGater.EXPECT().InterceptAccept(gomock.Any()).Return(true),
215214
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Do(func(_ network.Direction, _ peer.ID, addrs network.ConnMultiaddrs) {
216-
// remove the certhash component from WebTransport addresses
217-
require.Equal(t, stripCertHash(h2.Addrs()[0]), addrs.LocalMultiaddr())
215+
require.Equal(t, h2.Addrs()[0], addrs.LocalMultiaddr())
218216
}),
219217
)
220218
h1.Peerstore().AddAddrs(h2.ID(), h2.Addrs(), time.Hour)
@@ -248,7 +246,7 @@ func TestInterceptUpgradedIncoming(t *testing.T) {
248246
connGater.EXPECT().InterceptSecured(network.DirInbound, h1.ID(), gomock.Any()).Return(true),
249247
connGater.EXPECT().InterceptUpgraded(gomock.Any()).Do(func(c network.Conn) {
250248
// remove the certhash component from WebTransport addresses
251-
require.Equal(t, stripCertHash(h2.Addrs()[0]), c.LocalMultiaddr())
249+
require.Equal(t, h2.Addrs()[0], c.LocalMultiaddr())
252250
require.Equal(t, h1.ID(), c.RemotePeer())
253251
require.Equal(t, h2.ID(), c.LocalPeer())
254252
}),

p2p/test/transport/transport_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -867,3 +867,29 @@ func TestConnClosedWhenRemoteCloses(t *testing.T) {
867867
})
868868
}
869869
}
870+
871+
func TestConnMatchingAddress(t *testing.T) {
872+
for _, tc := range transportsToTest {
873+
t.Run(tc.Name, func(t *testing.T) {
874+
server := tc.HostGenerator(t, TransportTestCaseOpts{})
875+
client1 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
876+
client2 := tc.HostGenerator(t, TransportTestCaseOpts{NoListen: true})
877+
defer server.Close()
878+
defer client1.Close()
879+
defer client2.Close()
880+
881+
client1.Peerstore().AddAddrs(server.ID(), server.Addrs(), peerstore.PermanentAddrTTL)
882+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
883+
defer cancel()
884+
err := client1.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: server.Addrs()})
885+
require.NoError(t, err)
886+
887+
client1Conns := client1.Network().ConnsToPeer(server.ID())
888+
require.Equal(t, 1, len(client1Conns))
889+
remoteMA := client1Conns[0].RemoteMultiaddr()
890+
891+
err = client2.Connect(ctx, peer.AddrInfo{ID: server.ID(), Addrs: []ma.Multiaddr{remoteMA}})
892+
require.NoError(t, err)
893+
})
894+
}
895+
}

p2p/transport/webrtc/listener.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -264,14 +264,13 @@ func (l *listener) setupConnection(
264264
return nil, err
265265
}
266266

267-
localMultiaddrWithoutCerthash, _ := ma.SplitFunc(l.localMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
268267
conn, err := newConnection(
269268
network.DirInbound,
270269
w.PeerConnection,
271270
l.transport,
272271
scope,
273272
l.transport.localPeerId,
274-
localMultiaddrWithoutCerthash,
273+
l.localMultiaddr,
275274
remotePeer,
276275
remotePubKey,
277276
remoteMultiaddr,

p2p/transport/webrtc/transport.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,6 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
387387
if err != nil {
388388
return nil, err
389389
}
390-
remoteMultiaddrWithoutCerthash, _ := ma.SplitFunc(remoteMultiaddr, func(c ma.Component) bool { return c.Protocol().Code == ma.P_CERTHASH })
391390

392391
conn, err := newConnection(
393392
network.DirOutbound,
@@ -398,7 +397,7 @@ func (t *WebRTCTransport) dial(ctx context.Context, scope network.ConnManagement
398397
localAddr,
399398
p,
400399
remotePubKey,
401-
remoteMultiaddrWithoutCerthash,
400+
remoteMultiaddr,
402401
w.IncomingDataChannels,
403402
w.PeerConnectionClosedCh,
404403
)

p2p/transport/websocket/conn.go

Lines changed: 71 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package websocket
22

33
import (
4+
"fmt"
5+
ma "github.com/multiformats/go-multiaddr"
6+
manet "github.com/multiformats/go-multiaddr/net"
47
"io"
58
"net"
69
"sync"
@@ -25,17 +28,72 @@ type Conn struct {
2528
closeOnce sync.Once
2629

2730
readLock, writeLock sync.Mutex
31+
32+
laddr, raddr *Addr
33+
laddrma, raddrma ma.Multiaddr
2834
}
2935

3036
var _ net.Conn = (*Conn)(nil)
3137

32-
// NewConn creates a Conn given a regular gorilla/websocket Conn.
33-
func NewConn(raw *ws.Conn, secure bool) *Conn {
38+
// NewOutboundConn creates an outbound Conn given a regular gorilla/websocket Conn.
39+
func NewOutboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) {
40+
return newConn(raw, secure, sni, false)
41+
}
42+
43+
// NewInboundConn creates an inbound Conn given a regular gorilla/websocket Conn.
44+
func NewInboundConn(raw *ws.Conn, secure bool, sni string) (*Conn, error) {
45+
return newConn(raw, secure, sni, true)
46+
}
47+
48+
// newConn creates a Conn given a regular gorilla/websocket Conn.
49+
func newConn(raw *ws.Conn, secure bool, sni string, inbound bool) (*Conn, error) {
50+
laddr := NewAddrWithScheme(raw.LocalAddr().String(), secure)
51+
raddr := NewAddrWithScheme(raw.RemoteAddr().String(), secure)
52+
53+
laddrma, err := manet.FromNetAddr(laddr)
54+
if err != nil {
55+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
56+
}
57+
58+
raddrma, err := manet.FromNetAddr(raddr)
59+
if err != nil {
60+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
61+
}
62+
63+
if secure && sni != "" {
64+
var wssMA ma.Multiaddr
65+
if inbound {
66+
wssMA = laddrma
67+
} else {
68+
wssMA = raddrma
69+
}
70+
71+
if withoutWSS := wssMA.Decapsulate(ma.StringCast("/wss")); withoutWSS.Equal(wssMA) {
72+
return nil, fmt.Errorf("missing wss component from converted multiaddr")
73+
} else {
74+
tlsSniWsMa, err := ma.NewMultiaddr(fmt.Sprintf("/tls/sni/%s/ws", sni))
75+
if err != nil {
76+
return nil, fmt.Errorf("failed to convert connection address to multiaddr: %s", err)
77+
}
78+
wssMA = withoutWSS.Encapsulate(tlsSniWsMa)
79+
}
80+
81+
if inbound {
82+
laddrma = wssMA
83+
} else {
84+
raddrma = wssMA
85+
}
86+
}
87+
3488
return &Conn{
3589
Conn: raw,
3690
secure: secure,
3791
DefaultMessageType: ws.BinaryMessage,
38-
}
92+
laddr: laddr,
93+
raddr: raddr,
94+
laddrma: laddrma,
95+
raddrma: raddrma,
96+
}, nil
3997
}
4098

4199
func (c *Conn) Read(b []byte) (int, error) {
@@ -122,11 +180,19 @@ func (c *Conn) Close() error {
122180
}
123181

124182
func (c *Conn) LocalAddr() net.Addr {
125-
return NewAddrWithScheme(c.Conn.LocalAddr().String(), c.secure)
183+
return c.laddr
126184
}
127185

128186
func (c *Conn) RemoteAddr() net.Addr {
129-
return NewAddrWithScheme(c.Conn.RemoteAddr().String(), c.secure)
187+
return c.raddr
188+
}
189+
190+
func (c *Conn) LocalMultiaddr() ma.Multiaddr {
191+
return c.laddrma
192+
}
193+
194+
func (c *Conn) RemoteMultiaddr() ma.Multiaddr {
195+
return c.raddrma
130196
}
131197

132198
func (c *Conn) SetDeadline(t time.Time) error {

p2p/transport/websocket/listener.go

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,20 @@ func (l *listener) ServeHTTP(w http.ResponseWriter, r *http.Request) {
112112
return
113113
}
114114

115+
var sni string
116+
if r.TLS != nil {
117+
sni = r.TLS.ServerName
118+
}
119+
mnc, err := NewInboundConn(c, l.isWss, sni)
120+
if err != nil {
121+
_ = c.Close()
122+
return
123+
}
124+
115125
select {
116-
case l.incoming <- NewConn(c, l.isWss):
126+
case l.incoming <- mnc:
117127
case <-l.closed:
118-
c.Close()
128+
mnc.Close()
119129
}
120130
// The connection has been hijacked, it's safe to return.
121131
}
@@ -126,13 +136,7 @@ func (l *listener) Accept() (manet.Conn, error) {
126136
if !ok {
127137
return nil, transport.ErrListenerClosed
128138
}
129-
130-
mnc, err := manet.WrapNetConn(c)
131-
if err != nil {
132-
c.Close()
133-
return nil, err
134-
}
135-
return mnc, nil
139+
return c, nil
136140
case <-l.closed:
137141
return nil, transport.ErrListenerClosed
138142
}

p2p/transport/websocket/websocket.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,8 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
188188
}
189189
isWss := wsurl.Scheme == "wss"
190190
dialer := ws.Dialer{HandshakeTimeout: 30 * time.Second}
191+
var sni string
191192
if isWss {
192-
sni := ""
193193
sni, err = raddr.ValueForProtocol(ma.P_SNI)
194194
if err != nil {
195195
sni = ""
@@ -220,7 +220,7 @@ func (t *WebsocketTransport) maDial(ctx context.Context, raddr ma.Multiaddr) (ma
220220
return nil, err
221221
}
222222

223-
mnc, err := manet.WrapNetConn(NewConn(wscon, isWss))
223+
mnc, err := NewOutboundConn(wscon, isWss, sni)
224224
if err != nil {
225225
wscon.Close()
226226
return nil, err

p2p/transport/webtransport/listener.go

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -234,10 +234,7 @@ func (l *listener) Accept() (tpt.CapableConn, error) {
234234
}
235235

236236
func (l *listener) handshake(ctx context.Context, sess *webtransport.Session) (*connSecurityMultiaddrs, error) {
237-
local, err := toWebtransportMultiaddr(sess.LocalAddr())
238-
if err != nil {
239-
return nil, fmt.Errorf("error determiniting local addr: %w", err)
240-
}
237+
local := l.Multiaddr()
241238
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
242239
if err != nil {
243240
return nil, fmt.Errorf("error determiniting remote addr: %w", err)

p2p/transport/webtransport/transport.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (t *transport) dialWithScope(ctx context.Context, raddr ma.Multiaddr, p pee
172172
if err != nil {
173173
return nil, err
174174
}
175-
sconn, err := t.upgrade(ctx, sess, p, certHashes)
175+
sconn, err := t.upgrade(ctx, sess, p, certHashes, raddr)
176176
if err != nil {
177177
sess.CloseWithError(1, "")
178178
return nil, err
@@ -230,15 +230,11 @@ func (t *transport) dial(ctx context.Context, addr ma.Multiaddr, url, sni string
230230
return sess, conn, err
231231
}
232232

233-
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash) (*connSecurityMultiaddrs, error) {
233+
func (t *transport) upgrade(ctx context.Context, sess *webtransport.Session, p peer.ID, certHashes []multihash.DecodedMultihash, remote ma.Multiaddr) (*connSecurityMultiaddrs, error) {
234234
local, err := toWebtransportMultiaddr(sess.LocalAddr())
235235
if err != nil {
236236
return nil, fmt.Errorf("error determining local addr: %w", err)
237237
}
238-
remote, err := toWebtransportMultiaddr(sess.RemoteAddr())
239-
if err != nil {
240-
return nil, fmt.Errorf("error determining remote addr: %w", err)
241-
}
242238

243239
str, err := sess.OpenStreamSync(ctx)
244240
if err != nil {

0 commit comments

Comments
 (0)