From fbff6a14872a544f2be110467dbdda1e378d936f Mon Sep 17 00:00:00 2001 From: Nate Brown Date: Tue, 4 Feb 2025 16:36:34 -0600 Subject: [PATCH] More correct ipv6 header parsing (#1323) --- outside.go | 103 ++++++---- outside_test.go | 523 ++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 545 insertions(+), 81 deletions(-) diff --git a/outside.go b/outside.go index 1efe50d46..dbf75f1eb 100644 --- a/outside.go +++ b/outside.go @@ -3,7 +3,6 @@ package nebula import ( "encoding/binary" "errors" - "fmt" "net/netip" "time" @@ -271,10 +270,19 @@ func (f *Interface) handleEncrypted(ci *ConnectionState, addr netip.AddrPort, h return true } +var ( + ErrPacketTooShort = errors.New("packet is too short") + ErrUnknownIPVersion = errors.New("packet is an unknown ip version") + ErrIPv4InvalidHeaderLength = errors.New("invalid ipv4 header length") + ErrIPv4PacketTooShort = errors.New("ipv4 packet is too short") + ErrIPv6PacketTooShort = errors.New("ipv6 packet is too short") + ErrIPv6CouldNotFindPayload = errors.New("could not find payload in ipv6 packet") +) + // newPacket validates and parses the interesting bits for the firewall out of the ip and sub protocol headers func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { if len(data) < 1 { - return errors.New("packet too short") + return ErrPacketTooShort } version := int((data[0] >> 4) & 0x0f) @@ -284,13 +292,13 @@ func newPacket(data []byte, incoming bool, fp *firewall.Packet) error { case ipv6.Version: return parseV6(data, incoming, fp) } - return fmt.Errorf("packet is an unknown ip version: %v", version) + return ErrUnknownIPVersion } func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { dataLen := len(data) if dataLen < ipv6.HeaderLen { - return fmt.Errorf("ipv6 packet is less than %v bytes", ipv4.HeaderLen) + return ErrIPv6PacketTooShort } if incoming { @@ -301,11 +309,10 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { fp.RemoteAddr, _ = netip.AddrFromSlice(data[24:40]) } - //TODO: CERT-V2 whats a reasonable number of extension headers to attempt to parse? - //https://www.ietf.org/archive/id/draft-ietf-6man-eh-limits-00.html - protoAt := 6 - offset := 40 - for i := 0; i < 24; i++ { + protoAt := 6 // NextHeader is at 6 bytes into the ipv6 header + offset := ipv6.HeaderLen // Start at the end of the ipv6 header + next := 0 + for { if dataLen < offset { break } @@ -313,17 +320,18 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { proto := layers.IPProtocol(data[protoAt]) //fmt.Println(proto, protoAt) switch proto { - case layers.IPProtocolICMPv6: + case layers.IPProtocolICMPv6, layers.IPProtocolESP, layers.IPProtocolNoNextHeader: fp.Protocol = uint8(proto) fp.RemotePort = 0 fp.LocalPort = 0 fp.Fragment = false return nil - case layers.IPProtocolTCP: + case layers.IPProtocolTCP, layers.IPProtocolUDP: if dataLen < offset+4 { - return fmt.Errorf("ipv6 packet was too small") + return ErrIPv6PacketTooShort } + fp.Protocol = uint8(proto) if incoming { fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) @@ -332,62 +340,71 @@ func parseV6(data []byte, incoming bool, fp *firewall.Packet) error { fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) } + fp.Fragment = false return nil - case layers.IPProtocolUDP: - if dataLen < offset+4 { - return fmt.Errorf("ipv6 packet was too small") + case layers.IPProtocolIPv6Fragment: + // Fragment header is 8 bytes, need at least offset+4 to read the offset field + if dataLen < offset+8 { + return ErrIPv6PacketTooShort } - fp.Protocol = uint8(proto) - if incoming { - fp.RemotePort = binary.BigEndian.Uint16(data[offset : offset+2]) - fp.LocalPort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) - } else { - fp.LocalPort = binary.BigEndian.Uint16(data[offset : offset+2]) - fp.RemotePort = binary.BigEndian.Uint16(data[offset+2 : offset+4]) + + // Check if this is the first fragment + fragmentOffset := binary.BigEndian.Uint16(data[offset+2:offset+4]) &^ uint16(0x7) // Remove the reserved and M flag bits + if fragmentOffset != 0 { + // Non-first fragment, use what we have now and stop processing + fp.Protocol = data[offset] + fp.Fragment = true + fp.RemotePort = 0 + fp.LocalPort = 0 + return nil } - fp.Fragment = false - return nil - case layers.IPProtocolIPv6Fragment: - //TODO: CERT-V2 can we determine the protocol? - fp.RemotePort = 0 - fp.LocalPort = 0 - fp.Fragment = true - return nil + // The next loop should be the transport layer since we are the first fragment + next = 8 // Fragment headers are always 8 bytes - default: + case layers.IPProtocolAH: + // Auth headers, used by IPSec, have a different meaning for header length if dataLen < offset+1 { break } - next := int(data[offset+1]) * 8 - if next == 0 { - // each extension is at least 8 bytes - next = 8 + next = int(data[offset+1]+2) << 2 + + default: + // Normal ipv6 header length processing + if dataLen < offset+1 { + break } - protoAt = offset - offset = offset + next + next = int(data[offset+1]+1) << 3 } + + if next <= 0 { + // Safety check, each ipv6 header has to be at least 8 bytes + next = 8 + } + + protoAt = offset + offset = offset + next } - return fmt.Errorf("could not find payload in ipv6 packet") + return ErrIPv6CouldNotFindPayload } func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { // Do we at least have an ipv4 header worth of data? if len(data) < ipv4.HeaderLen { - return fmt.Errorf("ipv4 packet is less than %v bytes", ipv4.HeaderLen) + return ErrIPv4PacketTooShort } // Adjust our start position based on the advertised ip header length ihl := int(data[0]&0x0f) << 2 - // Well formed ip header length? + // Well-formed ip header length? if ihl < ipv4.HeaderLen { - return fmt.Errorf("ipv4 packet had an invalid header length: %v", ihl) + return ErrIPv4InvalidHeaderLength } // Check if this is the second or further fragment of a fragmented packet. @@ -403,7 +420,7 @@ func parseV4(data []byte, incoming bool, fp *firewall.Packet) error { minLen += minFwPacketLen } if len(data) < minLen { - return fmt.Errorf("ipv4 packet is less than %v bytes, ip header len: %v", minLen, ihl) + return ErrIPv4InvalidHeaderLength } // Firewall packets are locally oriented @@ -501,7 +518,7 @@ func (f *Interface) sendRecvError(endpoint netip.AddrPort, index uint32) { f.messageMetrics.Tx(header.RecvError, 0, 1) b := header.Encode(make([]byte, header.Len), header.Version, header.RecvError, 0, index, 0) - f.outside.WriteTo(b, endpoint) + _ = f.outside.WriteTo(b, endpoint) if f.l.Level >= logrus.DebugLevel { f.l.WithField("index", index). WithField("udpAddr", endpoint). diff --git a/outside_test.go b/outside_test.go index cbe622345..f19759478 100644 --- a/outside_test.go +++ b/outside_test.go @@ -1,6 +1,8 @@ package nebula import ( + "bytes" + "encoding/binary" "net" "net/netip" "testing" @@ -18,13 +20,13 @@ func Test_newPacket(t *testing.T) { // length fails err := newPacket([]byte{}, true, p) - assert.EqualError(t, err, "packet too short") + assert.ErrorIs(t, err, ErrPacketTooShort) err = newPacket([]byte{0x40}, true, p) - assert.EqualError(t, err, "ipv4 packet is less than 20 bytes") + assert.ErrorIs(t, err, ErrIPv4PacketTooShort) err = newPacket([]byte{0x60}, true, p) - assert.EqualError(t, err, "ipv6 packet is less than 20 bytes") + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ @@ -37,16 +39,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - - assert.EqualError(t, err, "ipv4 packet is less than 28 bytes, ip header len: 24") + assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "packet is an unknown ip version: 0") + assert.ErrorIs(t, err, ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.EqualError(t, err, "ipv4 packet had an invalid header length: 8") + assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ @@ -63,11 +64,12 @@ func Test_newPacket(t *testing.T) { err = newPacket(b, true, p) assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(firewall.ProtoTCP)) - assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.2")) - assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.1")) - assert.Equal(t, p.RemotePort, uint16(3)) - assert.Equal(t, p.LocalPort, uint16(4)) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) + assert.Equal(t, uint16(3), p.RemotePort) + assert.Equal(t, uint16(4), p.LocalPort) + assert.False(t, p.Fragment) // account for variable ip header length - outgoing h = ipv4.Header{ @@ -84,17 +86,94 @@ func Test_newPacket(t *testing.T) { err = newPacket(b, false, p) assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(2)) - assert.Equal(t, p.LocalAddr, netip.MustParseAddr("10.0.0.1")) - assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("10.0.0.2")) - assert.Equal(t, p.RemotePort, uint16(6)) - assert.Equal(t, p.LocalPort, uint16(5)) + assert.Equal(t, uint8(2), p.Protocol) + assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) + assert.Equal(t, uint16(6), p.RemotePort) + assert.Equal(t, uint16(5), p.LocalPort) + assert.False(t, p.Fragment) } func Test_newPacket_v6(t *testing.T) { p := &firewall.Packet{} + // invalid ipv6 ip := layers.IPv6{ + Version: 6, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + buffer := gopacket.NewSerializeBuffer() + opt := gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: false, + } + err := gopacket.SerializeLayers(buffer, opt, &ip) + assert.NoError(t, err) + + err = newPacket(buffer.Bytes(), true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A good ICMP packet + ip = layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolICMPv6, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + icmp := layers.ICMPv6{} + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, opt, &ip, &icmp) + if err != nil { + panic(err) + } + + err = newPacket(buffer.Bytes(), true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // A good ESP packet + b := buffer.Bytes() + b[6] = byte(layers.IPProtocolESP) + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // A good None packet + b = buffer.Bytes() + b[6] = byte(layers.IPProtocolNoNextHeader) + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.False(t, p.Fragment) + + // An unknown protocol packet + b = buffer.Bytes() + b[6] = 255 // 255 is a reserved protocol number + err = newPacket(b, true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + + // A good UDP packet + ip = layers.IPv6{ Version: 6, NextHeader: firewall.ProtoUDP, HopLimit: 128, @@ -106,39 +185,407 @@ func Test_newPacket_v6(t *testing.T) { SrcPort: layers.UDPPort(36123), DstPort: layers.UDPPort(22), } - err := udp.SetNetworkLayerForChecksum(&ip) + err = udp.SetNetworkLayerForChecksum(&ip) + assert.NoError(t, err) + + buffer.Clear() + err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) + if err != nil { + panic(err) + } + b = buffer.Bytes() + + // incoming + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // outgoing + err = newPacket(b, false, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Too short UDP packet + err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + + // A good TCP packet + b[6] = byte(layers.IPProtocolTCP) + + // incoming + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // outgoing + err = newPacket(b, false, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) + + // Too short TCP packet + err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + + // A good UDP packet with an AH header + ip = layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolAH, + HopLimit: 128, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + ah := layers.IPSecAH{ + AuthenticationData: []byte{0xde, 0xad, 0xbe, 0xef, 0xde, 0xad, 0xbe, 0xef}, + } + ah.NextHeader = layers.IPProtocolUDP + + udpHeader := []byte{ + 0x8d, 0x1b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer.Clear() + err = ip.SerializeTo(buffer, opt) if err != nil { panic(err) } + b = buffer.Bytes() + ahb := serializeAH(&ah) + b = append(b, ahb...) + b = append(b, udpHeader...) + + err = newPacket(b, true, p) + assert.Nil(t, err) + assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) + + // Invalid AH header + b = buffer.Bytes() + err = newPacket(b, true, p) + assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) +} + +func Test_newPacket_ipv6Fragment(t *testing.T) { + p := &firewall.Packet{} + + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6Fragment, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + // First fragment + fragHeader1 := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0x00, // Fragment Offset high byte (0) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + udpHeader := []byte{ + 0x8d, 0x1b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + buffer := gopacket.NewSerializeBuffer() - opt := gopacket.SerializeOptions{ + opts := gopacket.SerializeOptions{ ComputeChecksums: true, FixLengths: true, } - err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) + + err := ip.SerializeTo(buffer, opts) if err != nil { - panic(err) + t.Fatal(err) } - b := buffer.Bytes() - //test incoming - err = newPacket(b, true, p) + firstFrag := buffer.Bytes() + firstFrag = append(firstFrag, fragHeader1...) + firstFrag = append(firstFrag, udpHeader...) + firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) - assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP)) - assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::2")) - assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::1")) - assert.Equal(t, p.RemotePort, uint16(36123)) - assert.Equal(t, p.LocalPort, uint16(22)) + // Test first fragment incoming + err = newPacket(firstFrag, true, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(36123), p.RemotePort) + assert.Equal(t, uint16(22), p.LocalPort) + assert.False(t, p.Fragment) - //test outgoing - err = newPacket(b, false, p) + // Test first fragment outgoing + err = newPacket(firstFrag, false, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(36123), p.LocalPort) + assert.Equal(t, uint16(22), p.RemotePort) + assert.False(t, p.Fragment) - assert.Nil(t, err) - assert.Equal(t, p.Protocol, uint8(firewall.ProtoUDP)) - assert.Equal(t, p.LocalAddr, netip.MustParseAddr("ff02::2")) - assert.Equal(t, p.RemoteAddr, netip.MustParseAddr("ff02::1")) - assert.Equal(t, p.LocalPort, uint16(36123)) - assert.Equal(t, p.RemotePort, uint16(22)) + // Second fragment + fragHeader2 := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0xb9, // Fragment Offset high byte (185) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + buffer.Clear() + err = ip.SerializeTo(buffer, opts) + if err != nil { + t.Fatal(err) + } + + secondFrag := buffer.Bytes() + secondFrag = append(secondFrag, fragHeader2...) + secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Test second fragment incoming + err = newPacket(secondFrag, true, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(0), p.RemotePort) + assert.Equal(t, uint16(0), p.LocalPort) + assert.True(t, p.Fragment) + + // Test second fragment outgoing + err = newPacket(secondFrag, false, p) + assert.NoError(t, err) + assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) + assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) + assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) + assert.Equal(t, uint16(0), p.LocalPort) + assert.Equal(t, uint16(0), p.RemotePort) + assert.True(t, p.Fragment) + + // Too short of a fragment packet + err = newPacket(secondFrag[:len(secondFrag)-10], false, p) + assert.ErrorIs(t, err, ErrIPv6PacketTooShort) +} + +func BenchmarkParseV6(b *testing.B) { + // Regular UDP packet + ip := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolUDP, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + udp := &layers.UDP{ + SrcPort: layers.UDPPort(36123), + DstPort: layers.UDPPort(22), + } + + buffer := gopacket.NewSerializeBuffer() + opts := gopacket.SerializeOptions{ + ComputeChecksums: false, + FixLengths: true, + } + + err := gopacket.SerializeLayers(buffer, opts, ip, udp) + if err != nil { + b.Fatal(err) + } + normalPacket := buffer.Bytes() + + // First Fragment packet + ipFrag := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6Fragment, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + fragHeader := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Reserved + 0x00, // Fragment Offset high byte (0) + 0x01, // Fragment Offset low byte & flags (M=1) + 0x00, 0x00, 0x00, 0x01, // Identification + } + + udpHeader := []byte{ + 0x8d, 0x7b, // Source port 36123 + 0x00, 0x16, // Destination port 22 + 0x00, 0x00, // Length + 0x00, 0x00, // Checksum + } + + buffer.Clear() + err = ipFrag.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + firstFrag := buffer.Bytes() + firstFrag = append(firstFrag, fragHeader...) + firstFrag = append(firstFrag, udpHeader...) + firstFrag = append(firstFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + // Second Fragment packet + fragHeader[2] = 0xb9 // offset 185 + buffer.Clear() + err = ipFrag.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + secondFrag := buffer.Bytes() + secondFrag = append(secondFrag, fragHeader...) + secondFrag = append(secondFrag, []byte{0xde, 0xad, 0xbe, 0xef}...) + + fp := &firewall.Packet{} + + b.Run("Normal", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(normalPacket, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("FirstFragment", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(firstFrag, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + b.Run("SecondFragment", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(secondFrag, true, fp); err != nil { + b.Fatal(err) + } + } + }) + + // Evil packet + evilPacket := &layers.IPv6{ + Version: 6, + NextHeader: layers.IPProtocolIPv6HopByHop, + HopLimit: 64, + SrcIP: net.IPv6linklocalallrouters, + DstIP: net.IPv6linklocalallnodes, + } + + hopHeader := []byte{ + uint8(layers.IPProtocolIPv6HopByHop), // Next Header (HopByHop) + 0x00, // Length + 0x00, 0x00, // Options and padding + 0x00, 0x00, 0x00, 0x00, // More options and padding + } + + lastHopHeader := []byte{ + uint8(layers.IPProtocolUDP), // Next Header (UDP) + 0x00, // Length + 0x00, 0x00, // Options and padding + 0x00, 0x00, 0x00, 0x00, // More options and padding + } + + buffer.Clear() + err = evilPacket.SerializeTo(buffer, opts) + if err != nil { + b.Fatal(err) + } + + evilBytes := buffer.Bytes() + for i := 0; i < 200; i++ { + evilBytes = append(evilBytes, hopHeader...) + } + evilBytes = append(evilBytes, lastHopHeader...) + evilBytes = append(evilBytes, udpHeader...) + evilBytes = append(evilBytes, []byte{0xde, 0xad, 0xbe, 0xef}...) + + b.Run("200 HopByHop headers", func(b *testing.B) { + for i := 0; i < b.N; i++ { + if err = parseV6(evilBytes, false, fp); err != nil { + b.Fatal(err) + } + } + }) +} + +// Ensure authentication data is a multiple of 8 bytes by padding if necessary +func padAuthData(authData []byte) []byte { + // Length of Authentication Data must be a multiple of 8 bytes + paddingLength := (8 - (len(authData) % 8)) % 8 // Only pad if necessary + if paddingLength > 0 { + authData = append(authData, make([]byte, paddingLength)...) + } + return authData +} + +// Custom function to manually serialize IPSecAH for both IPv4 and IPv6 +func serializeAH(ah *layers.IPSecAH) []byte { + buf := new(bytes.Buffer) + + // Ensure Authentication Data is a multiple of 8 bytes + ah.AuthenticationData = padAuthData(ah.AuthenticationData) + // Calculate Payload Length (in 32-bit words, minus 2) + payloadLen := uint8((12+len(ah.AuthenticationData))/4) - 2 + + // Serialize fields + if err := binary.Write(buf, binary.BigEndian, ah.NextHeader); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, payloadLen); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.Reserved); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.SPI); err != nil { + panic(err) + } + if err := binary.Write(buf, binary.BigEndian, ah.Seq); err != nil { + panic(err) + } + if len(ah.AuthenticationData) > 0 { + if err := binary.Write(buf, binary.BigEndian, ah.AuthenticationData); err != nil { + panic(err) + } + } + + return buf.Bytes() }