Решение на Network Packets от Антонио Миндов

Обратно към всички решения

Към профила на Антонио Миндов

Резултати

  • 20 точки от тестове
  • 0 бонус точки
  • 20 точки общо
  • 15 успешни тест(а)
  • 0 неуспешни тест(а)

Код

use std::fmt;
use std::cmp;
use std::convert::TryInto;
#[derive(Debug)]
pub enum PacketError {
InvalidPacket,
InvalidChecksum,
UnknownProtocolVersion,
CorruptedMessage,
}
impl fmt::Display for PacketError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
PacketError::InvalidPacket => write!(f, "Ooops that was a bad packet!"),
PacketError::InvalidChecksum => write!(f, "Something doesn't add up!"),
PacketError::UnknownProtocolVersion => write!(f, "You must be from the future with that protocol version!"),
PacketError::CorruptedMessage => write!(f, "That message came from a bulgarian politician! Rejected!"),
}
}
}
impl std::error::Error for PacketError {}
pub struct ByteConsumer<'a> {
data: &'a[u8],
consumed: usize
}
impl<'a> ByteConsumer<'a> {
pub fn from(data: &[u8]) -> ByteConsumer {
ByteConsumer {data, consumed: 0}
}
pub fn single(&mut self) -> u8 {
// The user of this class should take care of checking if there are more bytes left
// much like an array
if self.bytes_left() < 1 {
panic!("These are not the bytes you were looking for!")
}
let result = self.data[self.consumed];
self.consumed += 1;
result
}
pub fn consume(&mut self, amount: usize) -> &'a [u8] {
// The user of this class should take care of checking if there are more bytes left
// much like an array
if self.bytes_left() < amount {
panic!("These are not the bytes you were looking for!")
}
let result = &self.data[self.consumed..self.consumed+amount];
self.consumed += amount;
result
}
pub fn bytes_left(&self) -> usize {
self.data.len() - self.consumed
}
pub fn rest_bytes(& self) -> &'a[u8] {
&self.data[self.consumed..]
}
}

Хм, това е интересна структура. Едно нещо, което можеше да я опрости е просто да променяш self.data когато викнеш consume да бъде нов slice, който е отместен с консумираната стойност. Може да видиш методите split_at примерно, които директно ще ти дадат парче от slice-а и остатъка, на който можеш да кажеш self.data = remainder. Тогава няма да има нужда да поддържаш ръчно self.consumed, bytes_left ще е просто len(), a rest_bytes ще е просто self.data :).

Можеше и да преизползваш consume в single, примерно consume(1).get(0).unwrap() or something. Тук аз бих накарал single да връща Option<u32> и в кода по-долу можеш да кажеш примерно let payload_size = consumer.single().ok_or(PacketError::InvalidPacket)?. Или можеше .single() директно да връща Result с PacketError ако приемеш, че е вътрешна структура на този модул използвана само за него.

Да кажеш "този метод очаква да си проверил еди-какво си преди да го извикаш" може да е ок понякога, но в Rust не е трудно да работиш с опционалност и грешки и това ти дава структура с методи, които спокойно могат да се използват без това допълнително знание. Колкото повече неща трябва да помниш, за да извикаш един метод правилно, толкова по-лесно е да допуснеш грешка.

#[derive(PartialEq, Debug)]
pub struct Packet<'a> {
pub version: u8,
payload: &'a [u8]
}
static PACKET_CHECKSUM_BYTES: usize = std::mem::size_of::<u32>();
static MINIMUM_PACKET_SIZE: usize = 3 * std::mem::size_of::<u8>() + PACKET_CHECKSUM_BYTES;
impl<'a> Packet<'a> {
pub fn from_source(source: &'a [u8], size: u8) -> (Self, &'a [u8]) {
if size == 0 {
panic!("No packets for you mister!")
}
let real_size = cmp::min(size as usize, source.len());
let (payload, rest) = source.split_at(real_size);
(Packet { version: 1, payload }, rest)
}
pub fn payload(&'a self) -> &'a [u8] {
&self.payload
}
pub fn serialize(& self) -> Vec<u8> {
let length = self.payload.len() as u8;
let mut result = Vec::new();
result.push(self.version);
result.push(length);
result.extend(self.payload.iter());
result.extend(&Packet::checksum_of_bytes(&self.payload).to_be_bytes());
result
}
pub fn deserialize(bytes: &[u8]) -> Result<(Packet, &[u8]), PacketError> {
let mut consumer = ByteConsumer::from(bytes);
if consumer.bytes_left() < MINIMUM_PACKET_SIZE {
return Err(PacketError::InvalidPacket)
}
let version = consumer.single();
if version != 1 {
return Err(PacketError::UnknownProtocolVersion)
}
let payload_size = consumer.single() as usize;
if consumer.bytes_left() < payload_size + PACKET_CHECKSUM_BYTES {
return Err(PacketError::InvalidPacket)
}
let payload = consumer.consume(payload_size);
let checksum_bytes: Vec<u8> = consumer.consume(PACKET_CHECKSUM_BYTES).to_vec();
// Safe unwrap because we are sure there are 4 checksum bytes
let real_bytes = (&checksum_bytes[..]).try_into().expect("Something went wrong...");
let given_checksum = u32::from_be_bytes(real_bytes);
if Packet::checksum_of_bytes(&payload) != given_checksum {
return Err(PacketError::InvalidChecksum)
}
return Ok((Packet::from_source(payload, payload_size as u8).0, consumer.rest_bytes()))
}
fn checksum_of_bytes(bytes: &[u8]) -> u32 {
bytes.iter().map(|&b| b as u32).sum()
}
}
pub struct PacketSerializer<'a> {
data: &'a [u8],
packet_size: u8
}
impl<'a> Iterator for PacketSerializer<'a> {
type Item = Packet<'a>;
fn next(&mut self) -> Option<Self::Item> {
if self.data.len() == 0 {
return None
}
let (result, new_data) = Packet::from_source(self.data, self.packet_size);
self.data = new_data;
Some(result)
}
}
pub trait Packetable: Sized {
fn to_packets(&self, packet_size: u8) -> PacketSerializer;
fn to_packet_data(&self, packet_size: u8) -> Vec<u8>;
fn from_packet_data(packet_data: &[u8]) -> Result<Self, PacketError>;
}
impl Packetable for String {
fn to_packets(&self, packet_size: u8) -> PacketSerializer {
PacketSerializer {data:self.as_bytes(), packet_size}
}
fn to_packet_data(&self, packet_size: u8) -> Vec<u8> {
let mut result = Vec::new();
for packet in self.to_packets(packet_size) {
result.extend(packet.serialize())
}
result
}
fn from_packet_data(packet_data: &[u8]) -> Result<Self, PacketError> {
let mut data: Vec<u8> = Vec::new();
let mut packet_data_copy = packet_data;
while packet_data_copy.len() > 0 {
let (packet, new_packet_data) = Packet::deserialize(packet_data_copy)?;
data.extend(packet.payload());
packet_data_copy = new_packet_data;
}
String::from_utf8(data).map_err(|_| PacketError::CorruptedMessage)
}
}
#[cfg(test)]
mod tests {
use crate::*;
#[test]
fn test_packets_construction() {
let source = b"hello";
let (packet, remainder) = Packet::from_source(source, 100);
assert_eq!(packet.payload().len(), source.len());
assert_eq!(remainder, b"");
let (packet, remainder) = Packet::from_source(source, 3);
assert_eq!(packet.payload(), b"hel");
assert_eq!(remainder, b"lo");
}
#[test]
#[should_panic]
fn test_packet_construction_with_no_size() {
let source = b"hello";
Packet::from_source(source, 0);
}
#[test]
fn test_packet_serialization() {
let source: Vec<u8> = vec![0, 1, 2, 3, 4];
let (packet, _) = Packet::from_source(&source, source.len() as u8);
let serialized = packet.serialize();
let payload_end = 2+source.len();
assert_eq!(serialized[0], 1);
assert_eq!(serialized[1], source.len() as u8);
assert_eq!(serialized[payload_end..payload_end+4], 10u32.to_be_bytes());
}
#[test]
fn test_packet_deserialization() {
let source: Vec<u8> = vec![0, 1, 2, 3, 4];
let (packet, _) = Packet::from_source(&source, source.len() as u8);
let serialized = packet.serialize();
let mut whole_bytes = serialized.clone();
whole_bytes.extend(serialized.clone());
if let Ok((packet, rest)) = Packet::deserialize(&whole_bytes) {
assert_eq!(packet.version, 1);
assert_eq!(packet.payload(), &source[..]);
assert_eq!(rest, &serialized[..]);
} else {
assert!(false)
}

Ако очакваш в тест нещо да ти върне Ok, напълно смислено е да му викнеш unwrap. Това ще ти даде по-ясна грешка от assert!(false), защото ще ти каже каква е стойността, която unwrap-ва.

}
#[test]
fn test_basic_packets() {
let source = b"hello";
let (packet, remainder) = Packet::from_source(source, 100);
assert_eq!(packet.payload().len(), source.len());
assert_eq!(remainder, b"");
assert!(packet.serialize().len() > 0);
if let Err(_) = Packet::deserialize(&packet.serialize()) {
assert!(false, "Couldn't deserialize serialized packet");
}
}
#[test]
fn test_basic_packet_iteration() {
let source = String::from("hell");
let packets = source.to_packets(2).collect::<Vec<Packet>>();
assert_eq!(packets[0].payload(), &source.as_bytes()[0..2]);
assert_eq!(packets[1].payload(), &source.as_bytes()[2..4]);
}
#[test]
fn test_basic_packet_deserialization_to_data() {
let source = String::from("Hello Rusty ones!");
let packet_data = source.to_packet_data(5);
if let Ok(deserialized) = String::from_packet_data(&packet_data) {
assert_eq!(deserialized, source);
} else {
assert!(false, "Couldn't deserialize serialized packet data");
}
}
#[test]
fn test_basic_iteration() {
let source = String::from("hello");
let packets = source.to_packets(100).collect::<Vec<Packet>>();
assert!(packets.len() > 0);
let data = source.to_packet_data(100);
assert!(data.len() > 0);
if let Err(_) = String::from_packet_data(&data) {
assert!(false, "Couldn't deserialize serialized packet data");
}
}
}

Лог от изпълнението

Compiling solution v0.1.0 (/tmp/d20200111-2173579-1113f5w/solution)
    Finished test [unoptimized + debuginfo] target(s) in 4.56s
     Running target/debug/deps/solution-a73e64ec87929bd0

running 8 tests
test tests::test_basic_iteration ... ok
test tests::test_basic_packet_deserialization_to_data ... ok
test tests::test_basic_packet_iteration ... ok
test tests::test_basic_packets ... ok
test tests::test_packet_construction_with_no_size ... ok
test tests::test_packet_deserialization ... ok
test tests::test_packet_serialization ... ok
test tests::test_packets_construction ... ok

test result: ok. 8 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out

     Running target/debug/deps/solution_test-38971695424b36d5

running 15 tests
test solution_test::test_construct_packet_from_unicode ... ok
test solution_test::test_construct_packet_no_remainder ... ok
test solution_test::test_construct_packet_with_remainder ... ok
test solution_test::test_construct_packet_with_remainder_cyrillic ... ok
test solution_test::test_consuming_packets ... ok
test solution_test::test_deserialize_invalid_packet ... ok
test solution_test::test_deserialize_packet ... ok
test solution_test::test_deserialize_unicode_packet ... ok
test solution_test::test_full_roundtrip ... ok
test solution_test::test_full_roundtrip_for_zero_size_string ... ok
test solution_test::test_invalid_packet_combination ... ok
test solution_test::test_iterating_packets ... ok
test solution_test::test_iterating_packets_for_zero_size_string ... ok
test solution_test::test_serialize_packet ... ok
test solution_test::test_zero_size ... ok

test result: ok. 15 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out

   Doc-tests solution

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out

История (1 версия и 4 коментара)

Антонио качи първо решение на 02.12.2019 13:25 (преди почти 6 години)

Добро решение, макар че ByteConsumer-а можеше да се съкрати доста, до състояние, в което си мисля, че не е много полезен :). Оценявам тестовете, но си тествал само с ASCII низове. Когато се работи с низове и байтове, много добра идея е да се подсигуриш, че кода ти обработва UTF8 низове (примерно кирилица) без да експлодира.