built-in mDNS; memory optimizations

This commit is contained in:
ivmarkov 2023-05-24 10:07:11 +00:00
parent fccf9fa5f6
commit 2e0a09b532
30 changed files with 780 additions and 563 deletions

View file

@ -15,13 +15,14 @@ name = "matter"
path = "src/lib.rs"
[features]
default = ["std", "crypto_mbedtls", "backtrace"]
std = ["alloc", "env_logger", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "async-io", "smol"]
default = ["os", "crypto_rustcrypto"]
os = ["std", "backtrace", "critical-section/std", "embassy-sync/std", "embassy-time/std"]
std = ["alloc", "env_logger", "rand", "qrcode", "async-io", "smol", "esp-idf-sys/std"]
backtrace = []
alloc = []
nightly = []
crypto_openssl = ["alloc", "openssl", "foreign-types", "hmac", "sha2"]
crypto_mbedtls = ["alloc", "mbedtls", "esp-idf-sys"]
crypto_mbedtls = ["alloc", "mbedtls"]
crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"]
[dependencies]
@ -40,14 +41,16 @@ safemem = { version = "0.3.3", default-features = false }
owo-colors = "3"
time = { version = "0.3", default-features = false }
verhoeff = { version = "1", default-features = false }
embassy-futures = "0.1"
embassy-time = { version = "0.1.1", features = ["generic-queue-8"] }
embassy-sync = "0.2"
critical-section = "1.1.1"
domain = { version = "0.7.2", default_features = false }
# STD-only dependencies
rand = { version = "0.8.5", optional = true }
qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code
simple-mdns = { version = "0.4", features = ["sync"], optional = true }
simple-dns = { version = "0.5", optional = true }
astro-dnssd = { version = "0.3", optional = true } # On Linux needs avahi-compat-libdns_sd, i.e. on Ubuntu/Debian do `sudo apt-get install libavahi-compat-libdnssd-dev`
zeroconf = { version = "0.10", optional = true }
smol = { version = "1.2", optional = true } # =1.2 for compatibility with ESP IDF
async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with ESP IDF
@ -71,14 +74,9 @@ x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], o
[target.'cfg(not(target_os = "espidf"))'.dependencies]
mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true }
env_logger = { version = "0.10.0", optional = true }
libmdns = { version = "0.7", optional = true }
[target.'cfg(target_os = "espidf")'.dependencies]
esp-idf-sys = { version = "0.32", default-features = false, features = ["native"], optional = true }
[[example]]
name = "onoff_light"
path = "../examples/onoff_light/src/main.rs"
esp-idf-sys = { version = "0.33", default-features = false, features = ["native"] }
[[example]]

View file

@ -22,7 +22,7 @@ use crate::{
error::{Error, ErrorCode},
fabric,
interaction_model::messages::GenericPath,
tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV},
tlv::{self, FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV},
transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC},
utils::writebuf::WriteBuf,
};
@ -390,7 +390,7 @@ impl AclEntry {
const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS;
type AclEntries = [Option<AclEntry>; MAX_ACL_ENTRIES];
type AclEntries = heapless::Vec<Option<AclEntry>, MAX_ACL_ENTRIES>;
pub struct AclMgr {
entries: AclEntries,
@ -398,20 +398,16 @@ pub struct AclMgr {
}
impl AclMgr {
#[inline(always)]
pub const fn new() -> Self {
const INIT: Option<AclEntry> = None;
Self {
entries: [INIT; MAX_ACL_ENTRIES],
entries: AclEntries::new(),
changed: false,
}
}
pub fn erase_all(&mut self) -> Result<(), Error> {
for i in 0..MAX_ACL_ENTRIES {
self.entries[i] = None;
}
self.entries.clear();
self.changed = true;
Ok(())
@ -427,14 +423,21 @@ impl AclMgr {
if cnt >= ENTRIES_PER_FABRIC {
Err(ErrorCode::NoSpace)?;
}
let index = self
.entries
.iter()
.position(|a| a.is_none())
.ok_or(ErrorCode::NoSpace)?;
self.entries[index] = Some(entry);
self.changed = true;
let slot = self.entries.iter().position(|a| a.is_none());
if slot.is_some() || self.entries.len() < MAX_ACL_ENTRIES {
if let Some(index) = slot {
self.entries[index] = Some(entry);
} else {
self.entries
.push(Some(entry))
.map_err(|_| ErrorCode::NoSpace)
.unwrap();
}
self.changed = true;
}
Ok(())
}
@ -459,17 +462,13 @@ impl AclMgr {
}
pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> {
for i in 0..MAX_ACL_ENTRIES {
if self.entries[i]
.filter(|e| e.fab_idx == Some(fab_idx))
.is_some()
{
self.entries[i] = None;
for entry in &mut self.entries {
if entry.map(|e| e.fab_idx == Some(fab_idx)).unwrap_or(false) {
*entry = None;
self.changed = true;
}
}
self.changed = true;
Ok(())
}
@ -505,7 +504,7 @@ impl AclMgr {
pub fn load(&mut self, data: &[u8]) -> Result<(), Error> {
let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?;
self.entries = AclEntries::from_tlv(&root)?;
tlv::from_tlv(&mut self.entries, &root)?;
self.changed = false;
Ok(())
@ -515,7 +514,9 @@ impl AclMgr {
if self.changed {
let mut wb = WriteBuf::new(buf);
let mut tw = TLVWriter::new(&mut wb);
self.entries.to_tlv(&mut tw, TagType::Anonymous)?;
self.entries
.as_slice()
.to_tlv(&mut tw, TagType::Anonymous)?;
self.changed = false;
@ -527,6 +528,10 @@ impl AclMgr {
}
}
pub fn is_changed(&self) -> bool {
self.changed
}
/// Traverse fabric specific entries to find the index
///
/// If the ACL Mgr has 3 entries with fabric indexes, 1, 2, 1, then the list

View file

@ -53,6 +53,7 @@ pub struct Matter<'a> {
impl<'a> Matter<'a> {
#[cfg(feature = "std")]
#[inline(always)]
pub fn new_default(dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns, port: u16) -> Self {
use crate::utils::epoch::sys_epoch;
use crate::utils::rand::sys_rand;
@ -66,6 +67,7 @@ impl<'a> Matter<'a> {
/// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device
/// requires a set of device attestation certificates and keys. It is the responsibility of
/// this object to return the device attestation details when queried upon.
#[inline(always)]
pub fn new(
dev_det: &'a BasicInfoConfig,
mdns: &'a mut dyn Mdns,
@ -113,6 +115,10 @@ impl<'a> Matter<'a> {
self.acl_mgr.borrow_mut().store(buf)
}
pub fn is_changed(&self) -> bool {
self.acl_mgr.borrow().is_changed() || self.fabric_mgr.borrow().is_changed()
}
pub fn start(&self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> {
let open_comm_window = self.fabric_mgr.borrow().is_empty();
if open_comm_window {

View file

@ -51,7 +51,7 @@ type AesCcm = Ccm<Aes128, U16, U13>;
extern crate alloc;
#[derive(Clone)]
#[derive(Debug, Clone)]
pub struct Sha256 {
hasher: sha2::Sha256,
}

View file

@ -49,7 +49,7 @@ impl<T> Handler for &mut T
where
T: Handler,
{
fn read<'a>(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
(**self).read(attr, encoder)
}

View file

@ -49,6 +49,7 @@ pub struct FailSafe {
}
impl FailSafe {
#[inline(always)]
pub const fn new() -> Self {
Self { state: State::Idle }
}

View file

@ -138,7 +138,7 @@ impl<'a> GenCommCluster<'a> {
}
pub fn failsafe(&self) -> &RefCell<FailSafe> {
&self.failsafe
self.failsafe
}
pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {

View file

@ -613,7 +613,7 @@ impl<'a> NocCluster<'a> {
SessionMode::Pase => {
let noc_data = transaction
.session_mut()
.get_noc_data::<NocData>()
.get_noc_data()
.ok_or(ErrorCode::NoSession)?;
let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?;

View file

@ -165,11 +165,11 @@ impl From<mbedtls::Error> for Error {
}
}
#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))]
#[cfg(target_os = "espidf")]
impl From<esp_idf_sys::EspError> for Error {
fn from(e: esp_idf_sys::EspError) -> Self {
::log::error!("Error in TLS: {}", e);
Self::new(ErrorCode::TLSStack)
::log::error!("Error in ESP: {}", e);
Self::new(ErrorCode::TLSStack) // TODO: Not a good mapping
}
}
@ -208,9 +208,9 @@ impl fmt::Debug for Error {
#[cfg(all(feature = "std", feature = "backtrace"))]
{
write!(f, "Error::{} {{\n", self)?;
writeln!(f, "Error::{} {{", self)?;
write!(f, "{}", self.backtrace())?;
write!(f, "}}\n")?;
writeln!(f, "}}")?;
}
Ok(())

View file

@ -27,7 +27,7 @@ use crate::{
error::{Error, ErrorCode},
group_keys::KeySet,
mdns::{MdnsMgr, ServiceMode},
tlv::{FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr},
tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr},
utils::writebuf::WriteBuf,
};
@ -184,7 +184,7 @@ impl Fabric {
pub const MAX_SUPPORTED_FABRICS: usize = 3;
type FabricEntries = [Option<Fabric>; MAX_SUPPORTED_FABRICS];
type FabricEntries = Vec<Option<Fabric>, MAX_SUPPORTED_FABRICS>;
pub struct FabricMgr {
fabrics: FabricEntries,
@ -192,30 +192,25 @@ pub struct FabricMgr {
}
impl FabricMgr {
#[inline(always)]
pub const fn new() -> Self {
const INIT: Option<Fabric> = None;
Self {
fabrics: [INIT; MAX_SUPPORTED_FABRICS],
fabrics: FabricEntries::new(),
changed: false,
}
}
pub fn load(&mut self, data: &[u8], mdns_mgr: &mut MdnsMgr) -> Result<(), Error> {
for fabric in &self.fabrics {
if let Some(fabric) = fabric {
mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?;
}
for fabric in self.fabrics.iter().flatten() {
mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?;
}
let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?;
self.fabrics = FabricEntries::from_tlv(&root)?;
tlv::from_tlv(&mut self.fabrics, &root)?;
for fabric in &self.fabrics {
if let Some(fabric) = fabric {
mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?;
}
for fabric in self.fabrics.iter().flatten() {
mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?;
}
self.changed = false;
@ -228,7 +223,9 @@ impl FabricMgr {
let mut wb = WriteBuf::new(buf);
let mut tw = TLVWriter::new(&mut wb);
self.fabrics.to_tlv(&mut tw, TagType::Anonymous)?;
self.fabrics
.as_slice()
.to_tlv(&mut tw, TagType::Anonymous)?;
self.changed = false;
@ -240,20 +237,32 @@ impl FabricMgr {
}
}
pub fn is_changed(&self) -> bool {
self.changed
}
pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result<u8, Error> {
for (index, fabric) in self.fabrics.iter_mut().enumerate() {
if fabric.is_none() {
mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?;
let slot = self.fabrics.iter().position(|x| x.is_none());
*fabric = Some(f);
if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS {
mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?;
self.changed = true;
self.changed = true;
if let Some(index) = slot {
self.fabrics[index] = Some(f);
return Ok((index + 1) as u8);
Ok((index + 1) as u8)
} else {
self.fabrics
.push(Some(f))
.map_err(|_| ErrorCode::NoSpace)
.unwrap();
Ok(self.fabrics.len() as u8)
}
} else {
Err(ErrorCode::NoSpace.into())
}
Err(ErrorCode::NoSpace.into())
}
pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> {
@ -311,15 +320,14 @@ impl FabricMgr {
}
pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> {
if !label.is_empty() {
if self
if !label.is_empty()
&& self
.fabrics
.iter()
.filter_map(|f| f.as_ref())
.any(|f| f.label == label)
{
return Err(ErrorCode::Invalid.into());
}
{
return Err(ErrorCode::Invalid.into());
}
let index = (index - 1) as usize;

View file

@ -605,6 +605,7 @@ impl<'a> SubscribeReq<'a> {
}
}
#[derive(Debug)]
pub struct ResumeReadReq {
pub paths: heapless::Vec<AttrPath, MAX_RESUME_PATHS>,
pub filters: heapless::Vec<DataVersionFilter, MAX_RESUME_DATAVER_FILTERS>,
@ -664,6 +665,7 @@ impl ResumeReadReq {
}
}
#[derive(Debug)]
pub struct ResumeSubscribeReq {
pub subscription_id: u32,
pub paths: heapless::Vec<AttrPath, MAX_RESUME_PATHS>,

View file

@ -77,7 +77,7 @@ pub mod msg {
EventPath,
};
#[derive(Default, FromTLV, ToTLV)]
#[derive(Debug, Default, FromTLV, ToTLV)]
#[tlvargs(lifetime = "'a")]
pub struct SubscribeReq<'a> {
pub keep_subs: bool,

View file

@ -109,6 +109,7 @@ pub struct MdnsMgr<'a> {
}
impl<'a> MdnsMgr<'a> {
#[inline(always)]
pub fn new(
vid: u16,
pid: u16,
@ -212,6 +213,428 @@ impl<'a> MdnsMgr<'a> {
}
}
pub mod builtin {
use core::cell::RefCell;
use core::fmt::Write;
use core::pin::pin;
use core::str::FromStr;
use domain::base::header::Flags;
use domain::base::iana::Class;
use domain::base::octets::{Octets256, Octets64, OctetsBuilder};
use domain::base::{Dname, MessageBuilder, Record, ShortBuf};
use domain::rdata::{Aaaa, Ptr, Srv, Txt, A};
use embassy_futures::select::select;
use embassy_sync::blocking_mutex::raw::NoopRawMutex;
use embassy_time::{Duration, Timer};
use log::info;
use crate::error::{Error, ErrorCode};
use crate::transport::network::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use crate::transport::udp::UdpListener;
use crate::utils::select::EitherUnwrap;
const IP_BROADCAST_ADDRS: [SocketAddr; 2] = [
SocketAddr::new(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353),
SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)),
5353,
),
];
const IP_BIND_ADDR: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353);
pub fn create_record(
id: u16,
hostname: &str,
ip: [u8; 4],
ipv6: Option<[u8; 16]>,
ttl_sec: u32,
name: &str,
service: &str,
protocol: &str,
port: u16,
service_subtypes: &[&str],
txt_kvs: &[(&str, &str)],
buffer: &mut [u8],
) -> Result<usize, ShortBuf> {
let target = domain::base::octets::Octets2048::new();
let message = MessageBuilder::from_target(target)?;
let mut message = message.answer();
let mut ptr_str = heapless::String::<40>::new();
write!(ptr_str, "{}.{}.local", service, protocol).unwrap();
let mut dname = heapless::String::<60>::new();
write!(dname, "{}.{}.{}.local", name, service, protocol).unwrap();
let mut hname = heapless::String::<40>::new();
write!(hname, "{}.local", hostname).unwrap();
let ptr: Dname<Octets64> = Dname::from_str(&ptr_str).unwrap();
let record: Record<Dname<Octets64>, Ptr<_>> = Record::new(
Dname::from_str("_services._dns-sd._udp.local").unwrap(),
Class::In,
ttl_sec,
Ptr::new(ptr),
);
message.push(record)?;
let t: Dname<Octets64> = Dname::from_str(&dname).unwrap();
let record: Record<Dname<Octets64>, Ptr<_>> = Record::new(
Dname::from_str(&ptr_str).unwrap(),
Class::In,
ttl_sec,
Ptr::new(t),
);
message.push(record)?;
for sub_srv in service_subtypes {
let mut ptr_str = heapless::String::<40>::new();
write!(ptr_str, "{}._sub.{}.{}.local", sub_srv, service, protocol).unwrap();
let ptr: Dname<Octets64> = Dname::from_str(&ptr_str).unwrap();
let record: Record<Dname<Octets64>, Ptr<_>> = Record::new(
Dname::from_str("_services._dns-sd._udp.local").unwrap(),
Class::In,
ttl_sec,
Ptr::new(ptr),
);
message.push(record)?;
let t: Dname<Octets64> = Dname::from_str(&dname).unwrap();
let record: Record<Dname<Octets64>, Ptr<_>> = Record::new(
Dname::from_str(&ptr_str).unwrap(),
Class::In,
ttl_sec,
Ptr::new(t),
);
message.push(record)?;
}
let target: Dname<Octets64> = Dname::from_str(&hname).unwrap();
let record: Record<Dname<Octets64>, Srv<_>> = Record::new(
Dname::from_str(&dname).unwrap(),
Class::In,
ttl_sec,
Srv::new(0, 0, port, target),
);
message.push(record)?;
// only way I found to create multiple parts in a Txt
// each slice is the length and then the data
let mut octets = Octets256::new();
//octets.append_slice(&[1u8, b'X']).unwrap();
//octets.append_slice(&[2u8, b'A', b'B']).unwrap();
//octets.append_slice(&[0u8]).unwrap();
for (k, v) in txt_kvs {
octets
.append_slice(&[(k.len() + v.len() + 1) as u8])
.unwrap();
octets.append_slice(k.as_bytes()).unwrap();
octets.append_slice(&[b'=']).unwrap();
octets.append_slice(v.as_bytes()).unwrap();
}
let txt = Txt::from_octets(&mut octets).unwrap();
let record: Record<Dname<Octets64>, Txt<_>> =
Record::new(Dname::from_str(&dname).unwrap(), Class::In, ttl_sec, txt);
message.push(record)?;
let record: Record<Dname<Octets64>, A> = Record::new(
Dname::from_str(&hname).unwrap(),
Class::In,
ttl_sec,
A::from_octets(ip[0], ip[1], ip[2], ip[3]),
);
message.push(record)?;
if let Some(ipv6) = ipv6 {
let record: Record<Dname<Octets64>, Aaaa> = Record::new(
Dname::from_str(&hname).unwrap(),
Class::In,
ttl_sec,
Aaaa::new(ipv6.into()),
);
message.push(record)?;
}
let headerb = message.header_mut();
headerb.set_id(id);
headerb.set_opcode(domain::base::iana::Opcode::Query);
headerb.set_rcode(domain::base::iana::Rcode::NoError);
let mut flags = Flags::new();
flags.qr = true;
flags.aa = true;
headerb.set_flags(flags);
let target = message.finish();
buffer[..target.len()].copy_from_slice(target.as_ref());
Ok(target.len())
}
pub type Notification = embassy_sync::signal::Signal<NoopRawMutex, ()>;
#[derive(Debug, Clone)]
struct MdnsEntry {
key: heapless::String<64>,
record: heapless::Vec<u8, 1024>,
}
impl MdnsEntry {
#[inline(always)]
const fn new() -> Self {
Self {
key: heapless::String::new(),
record: heapless::Vec::new(),
}
}
}
pub struct Mdns<'a> {
id: u16,
hostname: &'a str,
ip: [u8; 4],
ipv6: Option<[u8; 16]>,
entries: RefCell<heapless::Vec<MdnsEntry, 4>>,
notification: Notification,
udp: RefCell<Option<UdpListener>>,
}
impl<'a> Mdns<'a> {
#[inline(always)]
pub const fn new(id: u16, hostname: &'a str, ip: [u8; 4], ipv6: Option<[u8; 16]>) -> Self {
Self {
id,
hostname,
ip,
ipv6,
entries: RefCell::new(heapless::Vec::new()),
notification: Notification::new(),
udp: RefCell::new(None),
}
}
pub fn split(&mut self) -> (MdnsApi<'_, 'a>, MdnsRunner<'_, 'a>) {
(MdnsApi(&*self), MdnsRunner(&*self))
}
async fn bind(&self) -> Result<(), Error> {
if self.udp.borrow().is_none() {
*self.udp.borrow_mut() = Some(UdpListener::new(IP_BIND_ADDR).await?);
}
Ok(())
}
pub fn close(&mut self) {
*self.udp.borrow_mut() = None;
}
fn key(
&self,
name: &str,
service: &str,
protocol: &str,
port: u16,
) -> heapless::String<64> {
let mut key = heapless::String::new();
write!(&mut key, "{name}.{service}.{protocol}.{port}").unwrap();
key
}
}
pub struct MdnsApi<'a, 'b>(&'a Mdns<'b>);
impl<'a, 'b> MdnsApi<'a, 'b> {
pub fn add(
&self,
name: &str,
service: &str,
protocol: &str,
port: u16,
service_subtypes: &[&str],
txt_kvs: &[(&str, &str)],
) -> Result<(), Error> {
info!(
"Registering mDNS service {}/{}.{} [{:?}]/{}, keys [{:?}]",
name, service, protocol, service_subtypes, port, txt_kvs
);
let key = self.0.key(name, service, protocol, port);
let mut entries = self.0.entries.borrow_mut();
entries.retain(|entry| entry.key != key);
entries
.push(MdnsEntry::new())
.map_err(|_| ErrorCode::NoSpace)?;
let entry = entries.iter_mut().last().unwrap();
entry
.record
.resize(1024, 0)
.map_err(|_| ErrorCode::NoSpace)
.unwrap();
match create_record(
self.0.id,
self.0.hostname,
self.0.ip,
self.0.ipv6,
60, /*ttl_sec*/
name,
service,
protocol,
port,
service_subtypes,
txt_kvs,
&mut entry.record,
) {
Ok(len) => entry.record.truncate(len),
Err(_) => {
entries.pop();
Err(ErrorCode::NoSpace)?;
}
}
self.0.notification.signal(());
Ok(())
}
pub fn remove(
&self,
name: &str,
service: &str,
protocol: &str,
port: u16,
) -> Result<(), Error> {
info!(
"Deregistering mDNS service {}/{}.{}/{}",
name, service, protocol, port
);
let key = self.0.key(name, service, protocol, port);
let mut entries = self.0.entries.borrow_mut();
let old_len = entries.len();
entries.retain(|entry| entry.key != key);
if entries.len() != old_len {
self.0.notification.signal(());
}
Ok(())
}
}
pub struct MdnsRunner<'a, 'b>(&'a Mdns<'b>);
impl<'a, 'b> MdnsRunner<'a, 'b> {
pub async fn run(&mut self) -> Result<(), Error> {
let mut broadcast = pin!(self.broadcast());
let mut respond = pin!(self.respond());
select(&mut broadcast, &mut respond).await.unwrap()
}
async fn broadcast(&self) -> Result<(), Error> {
loop {
select(
self.0.notification.wait(),
Timer::after(Duration::from_secs(30)),
)
.await;
let mut index = 0;
while let Some(entry) = self
.0
.entries
.borrow()
.get(index)
.map(|entry| entry.clone())
{
info!("Broadasting mDNS entry {}", &entry.key);
self.0.bind().await?;
let udp = self.0.udp.borrow();
let udp = udp.as_ref().unwrap();
for addr in IP_BROADCAST_ADDRS {
udp.send(addr, &entry.record).await?;
}
index += 1;
}
}
}
async fn respond(&self) -> Result<(), Error> {
loop {
let mut buf = [0; 1580];
let udp = self.0.udp.borrow();
let udp = udp.as_ref().unwrap();
let (_len, _addr) = udp.recv(&mut buf).await?;
info!("Received UDP packet");
// TODO: Process the incoming packed and only answer what we are being queried about
self.0.notification.signal(());
}
}
}
impl<'a, 'b> super::Mdns for MdnsApi<'a, 'b> {
fn add(
&mut self,
name: &str,
service: &str,
protocol: &str,
port: u16,
service_subtypes: &[&str],
txt_kvs: &[(&str, &str)],
) -> Result<(), Error> {
MdnsApi::add(
self,
name,
service,
protocol,
port,
service_subtypes,
txt_kvs,
)
}
fn remove(
&mut self,
name: &str,
service: &str,
protocol: &str,
port: u16,
) -> Result<(), Error> {
MdnsApi::remove(self, name, service, protocol, port)
}
}
}
#[cfg(all(feature = "std", feature = "astro-dnssd"))]
pub mod astro {
use std::collections::HashMap;
@ -342,399 +765,6 @@ pub mod astro {
}
}
// TODO: Maybe future
// #[cfg(all(feature = "std", feature = "zeroconf"))]
// pub mod zeroconf {
// use std::collections::HashMap;
// use super::Mdns;
// use crate::error::{Error, ErrorCode};
// use log::info;
// use zeroconf::prelude::*;
// use zeroconf::{MdnsService, ServiceType, TxtRecord};
// #[derive(Debug, Clone, Eq, PartialEq, Hash)]
// pub struct ServiceId {
// name: String,
// service: String,
// protocol: String,
// port: u16,
// }
// pub struct ZeroconfMdns {
// services: HashMap<ServiceId, MdnsService>,
// }
// impl ZeroconfMdns {
// pub fn new() -> Result<Self, Error> {
// Ok(Self {
// services: HashMap::new(),
// })
// }
// pub fn add(
// &mut self,
// name: &str,
// service: &str,
// protocol: &str,
// port: u16,
// service_subtypes: &[&str],
// txt_kvs: &[(&str, &str)],
// ) -> Result<(), Error> {
// info!(
// "Registering mDNS service {}/{}.{} [{:?}]/{}",
// name, service, protocol, service_subtypes, port
// );
// let _ = self.remove(name, service, protocol, port);
// let mut svc = MdnsService::new(
// ServiceType::with_sub_types(service, protocol, service_subtypes.into()).unwrap(),
// port,
// );
// let mut txt = TxtRecord::new();
// for kvs in txt_kvs {
// info!("mDNS TXT key {} val {}", kvs.0, kvs.1);
// txt.insert(kvs.0, kvs.1);
// }
// svc.set_txt_record(txt);
// //let event_loop = svc.register().map_err(|_| ErrorCode::MdnsError)?;
// self.services.insert(
// ServiceId {
// name: name.into(),
// service: service.into(),
// protocol: protocol.into(),
// port,
// },
// svc,
// );
// Ok(())
// }
// pub fn remove(
// &mut self,
// name: &str,
// service: &str,
// protocol: &str,
// port: u16,
// ) -> Result<(), Error> {
// let id = ServiceId {
// name: name.into(),
// service: service.into(),
// protocol: protocol.into(),
// port,
// };
// if self.services.remove(&id).is_some() {
// info!(
// "Deregistering mDNS service {}.{}/{}/{}",
// name, service, protocol, port
// );
// }
// Ok(())
// }
// }
// impl Mdns for ZeroconfMdns {
// fn add(
// &mut self,
// name: &str,
// service: &str,
// protocol: &str,
// port: u16,
// service_subtypes: &[&str],
// txt_kvs: &[(&str, &str)],
// ) -> Result<(), Error> {
// ZeroconfMdns::add(
// self,
// name,
// service,
// protocol,
// port,
// service_subtypes,
// txt_kvs,
// )
// }
// fn remove(
// &mut self,
// name: &str,
// service: &str,
// protocol: &str,
// port: u16,
// ) -> Result<(), Error> {
// ZeroconfMdns::remove(self, name, service, protocol, port)
// }
// }
// }
#[cfg(all(feature = "std", not(target_os = "espidf")))]
pub mod libmdns {
use super::Mdns;
use crate::error::Error;
use libmdns::{Responder, Service};
use log::info;
use std::collections::HashMap;
use std::vec::Vec;
#[derive(Debug, Clone, Eq, PartialEq, Hash)]
pub struct ServiceId {
name: String,
service: String,
protocol: String,
port: u16,
}
pub struct LibMdns {
responder: Responder,
services: HashMap<ServiceId, Service>,
}
impl LibMdns {
pub fn new() -> Result<Self, Error> {
let responder = Responder::new()?;
Ok(Self {
responder,
services: HashMap::new(),
})
}
pub fn add(
&mut self,
name: &str,
service: &str,
protocol: &str,
port: u16,
txt_kvs: &[(&str, &str)],
) -> Result<(), Error> {
info!(
"Registering mDNS service {}/{}.{}/{}",
name, service, protocol, port
);
let _ = self.remove(name, service, protocol, port);
let mut properties = Vec::new();
for kvs in txt_kvs {
info!("mDNS TXT key {} val {}", kvs.0, kvs.1);
properties.push(format!("{}={}", kvs.0, kvs.1));
}
let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect();
let svc = self.responder.register(
format!("{}.{}", service, protocol),
name.to_owned(),
port,
&properties,
);
self.services.insert(
ServiceId {
name: name.into(),
service: service.into(),
protocol: protocol.into(),
port,
},
svc,
);
Ok(())
}
pub fn remove(
&mut self,
name: &str,
service: &str,
protocol: &str,
port: u16,
) -> Result<(), Error> {
let id = ServiceId {
name: name.into(),
service: service.into(),
protocol: protocol.into(),
port,
};
if self.services.remove(&id).is_some() {
info!(
"Deregistering mDNS service {}/{}.{}/{}",
name, service, protocol, port
);
}
Ok(())
}
}
impl Mdns for LibMdns {
fn add(
&mut self,
name: &str,
service: &str,
protocol: &str,
port: u16,
_service_subtypes: &[&str],
txt_kvs: &[(&str, &str)],
) -> Result<(), Error> {
LibMdns::add(self, name, service, protocol, port, txt_kvs)
}
fn remove(
&mut self,
name: &str,
service: &str,
protocol: &str,
port: u16,
) -> Result<(), Error> {
LibMdns::remove(self, name, service, protocol, port)
}
}
}
// TODO: Maybe future
// #[cfg(feature = "std")]
// pub mod simplemdns {
// use std::net::Ipv4Addr;
// use crate::error::{Error, ErrorCode};
// use super::Mdns;
// use log::info;
// use simple_dns::{
// rdata::{RData, A, SRV, TXT, PTR},
// CharacterString, Name, ResourceRecord, CLASS,
// };
// use simple_mdns::sync_discovery::SimpleMdnsResponder;
// #[derive(Debug, Clone, Eq, PartialEq, Hash)]
// pub struct ServiceId {
// name: String,
// service_type: String,
// port: u16,
// }
// pub struct SimpleMdns {
// responder: SimpleMdnsResponder,
// }
// impl SimpleMdns {
// pub fn new() -> Result<Self, Error> {
// Ok(Self {
// responder: Default::default(),
// })
// }
// pub fn add(
// &mut self,
// name: &str,
// service_type: &str,
// port: u16,
// txt_kvs: &[(&str, &str)],
// ) -> Result<(), Error> {
// info!(
// "Registering mDNS service {}/{}/{}",
// name, service_type, port
// );
// let _ = self.remove(name, service_type, port);
// let mut txt = TXT::new();
// for kvs in txt_kvs {
// info!("mDNS TXT key {} val {}", kvs.0, kvs.1);
// let string = format!("{}={}", kvs.0, kvs.1);
// txt.add_char_string(
// CharacterString::new(string.as_bytes())
// .unwrap()
// .into_owned(),
// );
// }
// let name = Name::new_unchecked(name).into_owned();
// let service_type = Name::new_unchecked(service_type).into_owned();
// self.responder.add_resource(ResourceRecord::new(
// name.clone(),
// CLASS::IN,
// 10,
// RData::A(A {
// address: Ipv4Addr::new(192, 168, 10, 189).into(),
// }),
// ));
// self.responder.add_resource(ResourceRecord::new(
// name.clone(),
// CLASS::IN,
// 10,
// RData::SRV(SRV {
// port: port,
// priority: 0,
// weight: 0,
// target: service_type.clone(),
// }),
// ));
// self.responder.add_resource(ResourceRecord::new(
// srv_name.clone(),
// CLASS::IN,
// 10,
// RData::PTR(PTR(srv_name.clone()),
// )));
// self.responder.add_resource(ResourceRecord::new(
// srv_name,
// CLASS::IN,
// 10,
// RData::TXT(txt),
// ));
// Ok(())
// }
// pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> {
// // TODO
// // let id = ServiceId {
// // name: name.into(),
// // service_type: service_type.into(),
// // port,
// // };
// // if self.responder.remove_resource_record(resource).remove(&id).is_some() {
// // info!(
// // "Deregistering mDNS service {}/{}/{}",
// // name, service_type, port
// // );
// // }
// Ok(())
// }
// }
// impl Mdns for SimpleMdns {
// fn add(
// &mut self,
// name: &str,
// service_type: &str,
// port: u16,
// _service_subtypes: &[&str],
// txt_kvs: &[(&str, &str)],
// ) -> Result<(), Error> {
// SimpleMdns::add(self, name, service_type, port, txt_kvs)
// }
// fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> {
// SimpleMdns::remove(self, name, service_type, port)
// }
// }
// }
#[cfg(test)]
mod tests {
use super::*;

View file

@ -91,7 +91,7 @@ pub fn print_pairing_code_and_qr(
let qr_code = compute_qr_code(dev_det, comm_data, discovery_capabilities, buf)?;
pretty_print_pairing_code(&pairing_code);
print_qr_code(&qr_code);
print_qr_code(qr_code);
Ok(())
}

View file

@ -35,12 +35,13 @@ use crate::{
utils::{rand::Rand, writebuf::WriteBuf},
};
#[derive(PartialEq)]
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum State {
Sigma1Rx,
Sigma3Rx,
}
#[derive(Debug, Clone)]
pub struct CaseSession {
state: State,
peer_sessid: u16,
@ -84,7 +85,7 @@ impl<'a> Case<'a> {
let mut case_session = ctx
.exch_ctx
.exch
.take_case_session::<CaseSession>()
.take_case_session()
.ok_or(ErrorCode::InvalidState)?;
if case_session.state != State::Sigma1Rx {
Err(ErrorCode::Invalid)?;

View file

@ -56,6 +56,8 @@ pub fn create_sc_status_report(
status_code: SCStatusCodes,
proto_data: Option<&[u8]>,
) -> Result<(), Error> {
proto_tx.reset();
let general_code = match status_code {
SCStatusCodes::SessionEstablishmentSuccess => GeneralCode::Success,
SCStatusCodes::CloseSession => {
@ -79,6 +81,7 @@ pub fn create_sc_status_report(
}
pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) {
proto_tx.reset();
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8);
proto_tx.unset_reliable();

View file

@ -711,11 +711,7 @@ impl<'a> Iterator for TLVContainerIterator<'a> {
return None;
}
if is_container(element.element_type) {
self.prev_container = true;
} else {
self.prev_container = false;
}
self.prev_container = is_container(element.element_type);
Some(element)
}
}

View file

@ -61,6 +61,24 @@ impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] {
}
}
pub fn from_tlv<'a, T: FromTLV<'a>, const N: usize>(
vec: &mut heapless::Vec<T, N>,
t: &TLVElement<'a>,
) -> Result<(), Error> {
vec.clear();
t.confirm_array()?;
if let Some(tlv_iter) = t.enter() {
for element in tlv_iter {
vec.push(T::from_tlv(&element)?)
.map_err(|_| ErrorCode::NoSpace)?;
}
}
Ok(())
}
macro_rules! fromtlv_for {
($($t:ident)*) => {
$(
@ -110,6 +128,16 @@ impl<T: ToTLV, const N: usize> ToTLV for [T; N] {
}
}
impl<'a, T: ToTLV> ToTLV for &'a [T] {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.start_array(tag)?;
for i in *self {
i.to_tlv(tw, TagType::Anonymous)?;
}
tw.end_container()
}
}
// Generate ToTLV for standard data types
totlv_for!(i8 u8 u16 u32 u64 bool);

View file

@ -15,7 +15,6 @@
* limitations under the License.
*/
use core::any::Any;
use core::fmt;
use core::time::Duration;
use log::{error, info, trace};
@ -144,7 +143,7 @@ impl Exchange {
}
}
pub fn take_case_session<T: Any>(&mut self) -> Option<CaseSession> {
pub fn take_case_session(&mut self) -> Option<CaseSession> {
let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::CaseSession(session) = old {
Some(session)

View file

@ -25,5 +25,4 @@ pub mod plain_hdr;
pub mod proto_ctx;
pub mod proto_hdr;
pub mod session;
#[cfg(feature = "std")]
pub mod udp;

View file

@ -17,15 +17,23 @@
use core::fmt::{Debug, Display};
#[cfg(not(feature = "std"))]
pub use no_std_net::{IpAddr, Ipv4Addr, SocketAddr};
pub use no_std_net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
#[cfg(feature = "std")]
pub use std::net::{IpAddr, Ipv4Addr, SocketAddr};
pub use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
#[derive(PartialEq, Copy, Clone)]
#[derive(Eq, PartialEq, Copy, Clone)]
pub enum Address {
Udp(SocketAddr),
}
impl Address {
pub fn unwrap_udp(self) -> SocketAddr {
match self {
Self::Udp(addr) => addr,
}
}
}
impl Default for Address {
fn default() -> Self {
Address::Udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080))

View file

@ -31,7 +31,7 @@ use super::{
pub const MAX_RX_BUF_SIZE: usize = 1583;
pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/;
#[derive(PartialEq)]
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
enum RxState {
Uninit,
PlainDecode,
@ -43,6 +43,30 @@ enum Direction<'a> {
Rx(ParseBuf<'a>, RxState),
}
impl<'a> Direction<'a> {
pub fn load(&mut self, direction: &Direction) -> Result<(), Error> {
if matches!(self, Self::Tx(_)) != matches!(direction, Direction::Tx(_)) {
Err(ErrorCode::Invalid)?;
}
match self {
Self::Tx(wb) => match direction {
Direction::Tx(src_wb) => wb.load(src_wb)?,
Direction::Rx(_, _) => Err(ErrorCode::Invalid)?,
},
Self::Rx(pb, state) => match direction {
Direction::Tx(_) => Err(ErrorCode::Invalid)?,
Direction::Rx(src_pb, src_state) => {
pb.load(src_pb)?;
*state = *src_state;
}
},
}
Ok(())
}
}
pub struct Packet<'a> {
pub plain: PlainHdr,
pub proto: ProtoHdr,
@ -78,7 +102,7 @@ impl<'a> Packet<'a> {
}
}
pub fn reset(&mut self) -> () {
pub fn reset(&mut self) {
if let Direction::Tx(wb) = &mut self.data {
wb.reset();
wb.reserve(Packet::HDR_RESERVE).unwrap();
@ -91,6 +115,13 @@ impl<'a> Packet<'a> {
}
}
pub fn load(&mut self, packet: &Packet) -> Result<(), Error> {
self.plain = packet.plain.clone();
self.proto = packet.proto.clone();
self.peer = packet.peer;
self.data.load(&packet.data)
}
pub fn as_slice(&self) -> &[u8] {
match &self.data {
Direction::Rx(pb, _) => pb.as_slice(),

View file

@ -21,7 +21,7 @@ use crate::utils::writebuf::WriteBuf;
use bitflags::bitflags;
use log::info;
#[derive(Debug, PartialEq, Default)]
#[derive(Debug, PartialEq, Eq, Default, Copy, Clone)]
pub enum SessionType {
#[default]
None,
@ -38,7 +38,7 @@ bitflags! {
}
// This is the unencrypted message
#[derive(Debug, Default)]
#[derive(Debug, Default, Clone)]
pub struct PlainHdr {
pub flags: MsgFlags,
pub sess_type: SessionType,

View file

@ -36,7 +36,7 @@ bitflags! {
}
}
#[derive(Default)]
#[derive(Debug, Default, Clone)]
pub struct ProtoHdr {
pub exch_id: u16,
pub exch_flags: ExchFlags,
@ -278,7 +278,7 @@ mod tests {
decrypt_in_place(recvd_ctr, 0, &mut parsebuf, &key).unwrap();
assert_eq!(
parsebuf.into_slice(),
parsebuf.as_slice(),
[
0x5, 0x8, 0x70, 0x0, 0x1, 0x0, 0x15, 0x28, 0x0, 0x28, 0x1, 0x36, 0x2, 0x15, 0x37,
0x0, 0x24, 0x0, 0x0, 0x24, 0x1, 0x30, 0x24, 0x2, 0x2, 0x18, 0x35, 0x1, 0x24, 0x0,

View file

@ -19,11 +19,8 @@ use crate::data_model::sdm::noc::NocData;
use crate::utils::epoch::Epoch;
use crate::utils::rand::Rand;
use core::fmt;
use core::ops::{Deref, DerefMut};
use core::time::Duration;
use core::{
any::Any,
ops::{Deref, DerefMut},
};
use crate::{
error::*,
@ -166,7 +163,7 @@ impl Session {
self.data = None;
}
pub fn get_noc_data<T: Any>(&mut self) -> Option<&mut NocData> {
pub fn get_noc_data(&mut self) -> Option<&mut NocData> {
self.data.as_mut()
}
@ -325,17 +322,16 @@ pub const MAX_SESSIONS: usize = 16;
pub struct SessionMgr {
next_sess_id: u16,
sessions: [Option<Session>; MAX_SESSIONS],
sessions: heapless::Vec<Option<Session>, MAX_SESSIONS>,
epoch: Epoch,
rand: Rand,
}
impl SessionMgr {
#[inline(always)]
pub fn new(epoch: Epoch, rand: Rand) -> Self {
const INIT: Option<Session> = None;
Self {
sessions: [INIT; MAX_SESSIONS],
sessions: heapless::Vec::new(),
next_sess_id: 1,
epoch,
rand,
@ -343,10 +339,10 @@ impl SessionMgr {
}
pub fn mut_by_index(&mut self, index: usize) -> Option<&mut Session> {
self.sessions[index].as_mut()
self.sessions.get_mut(index).and_then(Option::as_mut)
}
fn get_next_sess_id(&mut self) -> u16 {
pub fn get_next_sess_id(&mut self) -> u16 {
let mut next_sess_id: u16;
loop {
next_sess_id = self.next_sess_id;
@ -366,7 +362,7 @@ impl SessionMgr {
}
pub fn get_session_for_eviction(&self) -> Option<usize> {
if self.get_empty_slot().is_none() {
if self.sessions.len() == MAX_SESSIONS && self.get_empty_slot().is_none() {
Some(self.get_lru())
} else {
None
@ -380,8 +376,8 @@ impl SessionMgr {
fn get_lru(&self) -> usize {
let mut lru_index = 0;
let mut lru_ts = (self.epoch)();
for i in 0..MAX_SESSIONS {
if let Some(s) = &self.sessions[i] {
for (i, s) in self.sessions.iter().enumerate() {
if let Some(s) = s {
if s.last_use < lru_ts {
lru_ts = s.last_use;
lru_index = i;
@ -405,10 +401,17 @@ impl SessionMgr {
/// We could have returned a SessionHandle here. But the borrow checker doesn't support
/// non-lexical lifetimes. This makes it harder for the caller of this function to take
/// action in the error return path
pub fn add_session(&mut self, session: Session) -> Result<usize, Error> {
fn add_session(&mut self, session: Session) -> Result<usize, Error> {
if let Some(index) = self.get_empty_slot() {
self.sessions[index] = Some(session);
Ok(index)
} else if self.sessions.len() < MAX_SESSIONS {
self.sessions
.push(Some(session))
.map_err(|_| ErrorCode::NoSpace)
.unwrap();
Ok(self.sessions.len() - 1)
} else {
Err(ErrorCode::NoSpace.into())
}
@ -419,7 +422,7 @@ impl SessionMgr {
self.add_session(session)
}
fn _get(
pub fn get(
&self,
sess_id: u16,
peer_addr: Address,
@ -451,14 +454,14 @@ impl SessionMgr {
Some(self.get_session_handle(index))
}
pub fn get_or_add(
fn get_or_add(
&mut self,
sess_id: u16,
peer_addr: Address,
peer_nodeid: Option<u64>,
is_encrypted: bool,
) -> Result<usize, Error> {
if let Some(index) = self._get(sess_id, peer_addr, peer_nodeid, is_encrypted) {
if let Some(index) = self.get(sess_id, peer_addr, peer_nodeid, is_encrypted) {
Ok(index)
} else if sess_id == 0 && !is_encrypted {
// We must create a new session for this case
@ -538,7 +541,7 @@ impl fmt::Display for SessionMgr {
}
pub struct SessionHandle<'a> {
sess_mgr: &'a mut SessionMgr,
pub(crate) sess_mgr: &'a mut SessionMgr,
sess_idx: usize,
}

View file

@ -15,64 +15,103 @@
* limitations under the License.
*/
use crate::{error::*, MATTER_PORT};
use log::{info, warn};
use smol::net::{Ipv6Addr, UdpSocket};
#[cfg(feature = "std")]
pub use smol_udp::*;
use super::network::Address;
#[cfg(not(feature = "std"))]
pub use dummy_udp::*;
// We could get rid of the smol here, but keeping it around in case we have to process
// any other events in this thread's context
pub struct UdpListener {
socket: UdpSocket,
}
#[cfg(feature = "std")]
mod smol_udp {
use crate::error::*;
use log::{debug, info, warn};
use smol::net::UdpSocket;
impl UdpListener {
pub async fn new() -> Result<UdpListener, Error> {
let listener = UdpListener {
socket: UdpSocket::bind((Ipv6Addr::UNSPECIFIED, MATTER_PORT)).await?,
};
use crate::transport::network::SocketAddr;
info!(
"Listening on {:?} port {}",
Ipv6Addr::UNSPECIFIED,
MATTER_PORT
);
Ok(listener)
pub struct UdpListener {
socket: UdpSocket,
}
pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error> {
info!("Waiting for incoming packets");
impl UdpListener {
pub async fn new(addr: SocketAddr) -> Result<UdpListener, Error> {
let listener = UdpListener {
socket: UdpSocket::bind((addr.ip(), addr.port())).await?,
};
let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| {
warn!("Error on the network: {:?}", e);
ErrorCode::Network
})?;
info!("Listening on {:?}", addr);
info!("Got packet: {:?} from addr {:?}", &in_buf[..size], addr);
Ok(listener)
}
Ok((size, Address::Udp(addr)))
}
pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
info!("Waiting for incoming packets");
pub async fn send(&self, addr: Address, out_buf: &[u8]) -> Result<usize, Error> {
match addr {
Address::Udp(addr) => {
let len = self.socket.send_to(out_buf, addr).await.map_err(|e| {
warn!("Error on the network: {:?}", e);
ErrorCode::Network
})?;
let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| {
warn!("Error on the network: {:?}", e);
ErrorCode::Network
})?;
info!(
"Send packet: {:?} ({}/{}) to addr {:?}",
out_buf,
out_buf.len(),
len,
addr
);
debug!("Got packet {:?} from addr {:?}", &in_buf[..size], addr);
Ok(len)
}
Ok((size, addr))
}
pub async fn send(&self, addr: SocketAddr, out_buf: &[u8]) -> Result<usize, Error> {
let len = self.socket.send_to(out_buf, addr).await.map_err(|e| {
warn!("Error on the network: {:?}", e);
ErrorCode::Network
})?;
debug!(
"Send packet {:?} ({}/{}) to addr {:?}",
out_buf,
out_buf.len(),
len,
addr
);
Ok(len)
}
}
}
#[cfg(not(feature = "std"))]
mod dummy_udp {
use core::future::pending;
use crate::error::*;
use log::{debug, info};
use crate::transport::network::SocketAddr;
pub struct UdpListener {}
impl UdpListener {
pub async fn new(addr: SocketAddr) -> Result<UdpListener, Error> {
let listener = UdpListener {};
info!("Pretending to listen on {:?}", addr);
Ok(listener)
}
pub async fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> {
info!("Pretending to wait for incoming packets (looping forever)");
pending().await
}
pub async fn send(&self, addr: SocketAddr, out_buf: &[u8]) -> Result<usize, Error> {
debug!(
"Send packet {:?} ({}/{}) to addr {:?}",
out_buf,
out_buf.len(),
out_buf.len(),
addr
);
Ok(out_buf.len())
}
}
}

View file

@ -18,4 +18,5 @@
pub mod epoch;
pub mod parsebuf;
pub mod rand;
pub mod select;
pub mod writebuf;

View file

@ -35,13 +35,25 @@ impl<'a> ParseBuf<'a> {
}
}
pub fn set_len(&mut self, left: usize) {
self.left = left;
pub fn reset(&mut self) {
self.read_off = 0;
self.left = self.buf.len();
}
// Return the data that is valid as a slice, consume self
pub fn into_slice(self) -> &'a mut [u8] {
&mut self.buf[self.read_off..(self.read_off + self.left)]
pub fn load(&mut self, pb: &ParseBuf) -> Result<(), Error> {
if self.buf.len() < pb.read_off + pb.left {
Err(ErrorCode::NoSpace)?;
}
self.buf[0..pb.read_off + pb.left].copy_from_slice(&pb.buf[..pb.read_off + pb.left]);
self.read_off = pb.read_off;
self.left = pb.left;
Ok(())
}
pub fn set_len(&mut self, left: usize) {
self.left = left;
}
// Return the data that is valid as a slice
@ -114,7 +126,7 @@ mod tests {
assert_eq!(buf.le_u8().unwrap(), 0x01);
assert_eq!(buf.le_u16().unwrap(), 65);
assert_eq!(buf.le_u32().unwrap(), 0xcafebabe);
assert_eq!(buf.into_slice(), [0xa, 0xb, 0xc, 0xd]);
assert_eq!(buf.as_slice(), [0xa, 0xb, 0xc, 0xd]);
}
#[test]
@ -138,7 +150,7 @@ mod tests {
if buf.le_u8().is_ok() {
panic!("This should have returned error")
}
assert_eq!(buf.into_slice(), []);
assert_eq!(buf.as_slice(), [] as [u8; 0]);
}
#[test]
@ -154,7 +166,7 @@ mod tests {
assert_eq!(buf.as_mut_slice(), [0xa, 0xb]);
assert_eq!(buf.tail(2).unwrap(), [0xa, 0xb]);
assert_eq!(buf.into_slice(), []);
assert_eq!(buf.as_slice(), [] as [u8; 0]);
}
#[test]
@ -176,7 +188,7 @@ mod tests {
let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd];
let mut buf = ParseBuf::new(&mut test_slice);
assert_eq!(buf.parsed_as_slice(), []);
assert_eq!(buf.parsed_as_slice(), [] as [u8; 0]);
assert_eq!(buf.le_u8().unwrap(), 0x1);
assert_eq!(buf.le_u16().unwrap(), 65);
assert_eq!(buf.le_u32().unwrap(), 0xcafebabe);

View file

@ -0,0 +1,35 @@
use embassy_futures::select::{Either, Either3, Either4};
pub trait EitherUnwrap<T> {
fn unwrap(self) -> T;
}
impl<T> EitherUnwrap<T> for Either<T, T> {
fn unwrap(self) -> T {
match self {
Self::First(t) => t,
Self::Second(t) => t,
}
}
}
impl<T> EitherUnwrap<T> for Either3<T, T, T> {
fn unwrap(self) -> T {
match self {
Self::First(t) => t,
Self::Second(t) => t,
Self::Third(t) => t,
}
}
}
impl<T> EitherUnwrap<T> for Either4<T, T, T, T> {
fn unwrap(self) -> T {
match self {
Self::First(t) => t,
Self::Second(t) => t,
Self::Third(t) => t,
Self::Fourth(t) => t,
}
}
}

View file

@ -68,6 +68,18 @@ impl<'a> WriteBuf<'a> {
self.end = 0;
}
pub fn load(&mut self, wb: &WriteBuf) -> Result<(), Error> {
if self.buf_size < wb.end {
Err(ErrorCode::NoSpace)?;
}
self.buf[0..wb.end].copy_from_slice(&wb.buf[..wb.end]);
self.start = wb.start;
self.end = wb.end;
Ok(())
}
pub fn reserve(&mut self, reserve: usize) -> Result<(), Error> {
if self.end != 0 || self.start != 0 || self.buf_size != self.buf.len() {
Err(ErrorCode::Invalid.into())