Решение на Network Packets от Мирослав Лалев

Обратно към всички решения

Към профила на Мирослав Лалев

Резултати

  • 20 точки от тестове
  • 1 бонус точка
  • 21 точки общо
  • 15 успешни тест(а)
  • 0 неуспешни тест(а)

Код

use std::convert::TryInto;
use std::fmt;
#[derive(PartialEq, Debug)]
pub enum PacketError {
InvalidPacket,
InvalidChecksum,
UnknownProtocolVersion,
CorruptedMessage,
}
impl fmt::Display for PacketError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(match self {
PacketError::InvalidPacket => "invalid packet",
PacketError::InvalidChecksum => "invalid checksum",
PacketError::UnknownProtocolVersion => "unknown protocol",
PacketError::CorruptedMessage => "corrupted message",
})
}
}
impl std::error::Error for PacketError {}
#[derive(PartialEq, Debug)]
pub struct Packet<'a> {
payload: &'a [u8],
}
impl<'a> Packet<'a> {
const CURRENT_VERSION: u8 = 1;
pub fn from_source(source: &'a [u8], size: u8) -> (Self, &'a [u8]) {
match size as usize {
0 => unreachable!(),

Не бих използвал unreachable! тук, понеже този случай със сигурност може да се reach-не -- просто трябва някой да го извика с 0 :). "Unreachable" ще рече "логически, не би трябвало да е възможно да се удари този случай". Примерно, ако в единия ръкав проверяваш n > 0, а в другия n <= 0, пак трябва да напишеш 3-ти ръкав със _, който ще бъде наистина unreachable, но компилатора няма вградената логика да разбере това.

s if s >= source.len() => (Self { payload: source }, &[]),
_ => (
Self {
payload: &source[..size as usize],
},
&source[size as usize..],
),
}
}
pub fn payload(&self) -> &[u8] {
self.payload
}
pub fn serialize(&self) -> Vec<u8> {
let mut res = vec![Packet::CURRENT_VERSION, self.size()];
res.extend(self.payload);
res.extend(self.checksum().to_be_bytes().into_iter());
res
}
fn size(&self) -> u8 {
self.payload
.len()
.try_into()
.expect("payload size must be a single byte")
}
fn checksum(&self) -> u32 {
Packet::calc_checksum(self.payload)
}
fn calc_checksum(payload: &[u8]) -> u32 {
payload.iter().fold(0u32, |sum, item| sum + *item as u32)
}
pub fn deserialize(bytes: &'a [u8]) -> Result<(Packet, &'a [u8]), PacketError> {
let mut bytes_iter = bytes.into_iter();
let version = *bytes_iter.next().ok_or(PacketError::InvalidPacket)?;
if version != Packet::CURRENT_VERSION {
return Err(PacketError::UnknownProtocolVersion);
}
let num_bytes = *bytes_iter.next().ok_or(PacketError::InvalidPacket)? as usize;
Packet::validate_next(bytes_iter.by_ref(), num_bytes + 4)?;
let payload = &bytes[2..2 + num_bytes];
let checksum_bytes = &bytes[2 + num_bytes..6 + num_bytes];
let mut checksum_buffer = [0u8; 4];
checksum_buffer.copy_from_slice(checksum_bytes);
if u32::from_be_bytes(checksum_buffer) != Packet::calc_checksum(payload) {
return Err(PacketError::InvalidChecksum);
}
Ok((Packet { payload }, bytes_iter.as_slice()))

Хитро решение да използваш итератор за да се движиш напред по байтовете. Жалко само, че payload и checksum_bytes не се четат през итератора а се стига до 2 +, 6 +, etc. Но предполагам, че нещата са се усложнили с take-ването -- къде да ги съхраниш тия байтове, трябва да се направи вектор може би.

Аз използвах split_first и split_at за тия цели -- може да хвърлиш око на тия функции: https://doc.rust-lang.org/std/primitive.slice.html#method.split_at

}
fn validate_next<T>(iter: &mut dyn Iterator<Item = T>, n: usize) -> Result<(), PacketError> {
let actual = iter.take(n).count();
if actual == n {
Ok(())
} else {
Err(PacketError::InvalidPacket)
}
}
}
pub struct PacketSerializer<'a> {
src: &'a [u8],
packet_size: u8,
}
impl<'a> PacketSerializer<'a> {
fn new(src: &'a [u8], packet_size: u8) -> Self {
Self { src, packet_size }
}
}
impl<'a> Iterator for PacketSerializer<'a> {
type Item = Packet<'a>;
fn next(&mut self) -> Option<Self::Item> {
match self.src.len() {
0 => None,
_ => {
let (p, rem) = Packet::from_source(self.src, self.packet_size);
self.src = rem;
Some(p)
}
}
}
}
pub trait Packetable: Sized {
fn to_packets(&self, packet_size: u8) -> PacketSerializer;
fn to_packet_data(&self, packet_size: u8) -> Vec<u8>;
fn from_packet_data(packet_data: &[u8]) -> Result<Self, PacketError>;
}
impl Packetable for String {
fn to_packets(&self, packet_size: u8) -> PacketSerializer {
PacketSerializer::new(self.as_bytes(), packet_size)
}
fn to_packet_data(&self, packet_size: u8) -> Vec<u8> {
self.to_packets(packet_size)
.map(|p| p.serialize())

Чудих се дали има начин да направя итераторът тук да работи върху референции (т.е. &Item), което би ми позволило да напиша следното: .map(Packet::serialize)?

Но пък дори и да може, ще трябва да се добави .copied след .flatten, така че май няма голямо значение.

Друг вариант е да направя serialize да взема оунършип, но тогава ще счупя апи-то.

Би могъл да опиташ type Item = &Packet, за да може PacketSerializer-а да връща references към пакети, но не съм много сигурен, че би било удобно. Пакетите вече са някакъв reference тип, да добавиш още едно ниво на индирекция ми изглежда като че би усложнило API-то ненужно (и да, би счупило тестовете, но ако искаш да експериментираш, може да си изтеглиш теста и да го модифицираш).

По принцип не бих променил интерфейса на кода само заради нещо такова. Да, би било по-красиво :). Но и това е достатъчно просто и ясно, и голяма промяна в интерфейса само заради малко подобрение в четимостта на този метод не си заслужава. Разбира се, заслужава си да мислиш и да експериментираш с тези неща -- пробвай!

Иначе, можеш да заместиш .map().flatten() с .flat_map()

.flatten()
.collect::<Vec<_>>()
}
fn from_packet_data(packet_data: &[u8]) -> Result<Self, PacketError> {
let mut bytes = Vec::new();
let mut tmp_ref = packet_data;
while tmp_ref.len() > 0 {
let (p, rem) = Packet::deserialize(tmp_ref)?;
bytes.extend(p.payload());
tmp_ref = rem;
}
String::from_utf8(bytes).map_err(|_| PacketError::CorruptedMessage)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn packet_creation() {
let (p, rem) = Packet::from_source(b"1234", 3);
assert_eq!(p, Packet { payload: b"123" });
assert_eq!(rem, b"4");
let (p, rem) = Packet::from_source(b"1234", 5);
assert_eq!(p, Packet { payload: b"1234" });
assert_eq!(rem, &[]);
let (p, rem) = Packet::from_source(&[], 2);
assert_eq!(p, Packet { payload: &[] });
assert_eq!(rem, &[]);
}
#[test]
#[should_panic]
fn packet_creation_no_size() {
Packet::from_source(b"1234", 0);
}
#[test]
fn packet_serialize_deserialize_empty() {
let (p, _) = Packet::from_source(&[], 1);
let bytes = p.serialize();
assert_eq!(bytes, [1, 0, 0, 0, 0, 0]);
let (p, rem) = Packet::deserialize(bytes.as_slice()).unwrap();
assert_eq!(rem, &[]);
assert_eq!(p, Packet { payload: &[] });
}
#[test]
fn packet_serialize_deserialize() {
let (p, _) = Packet::from_source(b"1", 1);
let bytes = p.serialize();
assert_eq!(bytes, [1, 1, b'1', 0, 0, 0, b'1']);
let (p, rem) = Packet::deserialize(bytes.as_slice()).unwrap();
assert_eq!(rem, &[]);
assert_eq!(p, Packet { payload: b"1" });
}
#[test]
fn packet_deserialize_errs() {
let res = Packet::deserialize(&[0, 1, 1, 0, 0, 0, 1]);
assert_eq!(res, Err(PacketError::UnknownProtocolVersion));
let res = Packet::deserialize(&[]);
assert_eq!(res, Err(PacketError::InvalidPacket));
let res = Packet::deserialize(&[1, 3, 1, 2, 0, 0, 0, 3]);
assert_eq!(res, Err(PacketError::InvalidPacket));
let res = Packet::deserialize(&[1, 1, 1, 0, 0, 1]);
assert_eq!(res, Err(PacketError::InvalidPacket));
let res = Packet::deserialize(&[1, 2, 1, 2, 3, 0, 0, 0, 6]);
assert_eq!(res, Err(PacketError::InvalidChecksum));
}
#[test]
fn packet_deserialize_returns_remainder() {
let (p, rem) = Packet::deserialize(&[1, 2, 4, 4, 0, 0, 0, 8, 1, 2, 3]).unwrap();
assert_eq!(p, Packet { payload: &[4, 4] });
assert_eq!(rem, &[1, 2, 3]);
}
#[test]
fn packet_serializer() {
let mut packet_iter = PacketSerializer::new(b"012", 2).into_iter();
assert_eq!(packet_iter.next(), Some(Packet { payload: b"01" }));
assert_eq!(packet_iter.next(), Some(Packet { payload: b"2" }));
assert_eq!(packet_iter.next(), None);
}
#[test]
fn packetable_string() {
let data = String::from("abcd");
let mut packets = data.to_packets(2);
assert_eq!(packets.next(), Some(Packet { payload: b"ab" }));
assert_eq!(packets.next(), Some(Packet { payload: b"cd" }));
assert_eq!(packets.next(), None);
let data = String::from("abc");
let packet_data = data.to_packet_data(2);
assert_eq!(
packet_data,
vec![
vec![1, 2, b'a', b'b', 0, 0, 0, b'a' + b'b'],
vec![1, 1, b'c', 0, 0, 0, b'c']
]
.iter()
.flatten()
.copied()
.collect::<Vec<_>>()
);
let data = String::from_packet_data(packet_data.as_slice());
assert_eq!(data, Ok(String::from("abc")));
}
#[test]
fn packetable_string_corrupted() {
let data = String::from_packet_data(&[1, 2, b'\x80', b'\x81', 0, 0, 1, 1]);
assert_eq!(data, Err(PacketError::CorruptedMessage));
}
}

Лог от изпълнението

Compiling solution v0.1.0 (/tmp/d20200111-2173579-iiohcd/solution)
    Finished test [unoptimized + debuginfo] target(s) in 4.64s
     Running target/debug/deps/solution-a73e64ec87929bd0

running 9 tests
test tests::packet_creation ... ok
test tests::packet_creation_no_size ... ok
test tests::packet_deserialize_errs ... ok
test tests::packet_deserialize_returns_remainder ... ok
test tests::packet_serialize_deserialize ... ok
test tests::packet_serialize_deserialize_empty ... ok
test tests::packet_serializer ... ok
test tests::packetable_string ... ok
test tests::packetable_string_corrupted ... ok

test result: ok. 9 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out

     Running target/debug/deps/solution_test-38971695424b36d5

running 15 tests
test solution_test::test_construct_packet_from_unicode ... ok
test solution_test::test_construct_packet_no_remainder ... ok
test solution_test::test_construct_packet_with_remainder ... ok
test solution_test::test_construct_packet_with_remainder_cyrillic ... ok
test solution_test::test_consuming_packets ... ok
test solution_test::test_deserialize_invalid_packet ... ok
test solution_test::test_deserialize_packet ... ok
test solution_test::test_deserialize_unicode_packet ... ok
test solution_test::test_full_roundtrip ... ok
test solution_test::test_full_roundtrip_for_zero_size_string ... ok
test solution_test::test_invalid_packet_combination ... ok
test solution_test::test_iterating_packets ... ok
test solution_test::test_iterating_packets_for_zero_size_string ... ok
test solution_test::test_serialize_packet ... ok
test solution_test::test_zero_size ... ok

test result: ok. 15 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out

   Doc-tests solution

running 0 tests

test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out

История (2 версии и 6 коментара)

Мирослав качи първо решение на 01.12.2019 15:15 (преди почти 6 години)

Мирослав качи решение на 01.12.2019 16:05 (преди почти 6 години)

use std::convert::TryInto;
use std::fmt;
#[derive(PartialEq, Debug)]
pub enum PacketError {
InvalidPacket,
InvalidChecksum,
UnknownProtocolVersion,
CorruptedMessage,
}
impl fmt::Display for PacketError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str(match self {
PacketError::InvalidPacket => "invalid packet",
PacketError::InvalidChecksum => "invalid checksum",
PacketError::UnknownProtocolVersion => "unknown protocol",
PacketError::CorruptedMessage => "corrupted message",
})
}
}
impl std::error::Error for PacketError {}
#[derive(PartialEq, Debug)]
pub struct Packet<'a> {
payload: &'a [u8],
}
impl<'a> Packet<'a> {
const CURRENT_VERSION: u8 = 1;
pub fn from_source(source: &'a [u8], size: u8) -> (Self, &'a [u8]) {
match size as usize {
0 => unreachable!(),

Не бих използвал unreachable! тук, понеже този случай със сигурност може да се reach-не -- просто трябва някой да го извика с 0 :). "Unreachable" ще рече "логически, не би трябвало да е възможно да се удари този случай". Примерно, ако в единия ръкав проверяваш n > 0, а в другия n <= 0, пак трябва да напишеш 3-ти ръкав със _, който ще бъде наистина unreachable, но компилатора няма вградената логика да разбере това.

s if s >= source.len() => (Self { payload: source }, &[]),
_ => (
Self {
payload: &source[..size as usize],
},
&source[size as usize..],
),
}
}
pub fn payload(&self) -> &[u8] {
self.payload
}
pub fn serialize(&self) -> Vec<u8> {
let mut res = vec![Packet::CURRENT_VERSION, self.size()];
res.extend(self.payload);
res.extend(self.checksum().to_be_bytes().into_iter());
res
}
fn size(&self) -> u8 {
self.payload
.len()
.try_into()
.expect("payload size must be a single byte")
}
fn checksum(&self) -> u32 {
Packet::calc_checksum(self.payload)
}
fn calc_checksum(payload: &[u8]) -> u32 {
payload.iter().fold(0u32, |sum, item| sum + *item as u32)
}
pub fn deserialize(bytes: &'a [u8]) -> Result<(Packet, &'a [u8]), PacketError> {
let mut bytes_iter = bytes.into_iter();
let version = *bytes_iter.next().ok_or(PacketError::InvalidPacket)?;
if version != Packet::CURRENT_VERSION {
return Err(PacketError::UnknownProtocolVersion);
}
let num_bytes = *bytes_iter.next().ok_or(PacketError::InvalidPacket)? as usize;
Packet::validate_next(bytes_iter.by_ref(), num_bytes + 4)?;
let payload = &bytes[2..2 + num_bytes];
let checksum_bytes = &bytes[2 + num_bytes..6 + num_bytes];
let mut checksum_buffer = [0u8; 4];
checksum_buffer.copy_from_slice(checksum_bytes);
if u32::from_be_bytes(checksum_buffer) != Packet::calc_checksum(payload) {
return Err(PacketError::InvalidChecksum);
}
Ok((Packet { payload }, bytes_iter.as_slice()))

Хитро решение да използваш итератор за да се движиш напред по байтовете. Жалко само, че payload и checksum_bytes не се четат през итератора а се стига до 2 +, 6 +, etc. Но предполагам, че нещата са се усложнили с take-ването -- къде да ги съхраниш тия байтове, трябва да се направи вектор може би.

Аз използвах split_first и split_at за тия цели -- може да хвърлиш око на тия функции: https://doc.rust-lang.org/std/primitive.slice.html#method.split_at

}
- fn validate_next<T>(iter: &mut Iterator<Item = T>, n: usize) -> Result<(), PacketError> {
+ fn validate_next<T>(iter: &mut dyn Iterator<Item = T>, n: usize) -> Result<(), PacketError> {
let actual = iter.take(n).count();
- match actual == n {
- true => Ok(()),
- false => Err(PacketError::InvalidPacket),
+ if actual == n {
+ Ok(())
+ } else {
+ Err(PacketError::InvalidPacket)
}
}
}
pub struct PacketSerializer<'a> {
src: &'a [u8],
packet_size: u8,
}
impl<'a> PacketSerializer<'a> {
fn new(src: &'a [u8], packet_size: u8) -> Self {
Self { src, packet_size }
}
}
impl<'a> Iterator for PacketSerializer<'a> {
type Item = Packet<'a>;
fn next(&mut self) -> Option<Self::Item> {
match self.src.len() {
0 => None,
_ => {
let (p, rem) = Packet::from_source(self.src, self.packet_size);
self.src = rem;
Some(p)
}
}
}
}
pub trait Packetable: Sized {
fn to_packets(&self, packet_size: u8) -> PacketSerializer;
fn to_packet_data(&self, packet_size: u8) -> Vec<u8>;
fn from_packet_data(packet_data: &[u8]) -> Result<Self, PacketError>;
}
impl Packetable for String {
fn to_packets(&self, packet_size: u8) -> PacketSerializer {
PacketSerializer::new(self.as_bytes(), packet_size)
}
fn to_packet_data(&self, packet_size: u8) -> Vec<u8> {
self.to_packets(packet_size)
.map(|p| p.serialize())

Чудих се дали има начин да направя итераторът тук да работи върху референции (т.е. &Item), което би ми позволило да напиша следното: .map(Packet::serialize)?

Но пък дори и да може, ще трябва да се добави .copied след .flatten, така че май няма голямо значение.

Друг вариант е да направя serialize да взема оунършип, но тогава ще счупя апи-то.

Би могъл да опиташ type Item = &Packet, за да може PacketSerializer-а да връща references към пакети, но не съм много сигурен, че би било удобно. Пакетите вече са някакъв reference тип, да добавиш още едно ниво на индирекция ми изглежда като че би усложнило API-то ненужно (и да, би счупило тестовете, но ако искаш да експериментираш, може да си изтеглиш теста и да го модифицираш).

По принцип не бих променил интерфейса на кода само заради нещо такова. Да, би било по-красиво :). Но и това е достатъчно просто и ясно, и голяма промяна в интерфейса само заради малко подобрение в четимостта на този метод не си заслужава. Разбира се, заслужава си да мислиш и да експериментираш с тези неща -- пробвай!

Иначе, можеш да заместиш .map().flatten() с .flat_map()

.flatten()
.collect::<Vec<_>>()
}
fn from_packet_data(packet_data: &[u8]) -> Result<Self, PacketError> {
let mut bytes = Vec::new();
let mut tmp_ref = packet_data;
while tmp_ref.len() > 0 {
let (p, rem) = Packet::deserialize(tmp_ref)?;
bytes.extend(p.payload());
tmp_ref = rem;
}
String::from_utf8(bytes).map_err(|_| PacketError::CorruptedMessage)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn packet_creation() {
let (p, rem) = Packet::from_source(b"1234", 3);
assert_eq!(p, Packet { payload: b"123" });
assert_eq!(rem, b"4");
let (p, rem) = Packet::from_source(b"1234", 5);
assert_eq!(p, Packet { payload: b"1234" });
assert_eq!(rem, &[]);
let (p, rem) = Packet::from_source(&[], 2);
assert_eq!(p, Packet { payload: &[] });
assert_eq!(rem, &[]);
}
#[test]
#[should_panic]
fn packet_creation_no_size() {
Packet::from_source(b"1234", 0);
}
#[test]
fn packet_serialize_deserialize_empty() {
let (p, _) = Packet::from_source(&[], 1);
let bytes = p.serialize();
assert_eq!(bytes, [1, 0, 0, 0, 0, 0]);
let (p, rem) = Packet::deserialize(bytes.as_slice()).unwrap();
assert_eq!(rem, &[]);
assert_eq!(p, Packet { payload: &[] });
}
#[test]
fn packet_serialize_deserialize() {
let (p, _) = Packet::from_source(b"1", 1);
let bytes = p.serialize();
assert_eq!(bytes, [1, 1, b'1', 0, 0, 0, b'1']);
let (p, rem) = Packet::deserialize(bytes.as_slice()).unwrap();
assert_eq!(rem, &[]);
assert_eq!(p, Packet { payload: b"1" });
}
#[test]
fn packet_deserialize_errs() {
let res = Packet::deserialize(&[0, 1, 1, 0, 0, 0, 1]);
assert_eq!(res, Err(PacketError::UnknownProtocolVersion));
let res = Packet::deserialize(&[]);
assert_eq!(res, Err(PacketError::InvalidPacket));
let res = Packet::deserialize(&[1, 3, 1, 2, 0, 0, 0, 3]);
assert_eq!(res, Err(PacketError::InvalidPacket));
let res = Packet::deserialize(&[1, 1, 1, 0, 0, 1]);
assert_eq!(res, Err(PacketError::InvalidPacket));
let res = Packet::deserialize(&[1, 2, 1, 2, 3, 0, 0, 0, 6]);
assert_eq!(res, Err(PacketError::InvalidChecksum));
}
#[test]
fn packet_deserialize_returns_remainder() {
let (p, rem) = Packet::deserialize(&[1, 2, 4, 4, 0, 0, 0, 8, 1, 2, 3]).unwrap();
assert_eq!(p, Packet { payload: &[4, 4] });
assert_eq!(rem, &[1, 2, 3]);
}
#[test]
fn packet_serializer() {
let mut packet_iter = PacketSerializer::new(b"012", 2).into_iter();
assert_eq!(packet_iter.next(), Some(Packet { payload: b"01" }));
assert_eq!(packet_iter.next(), Some(Packet { payload: b"2" }));
assert_eq!(packet_iter.next(), None);
}
#[test]
fn packetable_string() {
let data = String::from("abcd");
let mut packets = data.to_packets(2);
assert_eq!(packets.next(), Some(Packet { payload: b"ab" }));
assert_eq!(packets.next(), Some(Packet { payload: b"cd" }));
assert_eq!(packets.next(), None);
let data = String::from("abc");
let packet_data = data.to_packet_data(2);
assert_eq!(
packet_data,
vec![
vec![1, 2, b'a', b'b', 0, 0, 0, b'a' + b'b'],
vec![1, 1, b'c', 0, 0, 0, b'c']
]
.iter()
.flatten()
.copied()
.collect::<Vec<_>>()
);
let data = String::from_packet_data(packet_data.as_slice());
assert_eq!(data, Ok(String::from("abc")));
}
#[test]
fn packetable_string_corrupted() {
let data = String::from_packet_data(&[1, 2, b'\x80', b'\x81', 0, 0, 1, 1]);
assert_eq!(data, Err(PacketError::CorruptedMessage));
}
}