From 2e0a09b532f9433672a2056d22d37c035b01c08d Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Wed, 24 May 2023 10:07:11 +0000 Subject: [PATCH] built-in mDNS; memory optimizations --- matter/Cargo.toml | 22 +- matter/src/acl.rs | 57 +- matter/src/core.rs | 6 + matter/src/crypto/crypto_rustcrypto.rs | 2 +- matter/src/data_model/objects/handler.rs | 2 +- matter/src/data_model/sdm/failsafe.rs | 1 + .../data_model/sdm/general_commissioning.rs | 2 +- matter/src/data_model/sdm/noc.rs | 2 +- matter/src/error.rs | 10 +- matter/src/fabric.rs | 64 +- matter/src/interaction_model/core.rs | 2 + matter/src/interaction_model/messages.rs | 2 +- matter/src/mdns.rs | 816 +++++++++--------- matter/src/pairing/mod.rs | 2 +- matter/src/secure_channel/case.rs | 5 +- matter/src/secure_channel/common.rs | 3 + matter/src/tlv/parser.rs | 6 +- matter/src/tlv/traits.rs | 28 + matter/src/transport/exchange.rs | 3 +- matter/src/transport/mod.rs | 1 - matter/src/transport/network.rs | 14 +- matter/src/transport/packet.rs | 35 +- matter/src/transport/plain_hdr.rs | 4 +- matter/src/transport/proto_hdr.rs | 4 +- matter/src/transport/session.rs | 41 +- matter/src/transport/udp.rs | 131 ++- matter/src/utils/mod.rs | 1 + matter/src/utils/parsebuf.rs | 30 +- matter/src/utils/select.rs | 35 + matter/src/utils/writebuf.rs | 12 + 30 files changed, 780 insertions(+), 563 deletions(-) create mode 100644 matter/src/utils/select.rs diff --git a/matter/Cargo.toml b/matter/Cargo.toml index c356956..22ef439 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -15,13 +15,14 @@ name = "matter" path = "src/lib.rs" [features] -default = ["std", "crypto_mbedtls", "backtrace"] -std = ["alloc", "env_logger", "rand", "qrcode", "libmdns", "simple-mdns", "simple-dns", "async-io", "smol"] +default = ["os", "crypto_rustcrypto"] +os = ["std", "backtrace", "critical-section/std", "embassy-sync/std", "embassy-time/std"] +std = ["alloc", "env_logger", "rand", "qrcode", "async-io", "smol", "esp-idf-sys/std"] backtrace = [] alloc = [] nightly = [] crypto_openssl = ["alloc", "openssl", "foreign-types", "hmac", "sha2"] -crypto_mbedtls = ["alloc", "mbedtls", "esp-idf-sys"] +crypto_mbedtls = ["alloc", "mbedtls"] crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"] [dependencies] @@ -40,14 +41,16 @@ safemem = { version = "0.3.3", default-features = false } owo-colors = "3" time = { version = "0.3", default-features = false } verhoeff = { version = "1", default-features = false } +embassy-futures = "0.1" +embassy-time = { version = "0.1.1", features = ["generic-queue-8"] } +embassy-sync = "0.2" +critical-section = "1.1.1" +domain = { version = "0.7.2", default_features = false } # STD-only dependencies rand = { version = "0.8.5", optional = true } qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code -simple-mdns = { version = "0.4", features = ["sync"], optional = true } -simple-dns = { version = "0.5", optional = true } astro-dnssd = { version = "0.3", optional = true } # On Linux needs avahi-compat-libdns_sd, i.e. on Ubuntu/Debian do `sudo apt-get install libavahi-compat-libdnssd-dev` -zeroconf = { version = "0.10", optional = true } smol = { version = "1.2", optional = true } # =1.2 for compatibility with ESP IDF async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with ESP IDF @@ -71,14 +74,9 @@ x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], o [target.'cfg(not(target_os = "espidf"))'.dependencies] mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true } env_logger = { version = "0.10.0", optional = true } -libmdns = { version = "0.7", optional = true } [target.'cfg(target_os = "espidf")'.dependencies] -esp-idf-sys = { version = "0.32", default-features = false, features = ["native"], optional = true } - -[[example]] -name = "onoff_light" -path = "../examples/onoff_light/src/main.rs" +esp-idf-sys = { version = "0.33", default-features = false, features = ["native"] } [[example]] diff --git a/matter/src/acl.rs b/matter/src/acl.rs index 77b8e5b..8bd8b70 100644 --- a/matter/src/acl.rs +++ b/matter/src/acl.rs @@ -22,7 +22,7 @@ use crate::{ error::{Error, ErrorCode}, fabric, interaction_model::messages::GenericPath, - tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, + tlv::{self, FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC}, utils::writebuf::WriteBuf, }; @@ -390,7 +390,7 @@ impl AclEntry { const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; -type AclEntries = [Option; MAX_ACL_ENTRIES]; +type AclEntries = heapless::Vec, MAX_ACL_ENTRIES>; pub struct AclMgr { entries: AclEntries, @@ -398,20 +398,16 @@ pub struct AclMgr { } impl AclMgr { + #[inline(always)] pub const fn new() -> Self { - const INIT: Option = None; - Self { - entries: [INIT; MAX_ACL_ENTRIES], + entries: AclEntries::new(), changed: false, } } pub fn erase_all(&mut self) -> Result<(), Error> { - for i in 0..MAX_ACL_ENTRIES { - self.entries[i] = None; - } - + self.entries.clear(); self.changed = true; Ok(()) @@ -427,14 +423,21 @@ impl AclMgr { if cnt >= ENTRIES_PER_FABRIC { Err(ErrorCode::NoSpace)?; } - let index = self - .entries - .iter() - .position(|a| a.is_none()) - .ok_or(ErrorCode::NoSpace)?; - self.entries[index] = Some(entry); - self.changed = true; + let slot = self.entries.iter().position(|a| a.is_none()); + + if slot.is_some() || self.entries.len() < MAX_ACL_ENTRIES { + if let Some(index) = slot { + self.entries[index] = Some(entry); + } else { + self.entries + .push(Some(entry)) + .map_err(|_| ErrorCode::NoSpace) + .unwrap(); + } + + self.changed = true; + } Ok(()) } @@ -459,17 +462,13 @@ impl AclMgr { } pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> { - for i in 0..MAX_ACL_ENTRIES { - if self.entries[i] - .filter(|e| e.fab_idx == Some(fab_idx)) - .is_some() - { - self.entries[i] = None; + for entry in &mut self.entries { + if entry.map(|e| e.fab_idx == Some(fab_idx)).unwrap_or(false) { + *entry = None; + self.changed = true; } } - self.changed = true; - Ok(()) } @@ -505,7 +504,7 @@ impl AclMgr { pub fn load(&mut self, data: &[u8]) -> Result<(), Error> { let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; - self.entries = AclEntries::from_tlv(&root)?; + tlv::from_tlv(&mut self.entries, &root)?; self.changed = false; Ok(()) @@ -515,7 +514,9 @@ impl AclMgr { if self.changed { let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); - self.entries.to_tlv(&mut tw, TagType::Anonymous)?; + self.entries + .as_slice() + .to_tlv(&mut tw, TagType::Anonymous)?; self.changed = false; @@ -527,6 +528,10 @@ impl AclMgr { } } + pub fn is_changed(&self) -> bool { + self.changed + } + /// Traverse fabric specific entries to find the index /// /// If the ACL Mgr has 3 entries with fabric indexes, 1, 2, 1, then the list diff --git a/matter/src/core.rs b/matter/src/core.rs index 2fc1c3a..fa960f7 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -53,6 +53,7 @@ pub struct Matter<'a> { impl<'a> Matter<'a> { #[cfg(feature = "std")] + #[inline(always)] pub fn new_default(dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns, port: u16) -> Self { use crate::utils::epoch::sys_epoch; use crate::utils::rand::sys_rand; @@ -66,6 +67,7 @@ impl<'a> Matter<'a> { /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device /// requires a set of device attestation certificates and keys. It is the responsibility of /// this object to return the device attestation details when queried upon. + #[inline(always)] pub fn new( dev_det: &'a BasicInfoConfig, mdns: &'a mut dyn Mdns, @@ -113,6 +115,10 @@ impl<'a> Matter<'a> { self.acl_mgr.borrow_mut().store(buf) } + pub fn is_changed(&self) -> bool { + self.acl_mgr.borrow().is_changed() || self.fabric_mgr.borrow().is_changed() + } + pub fn start(&self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> { let open_comm_window = self.fabric_mgr.borrow().is_empty(); if open_comm_window { diff --git a/matter/src/crypto/crypto_rustcrypto.rs b/matter/src/crypto/crypto_rustcrypto.rs index 6212c96..19c288e 100644 --- a/matter/src/crypto/crypto_rustcrypto.rs +++ b/matter/src/crypto/crypto_rustcrypto.rs @@ -51,7 +51,7 @@ type AesCcm = Ccm; extern crate alloc; -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct Sha256 { hasher: sha2::Sha256, } diff --git a/matter/src/data_model/objects/handler.rs b/matter/src/data_model/objects/handler.rs index a5e2b9c..143cad8 100644 --- a/matter/src/data_model/objects/handler.rs +++ b/matter/src/data_model/objects/handler.rs @@ -49,7 +49,7 @@ impl Handler for &mut T where T: Handler, { - fn read<'a>(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { + fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { (**self).read(attr, encoder) } diff --git a/matter/src/data_model/sdm/failsafe.rs b/matter/src/data_model/sdm/failsafe.rs index 301baf9..043f5b9 100644 --- a/matter/src/data_model/sdm/failsafe.rs +++ b/matter/src/data_model/sdm/failsafe.rs @@ -49,6 +49,7 @@ pub struct FailSafe { } impl FailSafe { + #[inline(always)] pub const fn new() -> Self { Self { state: State::Idle } } diff --git a/matter/src/data_model/sdm/general_commissioning.rs b/matter/src/data_model/sdm/general_commissioning.rs index f2487ef..b0cdff1 100644 --- a/matter/src/data_model/sdm/general_commissioning.rs +++ b/matter/src/data_model/sdm/general_commissioning.rs @@ -138,7 +138,7 @@ impl<'a> GenCommCluster<'a> { } pub fn failsafe(&self) -> &RefCell { - &self.failsafe + self.failsafe } pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> { diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index f7346cc..b8dda3c 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -613,7 +613,7 @@ impl<'a> NocCluster<'a> { SessionMode::Pase => { let noc_data = transaction .session_mut() - .get_noc_data::() + .get_noc_data() .ok_or(ErrorCode::NoSession)?; let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; diff --git a/matter/src/error.rs b/matter/src/error.rs index e15cbb7..c8da820 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -165,11 +165,11 @@ impl From for Error { } } -#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))] +#[cfg(target_os = "espidf")] impl From for Error { fn from(e: esp_idf_sys::EspError) -> Self { - ::log::error!("Error in TLS: {}", e); - Self::new(ErrorCode::TLSStack) + ::log::error!("Error in ESP: {}", e); + Self::new(ErrorCode::TLSStack) // TODO: Not a good mapping } } @@ -208,9 +208,9 @@ impl fmt::Debug for Error { #[cfg(all(feature = "std", feature = "backtrace"))] { - write!(f, "Error::{} {{\n", self)?; + writeln!(f, "Error::{} {{", self)?; write!(f, "{}", self.backtrace())?; - write!(f, "}}\n")?; + writeln!(f, "}}")?; } Ok(()) diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 5658d4c..f6f64ef 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -27,7 +27,7 @@ use crate::{ error::{Error, ErrorCode}, group_keys::KeySet, mdns::{MdnsMgr, ServiceMode}, - tlv::{FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, + tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, utils::writebuf::WriteBuf, }; @@ -184,7 +184,7 @@ impl Fabric { pub const MAX_SUPPORTED_FABRICS: usize = 3; -type FabricEntries = [Option; MAX_SUPPORTED_FABRICS]; +type FabricEntries = Vec, MAX_SUPPORTED_FABRICS>; pub struct FabricMgr { fabrics: FabricEntries, @@ -192,30 +192,25 @@ pub struct FabricMgr { } impl FabricMgr { + #[inline(always)] pub const fn new() -> Self { - const INIT: Option = None; - Self { - fabrics: [INIT; MAX_SUPPORTED_FABRICS], + fabrics: FabricEntries::new(), changed: false, } } pub fn load(&mut self, data: &[u8], mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { - for fabric in &self.fabrics { - if let Some(fabric) = fabric { - mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; - } + for fabric in self.fabrics.iter().flatten() { + mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; - self.fabrics = FabricEntries::from_tlv(&root)?; + tlv::from_tlv(&mut self.fabrics, &root)?; - for fabric in &self.fabrics { - if let Some(fabric) = fabric { - mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; - } + for fabric in self.fabrics.iter().flatten() { + mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } self.changed = false; @@ -228,7 +223,9 @@ impl FabricMgr { let mut wb = WriteBuf::new(buf); let mut tw = TLVWriter::new(&mut wb); - self.fabrics.to_tlv(&mut tw, TagType::Anonymous)?; + self.fabrics + .as_slice() + .to_tlv(&mut tw, TagType::Anonymous)?; self.changed = false; @@ -240,20 +237,32 @@ impl FabricMgr { } } + pub fn is_changed(&self) -> bool { + self.changed + } + pub fn add(&mut self, f: Fabric, mdns_mgr: &mut MdnsMgr) -> Result { - for (index, fabric) in self.fabrics.iter_mut().enumerate() { - if fabric.is_none() { - mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + let slot = self.fabrics.iter().position(|x| x.is_none()); - *fabric = Some(f); + if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS { + mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + self.changed = true; - self.changed = true; + if let Some(index) = slot { + self.fabrics[index] = Some(f); - return Ok((index + 1) as u8); + Ok((index + 1) as u8) + } else { + self.fabrics + .push(Some(f)) + .map_err(|_| ErrorCode::NoSpace) + .unwrap(); + + Ok(self.fabrics.len() as u8) } + } else { + Err(ErrorCode::NoSpace.into()) } - - Err(ErrorCode::NoSpace.into()) } pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &mut MdnsMgr) -> Result<(), Error> { @@ -311,15 +320,14 @@ impl FabricMgr { } pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> { - if !label.is_empty() { - if self + if !label.is_empty() + && self .fabrics .iter() .filter_map(|f| f.as_ref()) .any(|f| f.label == label) - { - return Err(ErrorCode::Invalid.into()); - } + { + return Err(ErrorCode::Invalid.into()); } let index = (index - 1) as usize; diff --git a/matter/src/interaction_model/core.rs b/matter/src/interaction_model/core.rs index 82d2eb4..e24ec07 100644 --- a/matter/src/interaction_model/core.rs +++ b/matter/src/interaction_model/core.rs @@ -605,6 +605,7 @@ impl<'a> SubscribeReq<'a> { } } +#[derive(Debug)] pub struct ResumeReadReq { pub paths: heapless::Vec, pub filters: heapless::Vec, @@ -664,6 +665,7 @@ impl ResumeReadReq { } } +#[derive(Debug)] pub struct ResumeSubscribeReq { pub subscription_id: u32, pub paths: heapless::Vec, diff --git a/matter/src/interaction_model/messages.rs b/matter/src/interaction_model/messages.rs index edf65db..bfd8a8b 100644 --- a/matter/src/interaction_model/messages.rs +++ b/matter/src/interaction_model/messages.rs @@ -77,7 +77,7 @@ pub mod msg { EventPath, }; - #[derive(Default, FromTLV, ToTLV)] + #[derive(Debug, Default, FromTLV, ToTLV)] #[tlvargs(lifetime = "'a")] pub struct SubscribeReq<'a> { pub keep_subs: bool, diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index defb137..a187683 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -109,6 +109,7 @@ pub struct MdnsMgr<'a> { } impl<'a> MdnsMgr<'a> { + #[inline(always)] pub fn new( vid: u16, pid: u16, @@ -212,6 +213,428 @@ impl<'a> MdnsMgr<'a> { } } +pub mod builtin { + use core::cell::RefCell; + use core::fmt::Write; + use core::pin::pin; + use core::str::FromStr; + + use domain::base::header::Flags; + use domain::base::iana::Class; + use domain::base::octets::{Octets256, Octets64, OctetsBuilder}; + use domain::base::{Dname, MessageBuilder, Record, ShortBuf}; + use domain::rdata::{Aaaa, Ptr, Srv, Txt, A}; + use embassy_futures::select::select; + use embassy_sync::blocking_mutex::raw::NoopRawMutex; + use embassy_time::{Duration, Timer}; + use log::info; + + use crate::error::{Error, ErrorCode}; + use crate::transport::network::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; + use crate::transport::udp::UdpListener; + use crate::utils::select::EitherUnwrap; + + const IP_BROADCAST_ADDRS: [SocketAddr; 2] = [ + SocketAddr::new(IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), + SocketAddr::new( + IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)), + 5353, + ), + ]; + + const IP_BIND_ADDR: SocketAddr = SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); + + pub fn create_record( + id: u16, + hostname: &str, + ip: [u8; 4], + ipv6: Option<[u8; 16]>, + + ttl_sec: u32, + + name: &str, + service: &str, + protocol: &str, + port: u16, + service_subtypes: &[&str], + txt_kvs: &[(&str, &str)], + + buffer: &mut [u8], + ) -> Result { + let target = domain::base::octets::Octets2048::new(); + let message = MessageBuilder::from_target(target)?; + + let mut message = message.answer(); + + let mut ptr_str = heapless::String::<40>::new(); + write!(ptr_str, "{}.{}.local", service, protocol).unwrap(); + + let mut dname = heapless::String::<60>::new(); + write!(dname, "{}.{}.{}.local", name, service, protocol).unwrap(); + + let mut hname = heapless::String::<40>::new(); + write!(hname, "{}.local", hostname).unwrap(); + + let ptr: Dname = Dname::from_str(&ptr_str).unwrap(); + let record: Record, Ptr<_>> = Record::new( + Dname::from_str("_services._dns-sd._udp.local").unwrap(), + Class::In, + ttl_sec, + Ptr::new(ptr), + ); + message.push(record)?; + + let t: Dname = Dname::from_str(&dname).unwrap(); + let record: Record, Ptr<_>> = Record::new( + Dname::from_str(&ptr_str).unwrap(), + Class::In, + ttl_sec, + Ptr::new(t), + ); + message.push(record)?; + + for sub_srv in service_subtypes { + let mut ptr_str = heapless::String::<40>::new(); + write!(ptr_str, "{}._sub.{}.{}.local", sub_srv, service, protocol).unwrap(); + + let ptr: Dname = Dname::from_str(&ptr_str).unwrap(); + let record: Record, Ptr<_>> = Record::new( + Dname::from_str("_services._dns-sd._udp.local").unwrap(), + Class::In, + ttl_sec, + Ptr::new(ptr), + ); + message.push(record)?; + + let t: Dname = Dname::from_str(&dname).unwrap(); + let record: Record, Ptr<_>> = Record::new( + Dname::from_str(&ptr_str).unwrap(), + Class::In, + ttl_sec, + Ptr::new(t), + ); + message.push(record)?; + } + + let target: Dname = Dname::from_str(&hname).unwrap(); + let record: Record, Srv<_>> = Record::new( + Dname::from_str(&dname).unwrap(), + Class::In, + ttl_sec, + Srv::new(0, 0, port, target), + ); + message.push(record)?; + + // only way I found to create multiple parts in a Txt + // each slice is the length and then the data + let mut octets = Octets256::new(); + //octets.append_slice(&[1u8, b'X']).unwrap(); + //octets.append_slice(&[2u8, b'A', b'B']).unwrap(); + //octets.append_slice(&[0u8]).unwrap(); + for (k, v) in txt_kvs { + octets + .append_slice(&[(k.len() + v.len() + 1) as u8]) + .unwrap(); + octets.append_slice(k.as_bytes()).unwrap(); + octets.append_slice(&[b'=']).unwrap(); + octets.append_slice(v.as_bytes()).unwrap(); + } + + let txt = Txt::from_octets(&mut octets).unwrap(); + + let record: Record, Txt<_>> = + Record::new(Dname::from_str(&dname).unwrap(), Class::In, ttl_sec, txt); + message.push(record)?; + + let record: Record, A> = Record::new( + Dname::from_str(&hname).unwrap(), + Class::In, + ttl_sec, + A::from_octets(ip[0], ip[1], ip[2], ip[3]), + ); + message.push(record)?; + + if let Some(ipv6) = ipv6 { + let record: Record, Aaaa> = Record::new( + Dname::from_str(&hname).unwrap(), + Class::In, + ttl_sec, + Aaaa::new(ipv6.into()), + ); + message.push(record)?; + } + + let headerb = message.header_mut(); + headerb.set_id(id); + headerb.set_opcode(domain::base::iana::Opcode::Query); + headerb.set_rcode(domain::base::iana::Rcode::NoError); + + let mut flags = Flags::new(); + flags.qr = true; + flags.aa = true; + headerb.set_flags(flags); + + let target = message.finish(); + + buffer[..target.len()].copy_from_slice(target.as_ref()); + + Ok(target.len()) + } + + pub type Notification = embassy_sync::signal::Signal; + + #[derive(Debug, Clone)] + struct MdnsEntry { + key: heapless::String<64>, + record: heapless::Vec, + } + + impl MdnsEntry { + #[inline(always)] + const fn new() -> Self { + Self { + key: heapless::String::new(), + record: heapless::Vec::new(), + } + } + } + + pub struct Mdns<'a> { + id: u16, + hostname: &'a str, + ip: [u8; 4], + ipv6: Option<[u8; 16]>, + entries: RefCell>, + notification: Notification, + udp: RefCell>, + } + + impl<'a> Mdns<'a> { + #[inline(always)] + pub const fn new(id: u16, hostname: &'a str, ip: [u8; 4], ipv6: Option<[u8; 16]>) -> Self { + Self { + id, + hostname, + ip, + ipv6, + entries: RefCell::new(heapless::Vec::new()), + notification: Notification::new(), + udp: RefCell::new(None), + } + } + + pub fn split(&mut self) -> (MdnsApi<'_, 'a>, MdnsRunner<'_, 'a>) { + (MdnsApi(&*self), MdnsRunner(&*self)) + } + + async fn bind(&self) -> Result<(), Error> { + if self.udp.borrow().is_none() { + *self.udp.borrow_mut() = Some(UdpListener::new(IP_BIND_ADDR).await?); + } + + Ok(()) + } + + pub fn close(&mut self) { + *self.udp.borrow_mut() = None; + } + + fn key( + &self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> heapless::String<64> { + let mut key = heapless::String::new(); + + write!(&mut key, "{name}.{service}.{protocol}.{port}").unwrap(); + + key + } + } + + pub struct MdnsApi<'a, 'b>(&'a Mdns<'b>); + + impl<'a, 'b> MdnsApi<'a, 'b> { + pub fn add( + &self, + name: &str, + service: &str, + protocol: &str, + port: u16, + service_subtypes: &[&str], + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + info!( + "Registering mDNS service {}/{}.{} [{:?}]/{}, keys [{:?}]", + name, service, protocol, service_subtypes, port, txt_kvs + ); + + let key = self.0.key(name, service, protocol, port); + + let mut entries = self.0.entries.borrow_mut(); + + entries.retain(|entry| entry.key != key); + entries + .push(MdnsEntry::new()) + .map_err(|_| ErrorCode::NoSpace)?; + + let entry = entries.iter_mut().last().unwrap(); + entry + .record + .resize(1024, 0) + .map_err(|_| ErrorCode::NoSpace) + .unwrap(); + + match create_record( + self.0.id, + self.0.hostname, + self.0.ip, + self.0.ipv6, + 60, /*ttl_sec*/ + name, + service, + protocol, + port, + service_subtypes, + txt_kvs, + &mut entry.record, + ) { + Ok(len) => entry.record.truncate(len), + Err(_) => { + entries.pop(); + Err(ErrorCode::NoSpace)?; + } + } + + self.0.notification.signal(()); + + Ok(()) + } + + pub fn remove( + &self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> Result<(), Error> { + info!( + "Deregistering mDNS service {}/{}.{}/{}", + name, service, protocol, port + ); + + let key = self.0.key(name, service, protocol, port); + + let mut entries = self.0.entries.borrow_mut(); + + let old_len = entries.len(); + + entries.retain(|entry| entry.key != key); + + if entries.len() != old_len { + self.0.notification.signal(()); + } + + Ok(()) + } + } + + pub struct MdnsRunner<'a, 'b>(&'a Mdns<'b>); + + impl<'a, 'b> MdnsRunner<'a, 'b> { + pub async fn run(&mut self) -> Result<(), Error> { + let mut broadcast = pin!(self.broadcast()); + let mut respond = pin!(self.respond()); + + select(&mut broadcast, &mut respond).await.unwrap() + } + + async fn broadcast(&self) -> Result<(), Error> { + loop { + select( + self.0.notification.wait(), + Timer::after(Duration::from_secs(30)), + ) + .await; + + let mut index = 0; + + while let Some(entry) = self + .0 + .entries + .borrow() + .get(index) + .map(|entry| entry.clone()) + { + info!("Broadasting mDNS entry {}", &entry.key); + + self.0.bind().await?; + + let udp = self.0.udp.borrow(); + let udp = udp.as_ref().unwrap(); + + for addr in IP_BROADCAST_ADDRS { + udp.send(addr, &entry.record).await?; + } + + index += 1; + } + } + } + + async fn respond(&self) -> Result<(), Error> { + loop { + let mut buf = [0; 1580]; + + let udp = self.0.udp.borrow(); + let udp = udp.as_ref().unwrap(); + + let (_len, _addr) = udp.recv(&mut buf).await?; + + info!("Received UDP packet"); + + // TODO: Process the incoming packed and only answer what we are being queried about + + self.0.notification.signal(()); + } + } + } + + impl<'a, 'b> super::Mdns for MdnsApi<'a, 'b> { + fn add( + &mut self, + name: &str, + service: &str, + protocol: &str, + port: u16, + service_subtypes: &[&str], + txt_kvs: &[(&str, &str)], + ) -> Result<(), Error> { + MdnsApi::add( + self, + name, + service, + protocol, + port, + service_subtypes, + txt_kvs, + ) + } + + fn remove( + &mut self, + name: &str, + service: &str, + protocol: &str, + port: u16, + ) -> Result<(), Error> { + MdnsApi::remove(self, name, service, protocol, port) + } + } +} + #[cfg(all(feature = "std", feature = "astro-dnssd"))] pub mod astro { use std::collections::HashMap; @@ -342,399 +765,6 @@ pub mod astro { } } -// TODO: Maybe future -// #[cfg(all(feature = "std", feature = "zeroconf"))] -// pub mod zeroconf { -// use std::collections::HashMap; - -// use super::Mdns; -// use crate::error::{Error, ErrorCode}; -// use log::info; -// use zeroconf::prelude::*; -// use zeroconf::{MdnsService, ServiceType, TxtRecord}; - -// #[derive(Debug, Clone, Eq, PartialEq, Hash)] -// pub struct ServiceId { -// name: String, -// service: String, -// protocol: String, -// port: u16, -// } - -// pub struct ZeroconfMdns { -// services: HashMap, -// } - -// impl ZeroconfMdns { -// pub fn new() -> Result { -// Ok(Self { -// services: HashMap::new(), -// }) -// } - -// pub fn add( -// &mut self, -// name: &str, -// service: &str, -// protocol: &str, -// port: u16, -// service_subtypes: &[&str], -// txt_kvs: &[(&str, &str)], -// ) -> Result<(), Error> { -// info!( -// "Registering mDNS service {}/{}.{} [{:?}]/{}", -// name, service, protocol, service_subtypes, port -// ); - -// let _ = self.remove(name, service, protocol, port); - -// let mut svc = MdnsService::new( -// ServiceType::with_sub_types(service, protocol, service_subtypes.into()).unwrap(), -// port, -// ); - -// let mut txt = TxtRecord::new(); - -// for kvs in txt_kvs { -// info!("mDNS TXT key {} val {}", kvs.0, kvs.1); -// txt.insert(kvs.0, kvs.1); -// } - -// svc.set_txt_record(txt); - -// //let event_loop = svc.register().map_err(|_| ErrorCode::MdnsError)?; - -// self.services.insert( -// ServiceId { -// name: name.into(), -// service: service.into(), -// protocol: protocol.into(), -// port, -// }, -// svc, -// ); - -// Ok(()) -// } - -// pub fn remove( -// &mut self, -// name: &str, -// service: &str, -// protocol: &str, -// port: u16, -// ) -> Result<(), Error> { -// let id = ServiceId { -// name: name.into(), -// service: service.into(), -// protocol: protocol.into(), -// port, -// }; - -// if self.services.remove(&id).is_some() { -// info!( -// "Deregistering mDNS service {}.{}/{}/{}", -// name, service, protocol, port -// ); -// } - -// Ok(()) -// } -// } - -// impl Mdns for ZeroconfMdns { -// fn add( -// &mut self, -// name: &str, -// service: &str, -// protocol: &str, -// port: u16, -// service_subtypes: &[&str], -// txt_kvs: &[(&str, &str)], -// ) -> Result<(), Error> { -// ZeroconfMdns::add( -// self, -// name, -// service, -// protocol, -// port, -// service_subtypes, -// txt_kvs, -// ) -// } - -// fn remove( -// &mut self, -// name: &str, -// service: &str, -// protocol: &str, -// port: u16, -// ) -> Result<(), Error> { -// ZeroconfMdns::remove(self, name, service, protocol, port) -// } -// } -// } - -#[cfg(all(feature = "std", not(target_os = "espidf")))] -pub mod libmdns { - use super::Mdns; - use crate::error::Error; - use libmdns::{Responder, Service}; - use log::info; - use std::collections::HashMap; - use std::vec::Vec; - - #[derive(Debug, Clone, Eq, PartialEq, Hash)] - pub struct ServiceId { - name: String, - service: String, - protocol: String, - port: u16, - } - - pub struct LibMdns { - responder: Responder, - services: HashMap, - } - - impl LibMdns { - pub fn new() -> Result { - let responder = Responder::new()?; - - Ok(Self { - responder, - services: HashMap::new(), - }) - } - - pub fn add( - &mut self, - name: &str, - service: &str, - protocol: &str, - port: u16, - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - info!( - "Registering mDNS service {}/{}.{}/{}", - name, service, protocol, port - ); - - let _ = self.remove(name, service, protocol, port); - - let mut properties = Vec::new(); - for kvs in txt_kvs { - info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - properties.push(format!("{}={}", kvs.0, kvs.1)); - } - let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect(); - - let svc = self.responder.register( - format!("{}.{}", service, protocol), - name.to_owned(), - port, - &properties, - ); - - self.services.insert( - ServiceId { - name: name.into(), - service: service.into(), - protocol: protocol.into(), - port, - }, - svc, - ); - - Ok(()) - } - - pub fn remove( - &mut self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - let id = ServiceId { - name: name.into(), - service: service.into(), - protocol: protocol.into(), - port, - }; - - if self.services.remove(&id).is_some() { - info!( - "Deregistering mDNS service {}/{}.{}/{}", - name, service, protocol, port - ); - } - - Ok(()) - } - } - - impl Mdns for LibMdns { - fn add( - &mut self, - name: &str, - service: &str, - protocol: &str, - port: u16, - _service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - LibMdns::add(self, name, service, protocol, port, txt_kvs) - } - - fn remove( - &mut self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - LibMdns::remove(self, name, service, protocol, port) - } - } -} - -// TODO: Maybe future -// #[cfg(feature = "std")] -// pub mod simplemdns { -// use std::net::Ipv4Addr; - -// use crate::error::{Error, ErrorCode}; -// use super::Mdns; -// use log::info; -// use simple_dns::{ -// rdata::{RData, A, SRV, TXT, PTR}, -// CharacterString, Name, ResourceRecord, CLASS, -// }; -// use simple_mdns::sync_discovery::SimpleMdnsResponder; - -// #[derive(Debug, Clone, Eq, PartialEq, Hash)] -// pub struct ServiceId { -// name: String, -// service_type: String, -// port: u16, -// } - -// pub struct SimpleMdns { -// responder: SimpleMdnsResponder, -// } - -// impl SimpleMdns { -// pub fn new() -> Result { -// Ok(Self { -// responder: Default::default(), -// }) -// } - -// pub fn add( -// &mut self, -// name: &str, -// service_type: &str, -// port: u16, -// txt_kvs: &[(&str, &str)], -// ) -> Result<(), Error> { -// info!( -// "Registering mDNS service {}/{}/{}", -// name, service_type, port -// ); - -// let _ = self.remove(name, service_type, port); - -// let mut txt = TXT::new(); -// for kvs in txt_kvs { -// info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - -// let string = format!("{}={}", kvs.0, kvs.1); -// txt.add_char_string( -// CharacterString::new(string.as_bytes()) -// .unwrap() -// .into_owned(), -// ); -// } - -// let name = Name::new_unchecked(name).into_owned(); -// let service_type = Name::new_unchecked(service_type).into_owned(); - -// self.responder.add_resource(ResourceRecord::new( -// name.clone(), -// CLASS::IN, -// 10, -// RData::A(A { -// address: Ipv4Addr::new(192, 168, 10, 189).into(), -// }), -// )); - -// self.responder.add_resource(ResourceRecord::new( -// name.clone(), -// CLASS::IN, -// 10, -// RData::SRV(SRV { -// port: port, -// priority: 0, -// weight: 0, -// target: service_type.clone(), -// }), -// )); - -// self.responder.add_resource(ResourceRecord::new( -// srv_name.clone(), -// CLASS::IN, -// 10, -// RData::PTR(PTR(srv_name.clone()), -// ))); - -// self.responder.add_resource(ResourceRecord::new( -// srv_name, -// CLASS::IN, -// 10, -// RData::TXT(txt), -// )); - -// Ok(()) -// } - -// pub fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { -// // TODO -// // let id = ServiceId { -// // name: name.into(), -// // service_type: service_type.into(), -// // port, -// // }; - -// // if self.responder.remove_resource_record(resource).remove(&id).is_some() { -// // info!( -// // "Deregistering mDNS service {}/{}/{}", -// // name, service_type, port -// // ); -// // } - -// Ok(()) -// } -// } - -// impl Mdns for SimpleMdns { -// fn add( -// &mut self, -// name: &str, -// service_type: &str, -// port: u16, -// _service_subtypes: &[&str], -// txt_kvs: &[(&str, &str)], -// ) -> Result<(), Error> { -// SimpleMdns::add(self, name, service_type, port, txt_kvs) -// } - -// fn remove(&mut self, name: &str, service_type: &str, port: u16) -> Result<(), Error> { -// SimpleMdns::remove(self, name, service_type, port) -// } -// } -// } - #[cfg(test)] mod tests { use super::*; diff --git a/matter/src/pairing/mod.rs b/matter/src/pairing/mod.rs index 2dddce5..253062e 100644 --- a/matter/src/pairing/mod.rs +++ b/matter/src/pairing/mod.rs @@ -91,7 +91,7 @@ pub fn print_pairing_code_and_qr( let qr_code = compute_qr_code(dev_det, comm_data, discovery_capabilities, buf)?; pretty_print_pairing_code(&pairing_code); - print_qr_code(&qr_code); + print_qr_code(qr_code); Ok(()) } diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index fbd6da8..c029963 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -35,12 +35,13 @@ use crate::{ utils::{rand::Rand, writebuf::WriteBuf}, }; -#[derive(PartialEq)] +#[derive(Debug, Copy, Clone, Eq, PartialEq)] enum State { Sigma1Rx, Sigma3Rx, } +#[derive(Debug, Clone)] pub struct CaseSession { state: State, peer_sessid: u16, @@ -84,7 +85,7 @@ impl<'a> Case<'a> { let mut case_session = ctx .exch_ctx .exch - .take_case_session::() + .take_case_session() .ok_or(ErrorCode::InvalidState)?; if case_session.state != State::Sigma1Rx { Err(ErrorCode::Invalid)?; diff --git a/matter/src/secure_channel/common.rs b/matter/src/secure_channel/common.rs index 7049ba3..c007ee5 100644 --- a/matter/src/secure_channel/common.rs +++ b/matter/src/secure_channel/common.rs @@ -56,6 +56,8 @@ pub fn create_sc_status_report( status_code: SCStatusCodes, proto_data: Option<&[u8]>, ) -> Result<(), Error> { + proto_tx.reset(); + let general_code = match status_code { SCStatusCodes::SessionEstablishmentSuccess => GeneralCode::Success, SCStatusCodes::CloseSession => { @@ -79,6 +81,7 @@ pub fn create_sc_status_report( } pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) { + proto_tx.reset(); proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8); proto_tx.unset_reliable(); diff --git a/matter/src/tlv/parser.rs b/matter/src/tlv/parser.rs index b740f5d..0c179e2 100644 --- a/matter/src/tlv/parser.rs +++ b/matter/src/tlv/parser.rs @@ -711,11 +711,7 @@ impl<'a> Iterator for TLVContainerIterator<'a> { return None; } - if is_container(element.element_type) { - self.prev_container = true; - } else { - self.prev_container = false; - } + self.prev_container = is_container(element.element_type); Some(element) } } diff --git a/matter/src/tlv/traits.rs b/matter/src/tlv/traits.rs index 3fced12..28c236b 100644 --- a/matter/src/tlv/traits.rs +++ b/matter/src/tlv/traits.rs @@ -61,6 +61,24 @@ impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] { } } +pub fn from_tlv<'a, T: FromTLV<'a>, const N: usize>( + vec: &mut heapless::Vec, + t: &TLVElement<'a>, +) -> Result<(), Error> { + vec.clear(); + + t.confirm_array()?; + + if let Some(tlv_iter) = t.enter() { + for element in tlv_iter { + vec.push(T::from_tlv(&element)?) + .map_err(|_| ErrorCode::NoSpace)?; + } + } + + Ok(()) +} + macro_rules! fromtlv_for { ($($t:ident)*) => { $( @@ -110,6 +128,16 @@ impl ToTLV for [T; N] { } } +impl<'a, T: ToTLV> ToTLV for &'a [T] { + fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { + tw.start_array(tag)?; + for i in *self { + i.to_tlv(tw, TagType::Anonymous)?; + } + tw.end_container() + } +} + // Generate ToTLV for standard data types totlv_for!(i8 u8 u16 u32 u64 bool); diff --git a/matter/src/transport/exchange.rs b/matter/src/transport/exchange.rs index 57f666c..04b63db 100644 --- a/matter/src/transport/exchange.rs +++ b/matter/src/transport/exchange.rs @@ -15,7 +15,6 @@ * limitations under the License. */ -use core::any::Any; use core::fmt; use core::time::Duration; use log::{error, info, trace}; @@ -144,7 +143,7 @@ impl Exchange { } } - pub fn take_case_session(&mut self) -> Option { + pub fn take_case_session(&mut self) -> Option { let old = core::mem::replace(&mut self.data, DataOption::None); if let DataOption::CaseSession(session) = old { Some(session) diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 0b6453e..1a81c75 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -25,5 +25,4 @@ pub mod plain_hdr; pub mod proto_ctx; pub mod proto_hdr; pub mod session; -#[cfg(feature = "std")] pub mod udp; diff --git a/matter/src/transport/network.rs b/matter/src/transport/network.rs index e03658b..ba50386 100644 --- a/matter/src/transport/network.rs +++ b/matter/src/transport/network.rs @@ -17,15 +17,23 @@ use core::fmt::{Debug, Display}; #[cfg(not(feature = "std"))] -pub use no_std_net::{IpAddr, Ipv4Addr, SocketAddr}; +pub use no_std_net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; #[cfg(feature = "std")] -pub use std::net::{IpAddr, Ipv4Addr, SocketAddr}; +pub use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; -#[derive(PartialEq, Copy, Clone)] +#[derive(Eq, PartialEq, Copy, Clone)] pub enum Address { Udp(SocketAddr), } +impl Address { + pub fn unwrap_udp(self) -> SocketAddr { + match self { + Self::Udp(addr) => addr, + } + } +} + impl Default for Address { fn default() -> Self { Address::Udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080)) diff --git a/matter/src/transport/packet.rs b/matter/src/transport/packet.rs index 3e7e9c7..72368cb 100644 --- a/matter/src/transport/packet.rs +++ b/matter/src/transport/packet.rs @@ -31,7 +31,7 @@ use super::{ pub const MAX_RX_BUF_SIZE: usize = 1583; pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/; -#[derive(PartialEq)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] enum RxState { Uninit, PlainDecode, @@ -43,6 +43,30 @@ enum Direction<'a> { Rx(ParseBuf<'a>, RxState), } +impl<'a> Direction<'a> { + pub fn load(&mut self, direction: &Direction) -> Result<(), Error> { + if matches!(self, Self::Tx(_)) != matches!(direction, Direction::Tx(_)) { + Err(ErrorCode::Invalid)?; + } + + match self { + Self::Tx(wb) => match direction { + Direction::Tx(src_wb) => wb.load(src_wb)?, + Direction::Rx(_, _) => Err(ErrorCode::Invalid)?, + }, + Self::Rx(pb, state) => match direction { + Direction::Tx(_) => Err(ErrorCode::Invalid)?, + Direction::Rx(src_pb, src_state) => { + pb.load(src_pb)?; + *state = *src_state; + } + }, + } + + Ok(()) + } +} + pub struct Packet<'a> { pub plain: PlainHdr, pub proto: ProtoHdr, @@ -78,7 +102,7 @@ impl<'a> Packet<'a> { } } - pub fn reset(&mut self) -> () { + pub fn reset(&mut self) { if let Direction::Tx(wb) = &mut self.data { wb.reset(); wb.reserve(Packet::HDR_RESERVE).unwrap(); @@ -91,6 +115,13 @@ impl<'a> Packet<'a> { } } + pub fn load(&mut self, packet: &Packet) -> Result<(), Error> { + self.plain = packet.plain.clone(); + self.proto = packet.proto.clone(); + self.peer = packet.peer; + self.data.load(&packet.data) + } + pub fn as_slice(&self) -> &[u8] { match &self.data { Direction::Rx(pb, _) => pb.as_slice(), diff --git a/matter/src/transport/plain_hdr.rs b/matter/src/transport/plain_hdr.rs index e5a9b24..5a0728a 100644 --- a/matter/src/transport/plain_hdr.rs +++ b/matter/src/transport/plain_hdr.rs @@ -21,7 +21,7 @@ use crate::utils::writebuf::WriteBuf; use bitflags::bitflags; use log::info; -#[derive(Debug, PartialEq, Default)] +#[derive(Debug, PartialEq, Eq, Default, Copy, Clone)] pub enum SessionType { #[default] None, @@ -38,7 +38,7 @@ bitflags! { } // This is the unencrypted message -#[derive(Debug, Default)] +#[derive(Debug, Default, Clone)] pub struct PlainHdr { pub flags: MsgFlags, pub sess_type: SessionType, diff --git a/matter/src/transport/proto_hdr.rs b/matter/src/transport/proto_hdr.rs index d7f92fb..9bf80d4 100644 --- a/matter/src/transport/proto_hdr.rs +++ b/matter/src/transport/proto_hdr.rs @@ -36,7 +36,7 @@ bitflags! { } } -#[derive(Default)] +#[derive(Debug, Default, Clone)] pub struct ProtoHdr { pub exch_id: u16, pub exch_flags: ExchFlags, @@ -278,7 +278,7 @@ mod tests { decrypt_in_place(recvd_ctr, 0, &mut parsebuf, &key).unwrap(); assert_eq!( - parsebuf.into_slice(), + parsebuf.as_slice(), [ 0x5, 0x8, 0x70, 0x0, 0x1, 0x0, 0x15, 0x28, 0x0, 0x28, 0x1, 0x36, 0x2, 0x15, 0x37, 0x0, 0x24, 0x0, 0x0, 0x24, 0x1, 0x30, 0x24, 0x2, 0x2, 0x18, 0x35, 0x1, 0x24, 0x0, diff --git a/matter/src/transport/session.rs b/matter/src/transport/session.rs index 1135f05..1e3a1d4 100644 --- a/matter/src/transport/session.rs +++ b/matter/src/transport/session.rs @@ -19,11 +19,8 @@ use crate::data_model::sdm::noc::NocData; use crate::utils::epoch::Epoch; use crate::utils::rand::Rand; use core::fmt; +use core::ops::{Deref, DerefMut}; use core::time::Duration; -use core::{ - any::Any, - ops::{Deref, DerefMut}, -}; use crate::{ error::*, @@ -166,7 +163,7 @@ impl Session { self.data = None; } - pub fn get_noc_data(&mut self) -> Option<&mut NocData> { + pub fn get_noc_data(&mut self) -> Option<&mut NocData> { self.data.as_mut() } @@ -325,17 +322,16 @@ pub const MAX_SESSIONS: usize = 16; pub struct SessionMgr { next_sess_id: u16, - sessions: [Option; MAX_SESSIONS], + sessions: heapless::Vec, MAX_SESSIONS>, epoch: Epoch, rand: Rand, } impl SessionMgr { + #[inline(always)] pub fn new(epoch: Epoch, rand: Rand) -> Self { - const INIT: Option = None; - Self { - sessions: [INIT; MAX_SESSIONS], + sessions: heapless::Vec::new(), next_sess_id: 1, epoch, rand, @@ -343,10 +339,10 @@ impl SessionMgr { } pub fn mut_by_index(&mut self, index: usize) -> Option<&mut Session> { - self.sessions[index].as_mut() + self.sessions.get_mut(index).and_then(Option::as_mut) } - fn get_next_sess_id(&mut self) -> u16 { + pub fn get_next_sess_id(&mut self) -> u16 { let mut next_sess_id: u16; loop { next_sess_id = self.next_sess_id; @@ -366,7 +362,7 @@ impl SessionMgr { } pub fn get_session_for_eviction(&self) -> Option { - if self.get_empty_slot().is_none() { + if self.sessions.len() == MAX_SESSIONS && self.get_empty_slot().is_none() { Some(self.get_lru()) } else { None @@ -380,8 +376,8 @@ impl SessionMgr { fn get_lru(&self) -> usize { let mut lru_index = 0; let mut lru_ts = (self.epoch)(); - for i in 0..MAX_SESSIONS { - if let Some(s) = &self.sessions[i] { + for (i, s) in self.sessions.iter().enumerate() { + if let Some(s) = s { if s.last_use < lru_ts { lru_ts = s.last_use; lru_index = i; @@ -405,10 +401,17 @@ impl SessionMgr { /// We could have returned a SessionHandle here. But the borrow checker doesn't support /// non-lexical lifetimes. This makes it harder for the caller of this function to take /// action in the error return path - pub fn add_session(&mut self, session: Session) -> Result { + fn add_session(&mut self, session: Session) -> Result { if let Some(index) = self.get_empty_slot() { self.sessions[index] = Some(session); Ok(index) + } else if self.sessions.len() < MAX_SESSIONS { + self.sessions + .push(Some(session)) + .map_err(|_| ErrorCode::NoSpace) + .unwrap(); + + Ok(self.sessions.len() - 1) } else { Err(ErrorCode::NoSpace.into()) } @@ -419,7 +422,7 @@ impl SessionMgr { self.add_session(session) } - fn _get( + pub fn get( &self, sess_id: u16, peer_addr: Address, @@ -451,14 +454,14 @@ impl SessionMgr { Some(self.get_session_handle(index)) } - pub fn get_or_add( + fn get_or_add( &mut self, sess_id: u16, peer_addr: Address, peer_nodeid: Option, is_encrypted: bool, ) -> Result { - if let Some(index) = self._get(sess_id, peer_addr, peer_nodeid, is_encrypted) { + if let Some(index) = self.get(sess_id, peer_addr, peer_nodeid, is_encrypted) { Ok(index) } else if sess_id == 0 && !is_encrypted { // We must create a new session for this case @@ -538,7 +541,7 @@ impl fmt::Display for SessionMgr { } pub struct SessionHandle<'a> { - sess_mgr: &'a mut SessionMgr, + pub(crate) sess_mgr: &'a mut SessionMgr, sess_idx: usize, } diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 909ab1e..7cf5288 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -15,64 +15,103 @@ * limitations under the License. */ -use crate::{error::*, MATTER_PORT}; -use log::{info, warn}; -use smol::net::{Ipv6Addr, UdpSocket}; +#[cfg(feature = "std")] +pub use smol_udp::*; -use super::network::Address; +#[cfg(not(feature = "std"))] +pub use dummy_udp::*; -// We could get rid of the smol here, but keeping it around in case we have to process -// any other events in this thread's context -pub struct UdpListener { - socket: UdpSocket, -} +#[cfg(feature = "std")] +mod smol_udp { + use crate::error::*; + use log::{debug, info, warn}; + use smol::net::UdpSocket; -impl UdpListener { - pub async fn new() -> Result { - let listener = UdpListener { - socket: UdpSocket::bind((Ipv6Addr::UNSPECIFIED, MATTER_PORT)).await?, - }; + use crate::transport::network::SocketAddr; - info!( - "Listening on {:?} port {}", - Ipv6Addr::UNSPECIFIED, - MATTER_PORT - ); - - Ok(listener) + pub struct UdpListener { + socket: UdpSocket, } - pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error> { - info!("Waiting for incoming packets"); + impl UdpListener { + pub async fn new(addr: SocketAddr) -> Result { + let listener = UdpListener { + socket: UdpSocket::bind((addr.ip(), addr.port())).await?, + }; - let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { - warn!("Error on the network: {:?}", e); - ErrorCode::Network - })?; + info!("Listening on {:?}", addr); - info!("Got packet: {:?} from addr {:?}", &in_buf[..size], addr); + Ok(listener) + } - Ok((size, Address::Udp(addr))) - } + pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + info!("Waiting for incoming packets"); - pub async fn send(&self, addr: Address, out_buf: &[u8]) -> Result { - match addr { - Address::Udp(addr) => { - let len = self.socket.send_to(out_buf, addr).await.map_err(|e| { - warn!("Error on the network: {:?}", e); - ErrorCode::Network - })?; + let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; - info!( - "Send packet: {:?} ({}/{}) to addr {:?}", - out_buf, - out_buf.len(), - len, - addr - ); + debug!("Got packet {:?} from addr {:?}", &in_buf[..size], addr); - Ok(len) - } + Ok((size, addr)) + } + + pub async fn send(&self, addr: SocketAddr, out_buf: &[u8]) -> Result { + let len = self.socket.send_to(out_buf, addr).await.map_err(|e| { + warn!("Error on the network: {:?}", e); + ErrorCode::Network + })?; + + debug!( + "Send packet {:?} ({}/{}) to addr {:?}", + out_buf, + out_buf.len(), + len, + addr + ); + + Ok(len) + } + } +} + +#[cfg(not(feature = "std"))] +mod dummy_udp { + use core::future::pending; + + use crate::error::*; + use log::{debug, info}; + + use crate::transport::network::SocketAddr; + + pub struct UdpListener {} + + impl UdpListener { + pub async fn new(addr: SocketAddr) -> Result { + let listener = UdpListener {}; + + info!("Pretending to listen on {:?}", addr); + + Ok(listener) + } + + pub async fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { + info!("Pretending to wait for incoming packets (looping forever)"); + + pending().await + } + + pub async fn send(&self, addr: SocketAddr, out_buf: &[u8]) -> Result { + debug!( + "Send packet {:?} ({}/{}) to addr {:?}", + out_buf, + out_buf.len(), + out_buf.len(), + addr + ); + + Ok(out_buf.len()) } } } diff --git a/matter/src/utils/mod.rs b/matter/src/utils/mod.rs index 1e69b84..5a3fe81 100644 --- a/matter/src/utils/mod.rs +++ b/matter/src/utils/mod.rs @@ -18,4 +18,5 @@ pub mod epoch; pub mod parsebuf; pub mod rand; +pub mod select; pub mod writebuf; diff --git a/matter/src/utils/parsebuf.rs b/matter/src/utils/parsebuf.rs index 549e022..233693c 100644 --- a/matter/src/utils/parsebuf.rs +++ b/matter/src/utils/parsebuf.rs @@ -35,13 +35,25 @@ impl<'a> ParseBuf<'a> { } } - pub fn set_len(&mut self, left: usize) { - self.left = left; + pub fn reset(&mut self) { + self.read_off = 0; + self.left = self.buf.len(); } - // Return the data that is valid as a slice, consume self - pub fn into_slice(self) -> &'a mut [u8] { - &mut self.buf[self.read_off..(self.read_off + self.left)] + pub fn load(&mut self, pb: &ParseBuf) -> Result<(), Error> { + if self.buf.len() < pb.read_off + pb.left { + Err(ErrorCode::NoSpace)?; + } + + self.buf[0..pb.read_off + pb.left].copy_from_slice(&pb.buf[..pb.read_off + pb.left]); + self.read_off = pb.read_off; + self.left = pb.left; + + Ok(()) + } + + pub fn set_len(&mut self, left: usize) { + self.left = left; } // Return the data that is valid as a slice @@ -114,7 +126,7 @@ mod tests { assert_eq!(buf.le_u8().unwrap(), 0x01); assert_eq!(buf.le_u16().unwrap(), 65); assert_eq!(buf.le_u32().unwrap(), 0xcafebabe); - assert_eq!(buf.into_slice(), [0xa, 0xb, 0xc, 0xd]); + assert_eq!(buf.as_slice(), [0xa, 0xb, 0xc, 0xd]); } #[test] @@ -138,7 +150,7 @@ mod tests { if buf.le_u8().is_ok() { panic!("This should have returned error") } - assert_eq!(buf.into_slice(), []); + assert_eq!(buf.as_slice(), [] as [u8; 0]); } #[test] @@ -154,7 +166,7 @@ mod tests { assert_eq!(buf.as_mut_slice(), [0xa, 0xb]); assert_eq!(buf.tail(2).unwrap(), [0xa, 0xb]); - assert_eq!(buf.into_slice(), []); + assert_eq!(buf.as_slice(), [] as [u8; 0]); } #[test] @@ -176,7 +188,7 @@ mod tests { let mut test_slice = [0x01, 65, 0, 0xbe, 0xba, 0xfe, 0xca, 0xa, 0xb, 0xc, 0xd]; let mut buf = ParseBuf::new(&mut test_slice); - assert_eq!(buf.parsed_as_slice(), []); + assert_eq!(buf.parsed_as_slice(), [] as [u8; 0]); assert_eq!(buf.le_u8().unwrap(), 0x1); assert_eq!(buf.le_u16().unwrap(), 65); assert_eq!(buf.le_u32().unwrap(), 0xcafebabe); diff --git a/matter/src/utils/select.rs b/matter/src/utils/select.rs new file mode 100644 index 0000000..2b5d21e --- /dev/null +++ b/matter/src/utils/select.rs @@ -0,0 +1,35 @@ +use embassy_futures::select::{Either, Either3, Either4}; + +pub trait EitherUnwrap { + fn unwrap(self) -> T; +} + +impl EitherUnwrap for Either { + fn unwrap(self) -> T { + match self { + Self::First(t) => t, + Self::Second(t) => t, + } + } +} + +impl EitherUnwrap for Either3 { + fn unwrap(self) -> T { + match self { + Self::First(t) => t, + Self::Second(t) => t, + Self::Third(t) => t, + } + } +} + +impl EitherUnwrap for Either4 { + fn unwrap(self) -> T { + match self { + Self::First(t) => t, + Self::Second(t) => t, + Self::Third(t) => t, + Self::Fourth(t) => t, + } + } +} diff --git a/matter/src/utils/writebuf.rs b/matter/src/utils/writebuf.rs index 21a51e2..2f24c97 100644 --- a/matter/src/utils/writebuf.rs +++ b/matter/src/utils/writebuf.rs @@ -68,6 +68,18 @@ impl<'a> WriteBuf<'a> { self.end = 0; } + pub fn load(&mut self, wb: &WriteBuf) -> Result<(), Error> { + if self.buf_size < wb.end { + Err(ErrorCode::NoSpace)?; + } + + self.buf[0..wb.end].copy_from_slice(&wb.buf[..wb.end]); + self.start = wb.start; + self.end = wb.end; + + Ok(()) + } + pub fn reserve(&mut self, reserve: usize) -> Result<(), Error> { if self.end != 0 || self.start != 0 || self.buf_size != self.buf.len() { Err(ErrorCode::Invalid.into())