Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 205 additions & 51 deletions src/dns_parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1424,11 +1424,8 @@ enum PacketState {

/// A single packet for outgoing DNS message.
pub struct DnsOutPacket {
/// All bytes in `data` concatenated is the actual packet on the wire.
data: Vec<Vec<u8>>,

/// Current logical size of the packet. It starts with the size of the mandatory header.
size: usize,
/// All bytes in `data` is the actual packet on the wire.
data: Vec<u8>,

/// An internal state, not defined by DNS.
state: PacketState,
Expand All @@ -1440,19 +1437,18 @@ pub struct DnsOutPacket {
impl DnsOutPacket {
fn new() -> Self {
Self {
data: Vec::new(),
size: MSG_HEADER_LEN, // Header is mandatory.
data: vec![0; MSG_HEADER_LEN],
Comment thread
keepsimple1 marked this conversation as resolved.
state: PacketState::Init,
names: HashMap::new(),
}
}

pub fn size(&self) -> usize {
self.size
self.data.len()
}

pub fn to_bytes(&self) -> Vec<u8> {
self.data.concat()
pub fn as_bytes(&self) -> &[u8] {
Comment thread
keepsimple1 marked this conversation as resolved.
&self.data
}

fn write_question(&mut self, question: &DnsQuestion) {
Expand All @@ -1465,8 +1461,7 @@ impl DnsOutPacket {
/// Returns false if the packet exceeds the max size with this record, nothing is written to the packet.
/// otherwise returns true.
fn write_record(&mut self, record_ext: &dyn DnsRecordExt, now: u64) -> bool {
let start_data_length = self.data.len();
let start_size = self.size;
let start_size = self.size();

let record = record_ext.get_record();
self.write_name(record.get_name());
Expand All @@ -1484,19 +1479,14 @@ impl DnsOutPacket {
self.write_u32(record.get_remaining_ttl(now));
}

let index = self.data.len();

// Adjust size for the short we will write before this record
self.size += 2;
// Placeholder for record size
self.write_short(0);
let record_offset = self.size();
record_ext.write(self);
self.size -= 2;

let length: usize = self.data[index..].iter().map(|x| x.len()).sum();
self.insert_short(index, length as u16);
self.insert_short(record_offset - 2, (self.size() - record_offset) as u16);

if self.size > MAX_MSG_ABSOLUTE {
self.data.truncate(start_data_length);
self.size = start_size;
if self.size() > MAX_MSG_ABSOLUTE {
self.data.truncate(start_size);
self.state = PacketState::Finished;
return false;
}
Expand All @@ -1505,8 +1495,7 @@ impl DnsOutPacket {
}

pub(crate) fn insert_short(&mut self, index: usize, value: u16) {
self.data.insert(index, value.to_be_bytes().to_vec());
self.size += 2;
self.data[index..index + 2].copy_from_slice(&value.to_be_bytes());
}

/// Parses a DNS name that may contain escaped characters according to RFC 6763 Section 4.3.
Expand Down Expand Up @@ -1613,7 +1602,7 @@ impl DnsOutPacket {
}

// Store this position for potential future compression
self.names.insert(remaining, self.size as u16);
self.names.insert(remaining, self.size() as u16);

// Write the label
self.write_utf8(label);
Expand All @@ -1623,30 +1612,26 @@ impl DnsOutPacket {
self.write_byte(0);
}

fn write_utf8(&mut self, utf: &str) {
assert!(utf.len() < 64);
self.write_byte(utf.len() as u8);
self.write_bytes(utf.as_bytes());
fn write_byte(&mut self, v: u8) {
self.data.push(v);
}

fn write_bytes(&mut self, s: &[u8]) {
self.data.push(s.to_vec());
self.size += s.len();
self.data.extend(s);
}

fn write_u32(&mut self, int: u32) {
self.data.push(int.to_be_bytes().to_vec());
self.size += 4;
fn write_utf8(&mut self, s: &str) {
assert!(s.len() < 64);
self.write_byte(s.len() as u8);
self.write_bytes(s.as_bytes());
}

fn write_short(&mut self, short: u16) {
self.data.push(short.to_be_bytes().to_vec());
self.size += 2;
fn write_u32(&mut self, v: u32) {
self.data.extend(&v.to_be_bytes());
}

fn write_byte(&mut self, byte: u8) {
self.data.push(vec![byte]);
self.size += 1;
fn write_short(&mut self, v: u16) {
self.data.extend(&v.to_be_bytes());
}

/// Writes the header fields and finish the packet.
Expand Down Expand Up @@ -1680,21 +1665,19 @@ impl DnsOutPacket {
auth_count: u16,
addi_count: u16,
) {
self.insert_short(0, addi_count);
self.insert_short(0, auth_count);
self.insert_short(0, a_count);
self.insert_short(0, q_count);
self.insert_short(0, flags);
self.insert_short(0, id);

// Adjust the size as it was already initialized to include the header.
self.size -= MSG_HEADER_LEN;
self.insert_short(2, flags);
self.insert_short(4, q_count);
self.insert_short(6, a_count);
self.insert_short(8, auth_count);
self.insert_short(10, addi_count);

self.state = PacketState::Finished;
}
}

/// Representation of one outgoing DNS message that could be sent in one or more packet(s).
#[derive(Debug)]
pub struct DnsOutgoing {
flags: u16,
id: u16,
Expand Down Expand Up @@ -1935,7 +1918,7 @@ impl DnsOutgoing {
/// Returns a list of actual DNS packet data to be sent on the wire.
pub fn to_data_on_wire(&self) -> Vec<Vec<u8>> {
let packet_list = self.to_packets();
packet_list.iter().map(|p| p.data.concat()).collect()
packet_list.into_iter().map(|p| p.data).collect()
}

/// Encode self into one or more packets.
Expand Down Expand Up @@ -2601,5 +2584,176 @@ const fn u32_from_be_slice(s: &[u8]) -> u32 {
const fn get_expiration_time(created: u64, ttl: u32, percent: u32) -> u64 {
// 'created' is in millis, 'ttl' is in seconds, hence:
// ttl * 1000 * (percent / 100) => ttl * percent * 10
created + (ttl * percent * 10) as u64
created + (ttl as u64 * percent as u64 * 10)
Comment thread
keepsimple1 marked this conversation as resolved.
}

#[cfg(test)]
mod tests {
use super::{
DnsAddress, DnsHostInfo, DnsOutgoing, DnsPointer, DnsTxt, RRType, CLASS_CACHE_FLUSH,
CLASS_IN,
};
use crate::InterfaceId;
use std::collections::HashMap;
use std::net::{IpAddr, Ipv4Addr};

#[test]
fn test_dns_outgoing_serialization_empty() {
let out = DnsOutgoing::new(0);
let packets = out.to_packets();
assert_eq!(packets.len(), 1);
assert_eq!(packets[0].as_bytes(), &[0; 12]);
let expected_names = HashMap::new();
assert_eq!(&packets[0].names, &expected_names);
}

#[test]
fn test_dns_outgoing_serialization_question() {
let mut out = DnsOutgoing::new(0);
out.add_question("123.test", RRType::A);
let packets = out.to_packets();
assert_eq!(packets.len(), 1);
assert_eq!(
packets[0].as_bytes(),
&[
0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, // Header
// Payload
3, 49, 50, 51, 4, 116, 101, 115, 116, 0, 0, 1, 0, 1,
]
);
let mut expected_names = HashMap::new();
expected_names.insert("123.test".to_string(), 12);
expected_names.insert("test".to_string(), 16);
assert_eq!(&packets[0].names, &expected_names);
}

#[test]
fn test_dns_outgoing_serialization_question_with_authority() {
let mut out = DnsOutgoing::new(0);
out.add_question("123.test", RRType::ANY);
out.add_authority(Box::new(DnsTxt::new(
"124.test",
CLASS_IN,
0x00112233,
b"help".to_vec(),
)));
out.add_authority(Box::new(DnsHostInfo::new(
"124.test",
RRType::CNAME,
CLASS_IN,
0x00112233,
"arm".to_string(),
"linux".to_string(),
)));
let packets = out.to_packets();
assert_eq!(packets.len(), 1);
assert_eq!(
packets[0].as_bytes(),
&[
0, 0, 0, 0, 0, 1, 0, 0, 0, 2, 0, 0, // Header
// Payload
3, 49, 50, 51, 4, 116, 101, 115, 116, 0, 0, 255, 0, 1, 3, 49, 50, 52, 192, 16, 0,
16, 0, 1, 0, 17, 34, 51, 0, 4, 104, 101, 108, 112, 192, 26, 0, 5, 0, 1, 0, 17, 34,
51, 0, 8, 97, 114, 109, 108, 105, 110, 117, 120,
]
);
let mut expected_names = HashMap::new();
expected_names.insert("123.test".to_string(), 12);
expected_names.insert("test".to_string(), 16);
expected_names.insert("124.test".to_string(), 26);
assert_eq!(&packets[0].names, &expected_names);
}

#[test]
fn test_dns_outgoing_serialization_additional_answer() {
let mut out = DnsOutgoing::new(0);
out.add_additional_answer(DnsAddress::new(
"test.local",
RRType::A,
CLASS_IN | CLASS_CACHE_FLUSH,
0xdead_beef,
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
InterfaceId::default(),
));
let packets = out.to_packets();
assert_eq!(packets.len(), 1);
assert_eq!(
packets[0].as_bytes(),
&[
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, // Header
// Payload
4, 116, 101, 115, 116, 5, 108, 111, 99, 97, 108, 0, 0, 1, 128, 1, 222, 173, 190,
239, 0, 4, 127, 0, 0, 1,
]
);
let mut expected_names = HashMap::new();
expected_names.insert("test.local".to_string(), 12);
expected_names.insert("local".to_string(), 17);
assert_eq!(&packets[0].names, &expected_names);
}

#[test]
fn test_dns_outgoing_serialization_answer_at_time() {
let mut out = DnsOutgoing::new(0);
out.add_answer_at_time(
DnsPointer::new(
"test",
RRType::PTR,
CLASS_IN,
0xaaaa5555,
"test-service".to_string(),
),
0,
);
let packets = out.to_packets();
assert_eq!(packets.len(), 1);
assert_eq!(
packets[0].as_bytes(),
&[
0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, // Header
// Payload
4, 116, 101, 115, 116, 0, 0, 12, 0, 1, 170, 170, 85, 85, 0, 14, 12, 116, 101, 115,
116, 45, 115, 101, 114, 118, 105, 99, 101, 0,
]
);

let mut out = DnsOutgoing::new(0);
out.add_answer_at_time(
DnsPointer::new(
"test",
RRType::CNAME,
CLASS_IN,
0xaaaa5555,
"test-service.local".to_string(),
),
0,
);
out.add_answer_at_time(
DnsPointer::new(
"test",
RRType::AAAA,
CLASS_IN,
0xffffffff,
"test-service.local".to_string(),
),
0,
);
let packets = out.to_packets();
assert_eq!(packets.len(), 1);
assert_eq!(
packets[0].as_bytes(),
&[
0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, // Header
// Payload
4, 116, 101, 115, 116, 0, 0, 5, 0, 1, 170, 170, 85, 85, 0, 20, 12, 116, 101, 115,
116, 45, 115, 101, 114, 118, 105, 99, 101, 5, 108, 111, 99, 97, 108, 0, 192, 12, 0,
28, 0, 1, 255, 255, 255, 255, 0, 2, 192, 28,
]
);
let mut expected_names = HashMap::new();
expected_names.insert("test".to_string(), 12);
expected_names.insert("test-service.local".to_string(), 28);
expected_names.insert("local".to_string(), 41);
assert_eq!(&packets[0].names, &expected_names);
}
}