Skip to content

Commit

Permalink
Add support for Netlink socket addresses
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinmehall committed Jan 29, 2024
1 parent 8533f7c commit 4c1c746
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 4 deletions.
22 changes: 21 additions & 1 deletion src/backend/libc/net/read_sockaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use crate::backend::c;
use crate::ffi::CStr;
use crate::io;
#[cfg(target_os = "linux")]
use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
use crate::net::{
netlink::SocketAddrNetlink,
xdp::{SockaddrXdpFlags, SocketAddrXdp},
};
use crate::net::{Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrV4, SocketAddrV6};
use core::mem::size_of;

Expand Down Expand Up @@ -208,6 +211,17 @@ pub(crate) unsafe fn read_sockaddr(
u32::from_be(decode.sxdp_shared_umem_fd),
)))
}
#[cfg(target_os = "linux")]
c::AF_NETLINK => {
if len < size_of::<c::sockaddr_nl>() {
return Err(io::Errno::INVAL);
}
let decode = &*storage.cast::<c::sockaddr_nl>();
Ok(SocketAddrAny::Netlink(SocketAddrNetlink::new(
decode.nl_pid,
decode.nl_groups,
)))
}
_ => Err(io::Errno::INVAL),
}
}
Expand Down Expand Up @@ -327,6 +341,12 @@ unsafe fn inner_read_sockaddr_os(
u32::from_be(decode.sxdp_shared_umem_fd),
))
}
#[cfg(target_os = "linux")]
c::AF_NETLINK => {
assert!(len >= size_of::<c::sockaddr_nl>());
let decode = &*storage.cast::<c::sockaddr_nl>();
SocketAddrAny::Netlink(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups))
}
other => unimplemented!("{:?}", other),
}
}
22 changes: 21 additions & 1 deletion src/backend/linux_raw/net/read_sockaddr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
use crate::backend::c;
use crate::io;
#[cfg(target_os = "linux")]
use crate::net::xdp::{SockaddrXdpFlags, SocketAddrXdp};
use crate::net::{
netlink::SocketAddrNetlink,
xdp::{SockaddrXdpFlags, SocketAddrXdp},
};
use crate::net::{Ipv4Addr, Ipv6Addr, SocketAddrAny, SocketAddrUnix, SocketAddrV4, SocketAddrV6};
use core::mem::size_of;
use core::slice;
Expand Down Expand Up @@ -127,6 +130,17 @@ pub(crate) unsafe fn read_sockaddr(
u32::from_be(decode.sxdp_shared_umem_fd),
)))
}
#[cfg(target_os = "linux")]
c::AF_NETLINK => {
if len < size_of::<c::sockaddr_nl>() {
return Err(io::Errno::INVAL);
}
let decode = &*storage.cast::<c::sockaddr_nl>();
Ok(SocketAddrAny::Netlink(SocketAddrNetlink::new(
decode.nl_pid,
decode.nl_groups,
)))
}
_ => Err(io::Errno::NOTSUP),
}
}
Expand Down Expand Up @@ -216,6 +230,12 @@ pub(crate) unsafe fn read_sockaddr_os(storage: *const c::sockaddr, len: usize) -
u32::from_be(decode.sxdp_shared_umem_fd),
))
}
#[cfg(target_os = "linux")]
c::AF_NETLINK => {
assert!(len >= size_of::<c::sockaddr_nl>());
let decode = &*storage.cast::<c::sockaddr_nl>();
SocketAddrAny::Netlink(SocketAddrNetlink::new(decode.nl_pid, decode.nl_groups))
}
other => unimplemented!("{:?}", other),
}
}
15 changes: 13 additions & 2 deletions src/net/socket_addr_any.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
#![allow(unsafe_code)]

use crate::backend::c;
#[cfg(target_os = "linux")]
use crate::net::xdp::SocketAddrXdp;
#[cfg(unix)]
use crate::net::SocketAddrUnix;
#[cfg(target_os = "linux")]
use crate::net::{netlink::SocketAddrNetlink, xdp::SocketAddrXdp};
use crate::net::{AddressFamily, SocketAddr, SocketAddrV4, SocketAddrV6};
use crate::{backend, io};
#[cfg(feature = "std")]
Expand All @@ -39,6 +39,9 @@ pub enum SocketAddrAny {
/// `struct sockaddr_xdp`
#[cfg(target_os = "linux")]
Xdp(SocketAddrXdp),
/// `struct sockaddr_nl`
#[cfg(target_os = "linux")]
Netlink(SocketAddrNetlink),
}

impl From<SocketAddr> for SocketAddrAny {
Expand Down Expand Up @@ -84,6 +87,8 @@ impl SocketAddrAny {
Self::Unix(_) => AddressFamily::UNIX,
#[cfg(target_os = "linux")]
Self::Xdp(_) => AddressFamily::XDP,
#[cfg(target_os = "linux")]
Self::Netlink(_) => AddressFamily::NETLINK,
}
}

Expand All @@ -103,6 +108,8 @@ impl SocketAddrAny {
SocketAddrAny::Unix(a) => a.write_sockaddr(storage),
#[cfg(target_os = "linux")]
SocketAddrAny::Xdp(a) => a.write_sockaddr(storage),
#[cfg(target_os = "linux")]
SocketAddrAny::Netlink(a) => a.write_sockaddr(storage),
}
}

Expand All @@ -128,6 +135,8 @@ impl fmt::Debug for SocketAddrAny {
Self::Unix(unix) => unix.fmt(fmt),
#[cfg(target_os = "linux")]
Self::Xdp(xdp) => xdp.fmt(fmt),
#[cfg(target_os = "linux")]
Self::Netlink(nl) => nl.fmt(fmt),
}
}
}
Expand Down Expand Up @@ -158,6 +167,8 @@ unsafe impl SocketAddress for SocketAddrAny {
Self::Unix(a) => a.with_sockaddr(f),
#[cfg(target_os = "linux")]
Self::Xdp(a) => a.with_sockaddr(f),
#[cfg(target_os = "linux")]
Self::Netlink(a) => a.with_sockaddr(f),
}
}
}
64 changes: 64 additions & 0 deletions src/net/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,8 @@ pub mod netlink {
use {
super::{new_raw_protocol, Protocol},
crate::backend::c,
crate::net::SocketAddress,
core::mem,
};

/// `NETLINK_UNUSED`
Expand Down Expand Up @@ -1024,6 +1026,68 @@ pub mod netlink {
/// `NETLINK_GET_STRICT_CHK`
#[cfg(linux_kernel)]
pub const GET_STRICT_CHK: Protocol = Protocol(new_raw_protocol(c::NETLINK_GET_STRICT_CHK as _));

/// A Netlink socket address.
///
/// Used to bind to a Netlink socket.
///
/// Not ABI compatible with `struct sockaddr_nl`
#[derive(Clone, PartialEq, PartialOrd, Eq, Ord, Hash, Debug)]
#[cfg(linux_kernel)]
pub struct SocketAddrNetlink {
/// Port ID
pid: u32,

/// Multicast groups mask
groups: u32,
}

#[cfg(linux_kernel)]
impl SocketAddrNetlink {
/// Construct a netlink address
#[inline]
pub fn new(pid: u32, groups: u32) -> Self {
Self { pid, groups }
}

/// Return port id.
#[inline]
pub fn pid(&self) -> u32 {
self.pid
}

/// Set port id.
#[inline]
pub fn set_pid(&mut self, pid: u32) {
self.pid = pid;
}

/// Return multicast groups mask.
#[inline]
pub fn groups(&self) -> u32 {
self.groups
}

/// Set multicast groups mask.
#[inline]
pub fn set_groups(&mut self, groups: u32) {
self.groups = groups;
}
}

#[cfg(linux_kernel)]
#[allow(unsafe_code)]
unsafe impl SocketAddress for SocketAddrNetlink {
type CSockAddr = c::sockaddr_nl;

fn encode(&self) -> Self::CSockAddr {
let mut addr: c::sockaddr_nl = unsafe { mem::zeroed() };
addr.nl_family = c::AF_NETLINK as _;
addr.nl_pid = self.pid;
addr.nl_groups = self.groups;
addr
}
}
}

/// `ETH_P_*` constants.
Expand Down

0 comments on commit 4c1c746

Please sign in to comment.