Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(virtio-net): prepare checksum correctly #1488

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 148 additions & 30 deletions src/drivers/net/virtio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -263,38 +263,12 @@ impl NetworkDriver for VirtioNetDriver {
};

let mut header = Box::new_in(<Hdr as Default>::default(), DeviceAlloc);
// If a checksum isn't necessary, we have inform the host within the header
// see Virtio specification 5.1.6.2
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {

if let Some((ip_header_len, csum_offset)) = self.should_request_checksum(&mut packet) {
header.flags = HdrF::NEEDS_CSUM;
let ethernet_frame: smoltcp::wire::EthernetFrame<&[u8]> =
EthernetFrame::new_unchecked(&packet);
let packet_header_len: u16;
let protocol;
match ethernet_frame.ethertype() {
smoltcp::wire::EthernetProtocol::Ipv4 => {
let packet = Ipv4Packet::new_unchecked(ethernet_frame.payload());
packet_header_len = packet.header_len().into();
protocol = Some(packet.next_header());
}
smoltcp::wire::EthernetProtocol::Ipv6 => {
let packet = Ipv6Packet::new_unchecked(ethernet_frame.payload());
packet_header_len = packet.header_len().try_into().unwrap();
protocol = Some(packet.next_header());
}
_ => {
packet_header_len = 0;
protocol = None;
}
}
header.csum_start =
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len).into();
header.csum_offset = match protocol {
Some(smoltcp::wire::IpProtocol::Tcp) => 16,
Some(smoltcp::wire::IpProtocol::Udp) => 6,
_ => 0,
}
.into();
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + ip_header_len).into();
header.csum_offset = csum_offset.into();
}

let buff_tkn = AvailBufferToken::new(
Expand Down Expand Up @@ -784,6 +758,87 @@ impl VirtioNetDriver {

Ok(())
}

/// Sets the TCP or UDP checksum field to the checksum of the pseudo-header if necessary or returns None otherwise.
fn should_request_checksum<T: AsRef<[u8]> + AsMut<[u8]>>(
&self,
frame: T,
) -> Option<(u16, u16)> {
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
// If a checksum calculation by the host is necessary, we have to inform the host within the header
// see Virtio specification 5.1.6.2
let mut ethernet_frame = EthernetFrame::new_unchecked(frame);
// If the Ethernet protocol is not one of these two, we default to not asking for checksum,
// as otherwise the frame will be corrupted by the device trying to write the checksum.
if let ip @ (smoltcp::wire::EthernetProtocol::Ipv4
| smoltcp::wire::EthernetProtocol::Ipv6) = ethernet_frame.ethertype()
{
let ip_header_len: u16;
let ip_packet_len: usize;
let protocol;
let pseudo_header_checksum;
match ip {
smoltcp::wire::EthernetProtocol::Ipv4 => {
let ip_packet = Ipv4Packet::new_unchecked(&*ethernet_frame.payload_mut());
ip_header_len = ip_packet.header_len().into();
ip_packet_len = ip_packet.total_len().into();
protocol = ip_packet.next_header();
pseudo_header_checksum =
partial_checksum::ipv4_pseudo_header_partial_checksum(&ip_packet);
}
smoltcp::wire::EthernetProtocol::Ipv6 => {
let ip_packet = Ipv6Packet::new_unchecked(&*ethernet_frame.payload_mut());
ip_header_len = ip_packet.header_len().try_into().expect(
"VIRTIO does not support IP headers that are longer than u16::MAX bytes.",
);
ip_packet_len = ip_packet.total_len();
protocol = ip_packet.next_header();
pseudo_header_checksum =
partial_checksum::ipv6_pseudo_header_partial_checksum(&ip_packet);
}
_ => unreachable!(),
}
// Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
if let smoltcp::wire::IpProtocol::Tcp | smoltcp::wire::IpProtocol::Udp = protocol {
let ip_payload =
&mut ethernet_frame.payload_mut()[ip_header_len.into()..ip_packet_len];

// We do not care about the offset of the checksum for the protocol if we don't require checksum
// from the host, so we use None to signal that checksum from the host is not needed.
let csum_offset = match protocol {
smoltcp::wire::IpProtocol::Tcp => {
if !self.checksums.tcp.tx() {
let mut tcp_packet =
smoltcp::wire::TcpPacket::new_unchecked(ip_payload);
tcp_packet.set_checksum(pseudo_header_checksum);
Some(16)
} else {
None
}
}
smoltcp::wire::IpProtocol::Udp => {
if !self.checksums.tcp.tx() {
let mut udp_packet =
smoltcp::wire::UdpPacket::new_unchecked(ip_payload);
udp_packet.set_checksum(pseudo_header_checksum);
Some(6)
} else {
None
}
}
_ => None,
};
csum_offset.map(|csum_offset| (ip_header_len, csum_offset))
} else {
None
}
} else {
None
}
} else {
None
}
}
}

pub mod constants {
Expand All @@ -808,3 +863,66 @@ pub mod error {
IncompatibleFeatureSets(virtio::net::F, virtio::net::F),
}
}

/// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
/// and their results are meant to be combined with the TCP payload to calculate the real checksum.
/// They are only useful for the VIRTIO driver with the checksum offloading feature.
///
/// The calculations here can theoretically be made faster by exploiting the properties described in
/// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
mod partial_checksum {
use core::iter;

use smoltcp::wire::{Ipv4Packet, Ipv6Packet};

/// Calculates the checksum for the IPv4 pseudo-header as described in
/// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
pub(super) fn ipv4_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
packet: &Ipv4Packet<T>,
) -> u16 {
let src_addr = packet.src_addr();
let dst_addr = packet.dst_addr();
let address_words = src_addr
.as_bytes()
.iter()
.chain(dst_addr.as_bytes())
.copied()
.array_chunks::<{ size_of::<u16>() }>()
.map(u16::from_be_bytes);
let padded_protocol = u16::from(u8::from(packet.next_header()));
let payload_len = packet.total_len() - u16::from(packet.header_len());
address_words
.chain(iter::once(padded_protocol))
.chain(iter::once(payload_len))
.fold(0u16, ones_complement_add)
}

/// Calculates the checksum for the IPv6 pseudo-header as described in
/// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
pub(super) fn ipv6_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
packet: &Ipv6Packet<T>,
) -> u16 {
warn!("The IPv6 partial checksum implementation is untested!");
let src_addr = packet.src_addr();
let dst_addr = packet.dst_addr();
let payload_len = packet.payload_len();
let padded_protocol = u16::from(u8::from(packet.next_header()));

src_addr
.as_bytes()
.iter()
.chain(dst_addr.as_bytes())
.copied()
.array_chunks::<{ size_of::<u16>() }>()
.map(u16::from_be_bytes)
.chain(iter::once(payload_len))
.chain(iter::once(padded_protocol))
.fold(0u16, ones_complement_add)
}

/// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
fn ones_complement_add(lhs: u16, rhs: u16) -> u16 {
let (sum, overflow) = u16::overflowing_add(lhs, rhs);
sum + u16::from(overflow)
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)]
#![cfg_attr(target_arch = "x86_64", feature(abi_x86_interrupt))]
#![feature(allocator_api)]
#![feature(iter_array_chunks)]
#![feature(linked_list_cursors)]
#![feature(map_try_insert)]
#![feature(maybe_uninit_as_bytes)]
Expand Down
Loading