From 488ef5b9f0decf5104247aec15af633d9bd9abdc Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 12 Jun 2023 09:47:20 +0000 Subject: [PATCH] Proper mDNS responder --- examples/onoff_light/src/main.rs | 22 +- matter/Cargo.toml | 2 +- matter/src/core.rs | 18 +- matter/src/data_model/root_endpoint.rs | 10 +- .../src/data_model/sdm/admin_commissioning.rs | 16 +- matter/src/data_model/sdm/noc.rs | 12 +- matter/src/fabric.rs | 16 +- matter/src/mdns.rs | 826 ++---------------- matter/src/mdns/astro.rs | 106 +++ matter/src/mdns/builtin.rs | 317 +++++++ matter/src/mdns/proto.rs | 508 +++++++++++ matter/src/secure_channel/core.rs | 8 +- matter/src/secure_channel/pake.rs | 24 +- matter/src/transport/udp.rs | 19 +- 14 files changed, 1064 insertions(+), 840 deletions(-) create mode 100644 matter/src/mdns/astro.rs create mode 100644 matter/src/mdns/builtin.rs create mode 100644 matter/src/mdns/proto.rs diff --git a/examples/onoff_light/src/main.rs b/examples/onoff_light/src/main.rs index 1bc9944..e26f20f 100644 --- a/examples/onoff_light/src/main.rs +++ b/examples/onoff_light/src/main.rs @@ -69,6 +69,16 @@ fn run() -> Result<(), Error> { core::mem::size_of::(), ); + let dev_det = BasicInfoConfig { + vid: 0xFFF1, + pid: 0x8000, + hw_ver: 2, + sw_ver: 1, + sw_ver_str: "1", + serial_no: "aabbccdd", + device_name: "OnOff Light", + }; + let (ipv4_addr, ipv6_addr) = initialize_network()?; let mdns = DefaultMdns::new( @@ -76,6 +86,8 @@ fn run() -> Result<(), Error> { "matter-demo", ipv4_addr.octets(), Some(ipv6_addr.octets()), + &dev_det, + matter::MATTER_PORT, ); let mut mdns_runner = DefaultMdnsRunner::new(&mdns); @@ -84,15 +96,7 @@ fn run() -> Result<(), Error> { let matter = Matter::new_default( // vid/pid should match those in the DAC - &BasicInfoConfig { - vid: 0xFFF1, - pid: 0x8000, - hw_ver: 2, - sw_ver: 1, - sw_ver_str: "1", - serial_no: "aabbccdd", - device_name: "OnOff Light", - }, + &dev_det, &dev_att, &mdns, matter::MATTER_PORT, diff --git a/matter/Cargo.toml b/matter/Cargo.toml index 410b30c..2f859b7 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -46,7 +46,7 @@ embassy-futures = "0.1" embassy-time = { version = "0.1.1", features = ["generic-queue-8"] } embassy-sync = "0.2" critical-section = "1.1.1" -domain = { version = "0.7.2", default_features = false } +domain = { version = "0.7.2", default_features = false, features = ["heapless"] } # STD-only dependencies rand = { version = "0.8.5", optional = true } diff --git a/matter/src/core.rs b/matter/src/core.rs index c6c5dc1..35c8677 100644 --- a/matter/src/core.rs +++ b/matter/src/core.rs @@ -25,7 +25,7 @@ use crate::{ }, error::*, fabric::FabricMgr, - mdns::{Mdns, MdnsMgr}, + mdns::Mdns, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, secure_channel::{pake::PaseMgr, spake2p::VerifierData}, utils::{epoch::Epoch, rand::Rand}, @@ -48,7 +48,7 @@ pub struct Matter<'a> { pub acl_mgr: RefCell, pub pase_mgr: RefCell, pub failsafe: RefCell, - pub mdns_mgr: MdnsMgr<'a>, + pub mdns: &'a dyn Mdns, pub epoch: Epoch, pub rand: Rand, pub dev_det: &'a BasicInfoConfig<'a>, @@ -91,7 +91,7 @@ impl<'a> Matter<'a> { acl_mgr: RefCell::new(AclMgr::new()), pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)), failsafe: RefCell::new(FailSafe::new()), - mdns_mgr: MdnsMgr::new(dev_det.vid, dev_det.pid, dev_det.device_name, port, mdns), + mdns, epoch, rand, dev_det, @@ -113,7 +113,7 @@ impl<'a> Matter<'a> { } pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> { - self.fabric_mgr.borrow_mut().load(data, &self.mdns_mgr) + self.fabric_mgr.borrow_mut().load(data, self.mdns) } pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> { @@ -149,7 +149,7 @@ impl<'a> Matter<'a> { self.pase_mgr.borrow_mut().enable_pase_session( dev_comm.verifier, dev_comm.discriminator, - &self.mdns_mgr, + self.mdns, )?; Ok(true) @@ -183,12 +183,6 @@ impl<'a> Borrow> for Matter<'a> { } } -impl<'a> Borrow> for Matter<'a> { - fn borrow(&self) -> &MdnsMgr<'a> { - &self.mdns_mgr - } -} - impl<'a> Borrow> for Matter<'a> { fn borrow(&self) -> &BasicInfoConfig<'a> { self.dev_det @@ -203,7 +197,7 @@ impl<'a> Borrow for Matter<'a> { impl<'a> Borrow for Matter<'a> { fn borrow(&self) -> &(dyn Mdns + 'a) { - self.mdns_mgr.mdns + self.mdns } } diff --git a/matter/src/data_model/root_endpoint.rs b/matter/src/data_model/root_endpoint.rs index 1bc22fe..78b8cfb 100644 --- a/matter/src/data_model/root_endpoint.rs +++ b/matter/src/data_model/root_endpoint.rs @@ -4,7 +4,7 @@ use crate::{ acl::AclMgr, fabric::FabricMgr, handler_chain_type, - mdns::MdnsMgr, + mdns::Mdns, secure_channel::pake::PaseMgr, utils::{epoch::Epoch, rand::Rand}, }; @@ -62,7 +62,7 @@ where + Borrow> + Borrow> + Borrow> - + Borrow> + + Borrow + Borrow + Borrow + 'a, @@ -90,7 +90,7 @@ pub fn wrap<'a>( fabric: &'a RefCell, acl: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, epoch: Epoch, rand: Rand, ) -> RootEndpointHandler<'a> { @@ -103,12 +103,12 @@ pub fn wrap<'a>( .chain( endpoint_id, noc::ID, - NocCluster::new(dev_att, fabric, acl, failsafe, mdns_mgr, epoch, rand), + NocCluster::new(dev_att, fabric, acl, failsafe, mdns, epoch, rand), ) .chain( endpoint_id, admin_commissioning::ID, - AdminCommCluster::new(pase, mdns_mgr, rand), + AdminCommCluster::new(pase, mdns, rand), ) .chain(endpoint_id, nw_commissioning::ID, NwCommCluster::new(rand)) .chain( diff --git a/matter/src/data_model/sdm/admin_commissioning.rs b/matter/src/data_model/sdm/admin_commissioning.rs index 9364311..15c803f 100644 --- a/matter/src/data_model/sdm/admin_commissioning.rs +++ b/matter/src/data_model/sdm/admin_commissioning.rs @@ -20,7 +20,7 @@ use core::convert::TryInto; use crate::data_model::objects::*; use crate::interaction_model::core::Transaction; -use crate::mdns::MdnsMgr; +use crate::mdns::Mdns; use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::spake2p::VerifierData; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; @@ -102,15 +102,15 @@ pub struct OpenCommWindowReq<'a> { pub struct AdminCommCluster<'a> { data_ver: Dataver, pase_mgr: &'a RefCell, - mdns_mgr: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, } impl<'a> AdminCommCluster<'a> { - pub fn new(pase_mgr: &'a RefCell, mdns_mgr: &'a MdnsMgr<'a>, rand: Rand) -> Self { + pub fn new(pase_mgr: &'a RefCell, mdns: &'a dyn Mdns, rand: Rand) -> Self { Self { data_ver: Dataver::new(rand), pase_mgr, - mdns_mgr, + mdns, } } @@ -152,11 +152,9 @@ impl<'a> AdminCommCluster<'a> { cmd_enter!("Open Commissioning Window"); let req = OpenCommWindowReq::from_tlv(data)?; let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0); - self.pase_mgr.borrow_mut().enable_pase_session( - verifier, - req.discriminator, - self.mdns_mgr, - )?; + self.pase_mgr + .borrow_mut() + .enable_pase_session(verifier, req.discriminator, self.mdns)?; Ok(()) } diff --git a/matter/src/data_model/sdm/noc.rs b/matter/src/data_model/sdm/noc.rs index f347b13..7fb1e37 100644 --- a/matter/src/data_model/sdm/noc.rs +++ b/matter/src/data_model/sdm/noc.rs @@ -25,7 +25,7 @@ use crate::data_model::objects::*; use crate::data_model::sdm::dev_att; use crate::fabric::{Fabric, FabricMgr, MAX_SUPPORTED_FABRICS}; use crate::interaction_model::core::Transaction; -use crate::mdns::MdnsMgr; +use crate::mdns::Mdns; use crate::tlv::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV, UtfStr}; use crate::transport::session::SessionMode; use crate::utils::epoch::Epoch; @@ -222,7 +222,7 @@ pub struct NocCluster<'a> { fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, } impl<'a> NocCluster<'a> { @@ -231,7 +231,7 @@ impl<'a> NocCluster<'a> { fabric_mgr: &'a RefCell, acl_mgr: &'a RefCell, failsafe: &'a RefCell, - mdns_mgr: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, epoch: Epoch, rand: Rand, ) -> Self { @@ -243,7 +243,7 @@ impl<'a> NocCluster<'a> { fabric_mgr, acl_mgr, failsafe, - mdns_mgr, + mdns, } } @@ -383,7 +383,7 @@ impl<'a> NocCluster<'a> { let fab_idx = self .fabric_mgr .borrow_mut() - .add(fabric, self.mdns_mgr) + .add(fabric, self.mdns) .map_err(|_| NocStatus::TableFull)?; self.add_acl(fab_idx, r.case_admin_subject)?; @@ -455,7 +455,7 @@ impl<'a> NocCluster<'a> { if self .fabric_mgr .borrow_mut() - .remove(req.fab_idx, self.mdns_mgr) + .remove(req.fab_idx, self.mdns) .is_ok() { let _ = self.acl_mgr.borrow_mut().delete_for_fabric(req.fab_idx); diff --git a/matter/src/fabric.rs b/matter/src/fabric.rs index 04369ca..8959407 100644 --- a/matter/src/fabric.rs +++ b/matter/src/fabric.rs @@ -26,7 +26,7 @@ use crate::{ crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, error::{Error, ErrorCode}, group_keys::KeySet, - mdns::{MdnsMgr, ServiceMode}, + mdns::{Mdns, ServiceMode}, tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr}, utils::writebuf::WriteBuf, }; @@ -200,9 +200,9 @@ impl FabricMgr { } } - pub fn load(&mut self, data: &[u8], mdns_mgr: &MdnsMgr) -> Result<(), Error> { + pub fn load(&mut self, data: &[u8], mdns: &dyn Mdns) -> Result<(), Error> { for fabric in self.fabrics.iter().flatten() { - mdns_mgr.unpublish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; + mdns.remove(&fabric.mdns_service_name)?; } let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?; @@ -210,7 +210,7 @@ impl FabricMgr { tlv::from_tlv(&mut self.fabrics, &root)?; for fabric in self.fabrics.iter().flatten() { - mdns_mgr.publish_service(&fabric.mdns_service_name, ServiceMode::Commissioned)?; + mdns.add(&fabric.mdns_service_name, ServiceMode::Commissioned)?; } self.changed = false; @@ -241,11 +241,11 @@ impl FabricMgr { self.changed } - pub fn add(&mut self, f: Fabric, mdns_mgr: &MdnsMgr) -> Result { + pub fn add(&mut self, f: Fabric, mdns: &dyn Mdns) -> Result { let slot = self.fabrics.iter().position(|x| x.is_none()); if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS { - mdns_mgr.publish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + mdns.add(&f.mdns_service_name, ServiceMode::Commissioned)?; self.changed = true; if let Some(index) = slot { @@ -265,10 +265,10 @@ impl FabricMgr { } } - pub fn remove(&mut self, fab_idx: u8, mdns_mgr: &MdnsMgr) -> Result<(), Error> { + pub fn remove(&mut self, fab_idx: u8, mdns: &dyn Mdns) -> Result<(), Error> { if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() { if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() { - mdns_mgr.unpublish_service(&f.mdns_service_name, ServiceMode::Commissioned)?; + mdns.remove(&f.mdns_service_name)?; self.changed = true; Ok(()) } else { diff --git a/matter/src/mdns.rs b/matter/src/mdns.rs index 897ec32..d07ba10 100644 --- a/matter/src/mdns.rs +++ b/matter/src/mdns.rs @@ -17,40 +17,28 @@ use core::fmt::Write; -use crate::error::Error; +use crate::{data_model::cluster_basic_information::BasicInfoConfig, error::Error}; + +#[cfg(all(feature = "std", feature = "astro-dnssd"))] +pub mod astro; +pub mod builtin; +pub mod proto; pub trait Mdns { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error>; - - fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error>; + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error>; + fn remove(&self, service: &str) -> Result<(), Error>; } impl Mdns for &mut T where T: Mdns, { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - (**self).add(name, service, protocol, port, service_subtypes, txt_kvs) + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + (**self).add(service, mode) } - fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error> { - (**self).remove(name, service, protocol, port) + fn remove(&self, service: &str) -> Result<(), Error> { + (**self).remove(service) } } @@ -58,25 +46,17 @@ impl Mdns for &T where T: Mdns, { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - (**self).add(name, service, protocol, port, service_subtypes, txt_kvs) + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + (**self).add(service, mode) } - fn remove(&self, name: &str, service: &str, protocol: &str, port: u16) -> Result<(), Error> { - (**self).remove(name, service, protocol, port) + fn remove(&self, service: &str) -> Result<(), Error> { + (**self).remove(service) } } #[cfg(all(feature = "std", feature = "astro-dnssd"))] -pub type DefaultMdns = astro::Mdns; +pub type DefaultMdns<'a> = astro::Mdns<'a>; #[cfg(all(feature = "std", feature = "astro-dnssd"))] pub type DefaultMdnsRunner<'a> = astro::MdnsRunner<'a>; @@ -90,29 +70,18 @@ pub type DefaultMdnsRunner<'a> = builtin::MdnsRunner<'a>; pub struct DummyMdns; impl Mdns for DummyMdns { - fn add( - &self, - _name: &str, - _service: &str, - _protocol: &str, - _port: u16, - _service_subtypes: &[&str], - _txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { + fn add(&self, _service: &str, _mode: ServiceMode) -> Result<(), Error> { Ok(()) } - fn remove( - &self, - _name: &str, - _service: &str, - _protocol: &str, - _port: u16, - ) -> Result<(), Error> { + fn remove(&self, _service: &str) -> Result<(), Error> { Ok(()) } } +pub type Service<'a> = proto::Service<'a>; + +#[derive(Debug, Clone, Eq, PartialEq)] pub enum ServiceMode { /// The commissioned state Commissioned, @@ -120,56 +89,31 @@ pub enum ServiceMode { Commissionable(u16), } -/// The mDNS service handler -pub struct MdnsMgr<'a> { - /// Vendor ID - vid: u16, - /// Product ID - pid: u16, - /// Device name - device_name: &'a str, - /// Matter port - matter_port: u16, - /// mDns service - pub(crate) mdns: &'a dyn Mdns, -} - -impl<'a> MdnsMgr<'a> { - #[inline(always)] - pub fn new( - vid: u16, - pid: u16, - device_name: &'a str, +impl ServiceMode { + pub fn service FnOnce(&Service<'a>) -> Result>( + &self, + dev_att: &BasicInfoConfig, matter_port: u16, - mdns: &'a dyn Mdns, - ) -> Self { - Self { - vid, - pid, - device_name, - matter_port, - mdns, - } - } - - /// Publish an mDNS service - /// name - is the service name (comma separated subtypes may follow) - /// mode - the current service mode - #[allow(clippy::needless_pass_by_value)] - pub fn publish_service(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { - match mode { - ServiceMode::Commissioned => { - self.mdns - .add(name, "_matter", "_tcp", self.matter_port, &[], &[]) - } + name: &str, + f: F, + ) -> Result { + match self { + Self::Commissioned => f(&Service { + name, + service: "_matter", + protocol: "_tcp", + port: matter_port, + service_subtypes: &[], + txt_kvs: &[], + }), ServiceMode::Commissionable(discriminator) => { - let discriminator_str = Self::get_discriminator_str(discriminator); - let vp = self.get_vp(); + let discriminator_str = Self::get_discriminator_str(*discriminator); + let vp = Self::get_vp(dev_att.vid, dev_att.pid); - let txt_kvs = [ + let txt_kvs = &[ ("D", discriminator_str.as_str()), ("CM", "1"), - ("DN", self.device_name), + ("DN", dev_att.device_name), ("VP", &vp), ("SII", "5000"), /* Sleepy Idle Interval */ ("SAI", "300"), /* Sleepy Active Interval */ @@ -177,40 +121,29 @@ impl<'a> MdnsMgr<'a> { ("PI", ""), /* Pairing Instruction */ ]; - self.mdns.add( + f(&Service { name, - "_matterc", - "_udp", - self.matter_port, - &[ - &self.get_long_service_subtype(discriminator), - &self.get_short_service_type(discriminator), + service: "_matterc", + protocol: "_udp", + port: matter_port, + service_subtypes: &[ + &Self::get_long_service_subtype(*discriminator), + &Self::get_short_service_type(*discriminator), ], - &txt_kvs, - ) + txt_kvs, + }) } } } - pub fn unpublish_service(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { - match mode { - ServiceMode::Commissioned => { - self.mdns.remove(name, "_matter", "_tcp", self.matter_port) - } - ServiceMode::Commissionable(_) => { - self.mdns.remove(name, "_matterc", "_udp", self.matter_port) - } - } - } - - fn get_long_service_subtype(&self, discriminator: u16) -> heapless::String<32> { + fn get_long_service_subtype(discriminator: u16) -> heapless::String<32> { let mut serv_type = heapless::String::new(); write!(&mut serv_type, "_L{}", discriminator).unwrap(); serv_type } - fn get_short_service_type(&self, discriminator: u16) -> heapless::String<32> { + fn get_short_service_type(discriminator: u16) -> heapless::String<32> { let short = Self::compute_short_discriminator(discriminator); let mut serv_type = heapless::String::new(); @@ -223,10 +156,10 @@ impl<'a> MdnsMgr<'a> { discriminator.into() } - fn get_vp(&self) -> heapless::String<11> { + fn get_vp(vid: u16, pid: u16) -> heapless::String<11> { let mut vp = heapless::String::new(); - write!(&mut vp, "{}+{}", self.vid, self.pid).unwrap(); + write!(&mut vp, "{}+{}", vid, pid).unwrap(); vp } @@ -239,651 +172,6 @@ impl<'a> MdnsMgr<'a> { } } -pub mod builtin { - use core::cell::RefCell; - use core::fmt::Write; - use core::mem::MaybeUninit; - use core::pin::pin; - use core::str::FromStr; - - use domain::base::header::Flags; - use domain::base::iana::Class; - use domain::base::octets::{Octets256, Octets64, OctetsBuilder}; - use domain::base::{Dname, MessageBuilder, Record, ShortBuf}; - use domain::rdata::{Aaaa, Ptr, Srv, Txt, A}; - use embassy_futures::select::{select, select3}; - use embassy_time::{Duration, Timer}; - use log::info; - - use crate::error::{Error, ErrorCode}; - use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; - use crate::transport::packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; - use crate::transport::pipe::{Chunk, Pipe}; - use crate::transport::udp::UdpListener; - use crate::utils::select::{EitherUnwrap, Notification}; - - const IP_BROADCAST_ADDRS: [(IpAddr, u16); 2] = [ - (IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), - ( - IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)), - 5353, - ), - ]; - - const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); - - type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; - type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; - - #[allow(clippy::too_many_arguments)] - pub fn create_record( - id: u16, - hostname: &str, - ip: [u8; 4], - ipv6: Option<[u8; 16]>, - - ttl_sec: u32, - - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - - buffer: &mut [u8], - ) -> Result { - let target = domain::base::octets::Octets2048::new(); - let message = MessageBuilder::from_target(target)?; - - let mut message = message.answer(); - - let mut ptr_str = heapless::String::<40>::new(); - write!(ptr_str, "{}.{}.local", service, protocol).unwrap(); - - let mut dname = heapless::String::<60>::new(); - write!(dname, "{}.{}.{}.local", name, service, protocol).unwrap(); - - let mut hname = heapless::String::<40>::new(); - write!(hname, "{}.local", hostname).unwrap(); - - let ptr: Dname = Dname::from_str(&ptr_str).unwrap(); - let record: Record, Ptr<_>> = Record::new( - Dname::from_str("_services._dns-sd._udp.local").unwrap(), - Class::In, - ttl_sec, - Ptr::new(ptr), - ); - message.push(record)?; - - let t: Dname = Dname::from_str(&dname).unwrap(); - let record: Record, Ptr<_>> = Record::new( - Dname::from_str(&ptr_str).unwrap(), - Class::In, - ttl_sec, - Ptr::new(t), - ); - message.push(record)?; - - for sub_srv in service_subtypes { - let mut ptr_str = heapless::String::<40>::new(); - write!(ptr_str, "{}._sub.{}.{}.local", sub_srv, service, protocol).unwrap(); - - let ptr: Dname = Dname::from_str(&ptr_str).unwrap(); - let record: Record, Ptr<_>> = Record::new( - Dname::from_str("_services._dns-sd._udp.local").unwrap(), - Class::In, - ttl_sec, - Ptr::new(ptr), - ); - message.push(record)?; - - let t: Dname = Dname::from_str(&dname).unwrap(); - let record: Record, Ptr<_>> = Record::new( - Dname::from_str(&ptr_str).unwrap(), - Class::In, - ttl_sec, - Ptr::new(t), - ); - message.push(record)?; - } - - let target: Dname = Dname::from_str(&hname).unwrap(); - let record: Record, Srv<_>> = Record::new( - Dname::from_str(&dname).unwrap(), - Class::In, - ttl_sec, - Srv::new(0, 0, port, target), - ); - message.push(record)?; - - // only way I found to create multiple parts in a Txt - // each slice is the length and then the data - let mut octets = Octets256::new(); - //octets.append_slice(&[1u8, b'X']).unwrap(); - //octets.append_slice(&[2u8, b'A', b'B']).unwrap(); - //octets.append_slice(&[0u8]).unwrap(); - for (k, v) in txt_kvs { - octets - .append_slice(&[(k.len() + v.len() + 1) as u8]) - .unwrap(); - octets.append_slice(k.as_bytes()).unwrap(); - octets.append_slice(&[b'=']).unwrap(); - octets.append_slice(v.as_bytes()).unwrap(); - } - - let txt = Txt::from_octets(&mut octets).unwrap(); - - let record: Record, Txt<_>> = - Record::new(Dname::from_str(&dname).unwrap(), Class::In, ttl_sec, txt); - message.push(record)?; - - let record: Record, A> = Record::new( - Dname::from_str(&hname).unwrap(), - Class::In, - ttl_sec, - A::from_octets(ip[0], ip[1], ip[2], ip[3]), - ); - message.push(record)?; - - if let Some(ipv6) = ipv6 { - let record: Record, Aaaa> = Record::new( - Dname::from_str(&hname).unwrap(), - Class::In, - ttl_sec, - Aaaa::new(ipv6.into()), - ); - message.push(record)?; - } - - let headerb = message.header_mut(); - headerb.set_id(id); - headerb.set_opcode(domain::base::iana::Opcode::Query); - headerb.set_rcode(domain::base::iana::Rcode::NoError); - - let mut flags = Flags::new(); - flags.qr = true; - flags.aa = true; - headerb.set_flags(flags); - - let target = message.finish(); - - buffer[..target.len()].copy_from_slice(target.as_ref()); - - Ok(target.len()) - } - - #[derive(Debug, Clone)] - struct MdnsEntry { - key: heapless::String<64>, - record: heapless::Vec, - } - - impl MdnsEntry { - #[inline(always)] - const fn new(key: heapless::String<64>) -> Self { - Self { - key, - record: heapless::Vec::new(), - } - } - } - - pub struct Mdns<'a> { - id: u16, - hostname: &'a str, - ip: [u8; 4], - ipv6: Option<[u8; 16]>, - entries: RefCell>, - notification: Notification, - } - - impl<'a> Mdns<'a> { - #[inline(always)] - pub const fn new(id: u16, hostname: &'a str, ip: [u8; 4], ipv6: Option<[u8; 16]>) -> Self { - Self { - id, - hostname, - ip, - ipv6, - entries: RefCell::new(heapless::Vec::new()), - notification: Notification::new(), - } - } - - pub fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - info!( - "Registering mDNS service {}/{}.{} [{:?}]/{}, keys [{:?}]", - name, service, protocol, service_subtypes, port, txt_kvs - ); - - let key = self.key(name, service, protocol, port); - - let mut entries = self.entries.borrow_mut(); - - entries.retain(|entry| entry.key != key); - entries - .push(MdnsEntry::new(key)) - .map_err(|_| ErrorCode::NoSpace)?; - - let entry = entries.iter_mut().last().unwrap(); - entry - .record - .resize(1024, 0) - .map_err(|_| ErrorCode::NoSpace) - .unwrap(); - - match create_record( - self.id, - self.hostname, - self.ip, - self.ipv6, - 60, /*ttl_sec*/ - name, - service, - protocol, - port, - service_subtypes, - txt_kvs, - &mut entry.record, - ) { - Ok(len) => entry.record.truncate(len), - Err(_) => { - entries.pop(); - Err(ErrorCode::NoSpace)?; - } - } - - self.notification.signal(()); - - Ok(()) - } - - pub fn remove( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - info!( - "Deregistering mDNS service {}/{}.{}/{}", - name, service, protocol, port - ); - - let key = self.key(name, service, protocol, port); - - let mut entries = self.entries.borrow_mut(); - - let old_len = entries.len(); - - entries.retain(|entry| entry.key != key); - - if entries.len() != old_len { - self.notification.signal(()); - } - - Ok(()) - } - - fn key( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> heapless::String<64> { - let mut key = heapless::String::new(); - - write!(&mut key, "{name}.{service}.{protocol}.{port}").unwrap(); - - key - } - } - - pub struct MdnsRunner<'a>(&'a Mdns<'a>); - - impl<'a> MdnsRunner<'a> { - pub const fn new(mdns: &'a Mdns<'a>) -> Self { - Self(mdns) - } - - pub async fn run_udp(&mut self) -> Result<(), Error> { - let mut tx_buf = MdnsTxBuf::uninit(); - let mut rx_buf = MdnsRxBuf::uninit(); - - let tx_buf = &mut tx_buf; - let rx_buf = &mut rx_buf; - - let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); - let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); - - let tx_pipe = &tx_pipe; - let rx_pipe = &rx_pipe; - - let udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; - let udp = &udp; - - let mut tx = pin!(async move { - loop { - { - let mut data = tx_pipe.data.lock().await; - - if let Some(chunk) = data.chunk { - udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) - .await?; - data.chunk = None; - tx_pipe.data_consumed_notification.signal(()); - } - } - - tx_pipe.data_supplied_notification.wait().await; - } - }); - - let mut rx = pin!(async move { - loop { - { - let mut data = rx_pipe.data.lock().await; - - if data.chunk.is_none() { - let (len, addr) = udp.recv(data.buf).await?; - - data.chunk = Some(Chunk { - start: 0, - end: len, - addr: Address::Udp(addr), - }); - rx_pipe.data_supplied_notification.signal(()); - } - } - - rx_pipe.data_consumed_notification.wait().await; - } - }); - - let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await }); - - select3(&mut tx, &mut rx, &mut run).await.unwrap() - } - - pub async fn run(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { - let mut broadcast = pin!(self.broadcast(tx_pipe)); - let mut respond = pin!(self.respond(rx_pipe, tx_pipe)); - - select(&mut broadcast, &mut respond).await.unwrap() - } - - #[allow(clippy::await_holding_refcell_ref)] - async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { - loop { - select( - self.0.notification.wait(), - Timer::after(Duration::from_secs(30)), - ) - .await; - - let mut index = 0; - - 'outer: loop { - for (addr, port) in IP_BROADCAST_ADDRS { - loop { - { - let mut data = tx_pipe.data.lock().await; - - if data.chunk.is_none() { - let entries = self.0.entries.borrow(); - let entry = entries.get(index); - - if let Some(entry) = entry { - info!( - "Broadasting mDNS entry {} on {}:{}", - &entry.key, addr, port - ); - - let len = entry.record.len(); - data.buf[..len].copy_from_slice(&entry.record); - drop(entries); - - data.chunk = Some(Chunk { - start: 0, - end: len, - addr: Address::Udp(SocketAddr::new(addr, port)), - }); - - tx_pipe.data_supplied_notification.signal(()); - } else { - break 'outer; - } - - break; - } - } - - tx_pipe.data_consumed_notification.wait().await; - } - } - - index += 1; - } - } - } - - #[allow(clippy::await_holding_refcell_ref)] - async fn respond(&self, rx_pipe: &Pipe<'_>, _tx_pipe: &Pipe<'_>) -> Result<(), Error> { - loop { - { - let mut data = rx_pipe.data.lock().await; - - if let Some(_chunk) = data.chunk { - // TODO: Process the incoming packed and only answer what we are being queried about - - data.chunk = None; - rx_pipe.data_consumed_notification.signal(()); - - self.0.notification.signal(()); - } - } - - rx_pipe.data_supplied_notification.wait().await; - } - } - } - - impl<'a> super::Mdns for Mdns<'a> { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - Mdns::add( - self, - name, - service, - protocol, - port, - service_subtypes, - txt_kvs, - ) - } - - fn remove( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - Mdns::remove(self, name, service, protocol, port) - } - } -} - -#[cfg(all(feature = "std", feature = "astro-dnssd"))] -pub mod astro { - use core::cell::RefCell; - use std::collections::HashMap; - - use crate::{ - error::{Error, ErrorCode}, - transport::pipe::Pipe, - }; - use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; - use log::info; - - #[derive(Debug, Clone, Eq, PartialEq, Hash)] - struct ServiceId { - name: String, - service: String, - protocol: String, - port: u16, - } - - pub struct Mdns { - services: RefCell>, - } - - impl Mdns { - pub fn new(_id: u16, _hostname: &str, _ip: [u8; 4], _ipv6: Option<[u8; 16]>) -> Self { - Self::native_new() - } - - pub fn native_new() -> Self { - Self { - services: RefCell::new(HashMap::new()), - } - } - - pub fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - info!( - "Registering mDNS service {}/{}.{} [{:?}]/{}", - name, service, protocol, service_subtypes, port - ); - - let _ = self.remove(name, service, protocol, port); - - let composite_service_type = if !service_subtypes.is_empty() { - format!("{}.{},{}", service, protocol, service_subtypes.join(",")) - } else { - format!("{}.{}", service, protocol) - }; - - let mut builder = DNSServiceBuilder::new(&composite_service_type, port).with_name(name); - - for kvs in txt_kvs { - info!("mDNS TXT key {} val {}", kvs.0, kvs.1); - builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); - } - - let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; - - self.services.borrow_mut().insert( - ServiceId { - name: name.into(), - service: service.into(), - protocol: protocol.into(), - port, - }, - svc, - ); - - Ok(()) - } - - pub fn remove( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - let id = ServiceId { - name: name.into(), - service: service.into(), - protocol: protocol.into(), - port, - }; - - if self.services.borrow_mut().remove(&id).is_some() { - info!( - "Deregistering mDNS service {}/{}.{}/{}", - name, service, protocol, port - ); - } - - Ok(()) - } - } - - pub struct MdnsRunner<'a>(&'a Mdns); - - impl<'a> MdnsRunner<'a> { - pub const fn new(mdns: &'a Mdns) -> Self { - Self(mdns) - } - - pub async fn run_udp(&mut self) -> Result<(), Error> { - core::future::pending::>().await - } - - pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { - core::future::pending::>().await - } - } - - impl super::Mdns for Mdns { - fn add( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - service_subtypes: &[&str], - txt_kvs: &[(&str, &str)], - ) -> Result<(), Error> { - Mdns::add( - self, - name, - service, - protocol, - port, - service_subtypes, - txt_kvs, - ) - } - - fn remove( - &self, - name: &str, - service: &str, - protocol: &str, - port: u16, - ) -> Result<(), Error> { - Mdns::remove(self, name, service, protocol, port) - } - } -} - #[cfg(test)] mod tests { use super::*; @@ -891,11 +179,11 @@ mod tests { #[test] fn can_compute_short_discriminator() { let discriminator: u16 = 0b0000_1111_0000_0000; - let short = MdnsMgr::compute_short_discriminator(discriminator); + let short = ServiceMode::compute_short_discriminator(discriminator); assert_eq!(short, 0b1111); let discriminator: u16 = 840; - let short = MdnsMgr::compute_short_discriminator(discriminator); + let short = ServiceMode::compute_short_discriminator(discriminator); assert_eq!(short, 3); } } diff --git a/matter/src/mdns/astro.rs b/matter/src/mdns/astro.rs new file mode 100644 index 0000000..12426cb --- /dev/null +++ b/matter/src/mdns/astro.rs @@ -0,0 +1,106 @@ +use core::cell::RefCell; +use std::collections::HashMap; + +use crate::{ + data_model::cluster_basic_information::BasicInfoConfig, + error::{Error, ErrorCode}, + transport::pipe::Pipe, +}; +use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService}; +use log::info; + +use super::ServiceMode; + +pub struct Mdns<'a> { + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + services: RefCell>, +} + +impl<'a> Mdns<'a> { + pub fn new( + _id: u16, + _hostname: &str, + _ip: [u8; 4], + _ipv6: Option<[u8; 16]>, + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + ) -> Self { + Self::native_new(dev_det, matter_port) + } + + pub fn native_new(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> Self { + Self { + dev_det, + matter_port, + services: RefCell::new(HashMap::new()), + } + } + + pub fn add(&self, name: &str, mode: ServiceMode) -> Result<(), Error> { + info!("Registering mDNS service {}/{:?}", name, mode); + + let _ = self.remove(name); + + mode.service(self.dev_det, self.matter_port, name, |service| { + let composite_service_type = if !service.service_subtypes.is_empty() { + format!( + "{}.{},{}", + service.service, + service.protocol, + service.service_subtypes.join(",") + ) + } else { + format!("{}.{}", service.service, service.protocol) + }; + + let mut builder = DNSServiceBuilder::new(&composite_service_type, service.port) + .with_name(service.name); + + for kvs in service.txt_kvs { + info!("mDNS TXT key {} val {}", kvs.0, kvs.1); + builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string()); + } + + let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?; + + self.services.borrow_mut().insert(service.name.into(), svc); + + Ok(()) + }) + } + + pub fn remove(&self, name: &str) -> Result<(), Error> { + if self.services.borrow_mut().remove(name).is_some() { + info!("Deregistering mDNS service {}", name); + } + + Ok(()) + } +} + +pub struct MdnsRunner<'a>(&'a Mdns<'a>); + +impl<'a> MdnsRunner<'a> { + pub const fn new(mdns: &'a Mdns<'a>) -> Self { + Self(mdns) + } + + pub async fn run_udp(&mut self) -> Result<(), Error> { + core::future::pending::>().await + } + + pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> { + core::future::pending::>().await + } +} + +impl<'a> super::Mdns for Mdns<'a> { + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + Mdns::add(self, service, mode) + } + + fn remove(&self, service: &str) -> Result<(), Error> { + Mdns::remove(self, service) + } +} diff --git a/matter/src/mdns/builtin.rs b/matter/src/mdns/builtin.rs new file mode 100644 index 0000000..95c6ad6 --- /dev/null +++ b/matter/src/mdns/builtin.rs @@ -0,0 +1,317 @@ +use core::{cell::RefCell, mem::MaybeUninit, pin::pin}; + +use domain::base::name::FromStrError; +use domain::base::{octets::ParseError, ShortBuf}; +use embassy_futures::select::{select, select3}; +use embassy_time::{Duration, Timer}; +use log::info; + +use crate::data_model::cluster_basic_information::BasicInfoConfig; +use crate::error::{Error, ErrorCode}; +use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use crate::transport::packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE}; +use crate::transport::pipe::{Chunk, Pipe}; +use crate::transport::udp::UdpListener; +use crate::utils::select::{EitherUnwrap, Notification}; + +use super::{ + proto::{Host, Services}, + Service, ServiceMode, +}; + +const IP_BROADCAST_ADDRS: [(IpAddr, u16); 2] = [ + (IpAddr::V4(Ipv4Addr::new(224, 0, 0, 251)), 5353), + ( + IpAddr::V6(Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb)), + 5353, + ), +]; + +const IP_BIND_ADDR: (IpAddr, u16) = (IpAddr::V6(Ipv6Addr::UNSPECIFIED), 5353); + +type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>; +type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>; + +pub struct Mdns<'a> { + host: Host<'a>, + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + services: RefCell, ServiceMode), 4>>, + notification: Notification, +} + +impl<'a> Mdns<'a> { + #[inline(always)] + pub const fn new( + id: u16, + hostname: &'a str, + ip: [u8; 4], + ipv6: Option<[u8; 16]>, + dev_det: &'a BasicInfoConfig<'a>, + matter_port: u16, + ) -> Self { + Self { + host: Host { + id, + hostname, + ip, + ipv6, + }, + dev_det, + matter_port, + services: RefCell::new(heapless::Vec::new()), + notification: Notification::new(), + } + } + + pub fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + let mut services = self.services.borrow_mut(); + + services.retain(|(name, _)| name != service); + services + .push((service.into(), mode)) + .map_err(|_| ErrorCode::NoSpace)?; + + self.notification.signal(()); + + Ok(()) + } + + pub fn remove(&self, service: &str) -> Result<(), Error> { + let mut services = self.services.borrow_mut(); + + services.retain(|(name, _)| name != service); + + Ok(()) + } + + pub fn for_each(&self, mut callback: F) -> Result<(), Error> + where + F: FnMut(&Service) -> Result<(), Error>, + { + let services = self.services.borrow(); + + for (service, mode) in &*services { + mode.service(self.dev_det, self.matter_port, service, |service| { + callback(service) + })?; + } + + Ok(()) + } +} + +pub struct MdnsRunner<'a>(&'a Mdns<'a>); + +impl<'a> MdnsRunner<'a> { + pub const fn new(mdns: &'a Mdns<'a>) -> Self { + Self(mdns) + } + + pub async fn run_udp(&mut self) -> Result<(), Error> { + let mut tx_buf = MdnsTxBuf::uninit(); + let mut rx_buf = MdnsRxBuf::uninit(); + + let tx_buf = &mut tx_buf; + let rx_buf = &mut rx_buf; + + let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() }); + let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() }); + + let tx_pipe = &tx_pipe; + let rx_pipe = &rx_pipe; + + let mut udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR.0, IP_BIND_ADDR.1)).await?; + + for (ip, _) in IP_BROADCAST_ADDRS { + udp.join_multicast(ip).await?; + } + + let udp = &udp; + + let mut tx = pin!(async move { + loop { + { + let mut data = tx_pipe.data.lock().await; + + if let Some(chunk) = data.chunk { + udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end]) + .await?; + data.chunk = None; + tx_pipe.data_consumed_notification.signal(()); + } + } + + tx_pipe.data_supplied_notification.wait().await; + } + }); + + let mut rx = pin!(async move { + loop { + { + let mut data = rx_pipe.data.lock().await; + + if data.chunk.is_none() { + let (len, addr) = udp.recv(data.buf).await?; + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(addr), + }); + rx_pipe.data_supplied_notification.signal(()); + } + } + + rx_pipe.data_consumed_notification.wait().await; + } + }); + + let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await }); + + select3(&mut tx, &mut rx, &mut run).await.unwrap() + } + + pub async fn run(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> { + let mut broadcast = pin!(self.broadcast(tx_pipe)); + let mut respond = pin!(self.respond(rx_pipe, tx_pipe)); + + select(&mut broadcast, &mut respond).await.unwrap() + } + + #[allow(clippy::await_holding_refcell_ref)] + async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { + loop { + select( + self.0.notification.wait(), + Timer::after(Duration::from_secs(30)), + ) + .await; + + for (addr, port) in IP_BROADCAST_ADDRS { + loop { + let sent = { + let mut data = tx_pipe.data.lock().await; + + if data.chunk.is_none() { + let len = self.0.host.broadcast(&self.0, data.buf, 60)?; + + if len > 0 { + info!("Broadasting mDNS entry to {}:{}", addr, port); + + data.chunk = Some(Chunk { + start: 0, + end: len, + addr: Address::Udp(SocketAddr::new(addr, port)), + }); + + tx_pipe.data_supplied_notification.signal(()); + } + + true + } else { + false + } + }; + + if sent { + break; + } else { + tx_pipe.data_consumed_notification.wait().await; + } + } + } + } + } + + #[allow(clippy::await_holding_refcell_ref)] + async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> { + loop { + { + let mut rx_data = rx_pipe.data.lock().await; + + if let Some(rx_chunk) = rx_data.chunk { + let data = &rx_data.buf[rx_chunk.start..rx_chunk.end]; + + loop { + let sent = { + let mut tx_data = tx_pipe.data.lock().await; + + if tx_data.chunk.is_none() { + let len = self.0.host.respond(&self.0, data, tx_data.buf, 60)?; + + if len > 0 { + info!("Replying to mDNS query from {}", rx_chunk.addr); + + tx_data.chunk = Some(Chunk { + start: 0, + end: len, + addr: rx_chunk.addr, + }); + + tx_pipe.data_supplied_notification.signal(()); + } + + true + } else { + false + } + }; + + if sent { + break; + } else { + tx_pipe.data_consumed_notification.wait().await; + } + } + + // info!("Got mDNS query"); + + rx_data.chunk = None; + rx_pipe.data_consumed_notification.signal(()); + } + } + + rx_pipe.data_supplied_notification.wait().await; + } + } +} + +impl<'a> super::Mdns for Mdns<'a> { + fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> { + Mdns::add(self, service, mode) + } + + fn remove(&self, service: &str) -> Result<(), Error> { + Mdns::remove(self, service) + } +} + +impl<'a> Services for Mdns<'a> { + type Error = crate::error::Error; + + fn for_each(&self, callback: F) -> Result<(), Error> + where + F: FnMut(&Service) -> Result<(), Error>, + { + Mdns::for_each(self, callback) + } +} + +impl From for Error { + fn from(_e: ShortBuf) -> Self { + Self::new(ErrorCode::NoSpace) + } +} + +impl From for Error { + fn from(_e: ParseError) -> Self { + Self::new(ErrorCode::MdnsError) + } +} + +impl From for Error { + fn from(_e: FromStrError) -> Self { + Self::new(ErrorCode::MdnsError) + } +} diff --git a/matter/src/mdns/proto.rs b/matter/src/mdns/proto.rs new file mode 100644 index 0000000..6fac2c7 --- /dev/null +++ b/matter/src/mdns/proto.rs @@ -0,0 +1,508 @@ +use core::fmt::Write; +use core::str::FromStr; + +use domain::{ + base::{ + header::Flags, + iana::Class, + message_builder::AnswerBuilder, + name::FromStrError, + octets::{Octets256, Octets64, OctetsBuilder, ParseError}, + Dname, Message, MessageBuilder, Record, Rtype, ShortBuf, ToDname, + }, + rdata::{Aaaa, Ptr, Srv, Txt, A}, +}; +use log::trace; + +pub trait Services { + type Error: From + From + From; + + fn for_each(&self, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&Service) -> Result<(), Self::Error>; +} + +impl Services for &mut T +where + T: Services, +{ + type Error = T::Error; + + fn for_each(&self, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&Service) -> Result<(), Self::Error>, + { + (**self).for_each(callback) + } +} + +impl Services for &T +where + T: Services, +{ + type Error = T::Error; + + fn for_each(&self, callback: F) -> Result<(), Self::Error> + where + F: FnMut(&Service) -> Result<(), Self::Error>, + { + (**self).for_each(callback) + } +} + +pub struct Host<'a> { + pub id: u16, + pub hostname: &'a str, + pub ip: [u8; 4], + pub ipv6: Option<[u8; 16]>, +} + +impl<'a> Host<'a> { + pub fn broadcast( + &self, + services: T, + buf: &mut [u8], + ttl_sec: u32, + ) -> Result { + let buf = Buf(buf, 0); + + let message = MessageBuilder::from_target(buf)?; + + let mut answer = message.answer(); + + self.set_broadcast(services, &mut answer, ttl_sec)?; + + let buf = answer.finish(); + + Ok(buf.1) + } + + pub fn respond( + &self, + services: T, + data: &[u8], + buf: &mut [u8], + ttl_sec: u32, + ) -> Result { + let buf = Buf(buf, 0); + + let message = MessageBuilder::from_target(buf)?; + + let mut answer = message.answer(); + + if self.set_response(data, services, &mut answer, ttl_sec)? { + let buf = answer.finish(); + + Ok(buf.1) + } else { + Ok(0) + } + } + + fn set_broadcast( + &self, + services: F, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), F::Error> + where + T: OctetsBuilder + AsMut<[u8]>, + F: Services, + { + self.set_header(answer); + + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + + services.for_each(|service| { + service.add_service(answer, self.hostname, ttl_sec)?; + service.add_service_type(answer, ttl_sec)?; + service.add_service_subtypes(answer, ttl_sec)?; + service.add_txt(answer, ttl_sec)?; + + Ok(()) + })?; + + Ok(()) + } + + fn set_response( + &self, + data: &[u8], + services: F, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result + where + T: OctetsBuilder + AsMut<[u8]>, + F: Services, + { + self.set_header(answer); + + let message = Message::from_octets(data)?; + + let mut replied = false; + + for question in message.question() { + trace!("Handling question {:?}", question); + + let question = question?; + + match question.qtype() { + Rtype::A + if question + .qname() + .name_eq(&Host::host_fqdn(self.hostname, true)?) => + { + self.add_ipv4(answer, ttl_sec)?; + replied = true; + } + Rtype::Aaaa + if question + .qname() + .name_eq(&Host::host_fqdn(self.hostname, true)?) => + { + self.add_ipv6(answer, ttl_sec)?; + replied = true; + } + Rtype::Srv => { + services.for_each(|service| { + if question.qname().name_eq(&service.service_fqdn(true)?) { + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + service.add_service(answer, self.hostname, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + } + Rtype::Ptr => { + services.for_each(|service| { + if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) { + service.add_service_type(answer, ttl_sec)?; + replied = true; + } else if question.qname().name_eq(&service.service_type_fqdn(true)?) { + // TODO + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + service.add_service(answer, self.hostname, ttl_sec)?; + service.add_service_type(answer, ttl_sec)?; + service.add_service_subtypes(answer, ttl_sec)?; + service.add_txt(answer, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + } + Rtype::Txt => { + services.for_each(|service| { + if question.qname().name_eq(&service.service_fqdn(true)?) { + service.add_txt(answer, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + } + Rtype::Any => { + // A / AAAA + if question + .qname() + .name_eq(&Host::host_fqdn(self.hostname, true)?) + { + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + replied = true; + } + + // PTR + services.for_each(|service| { + if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) { + service.add_service_type(answer, ttl_sec)?; + replied = true; + } else if question.qname().name_eq(&service.service_type_fqdn(true)?) { + // TODO + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + service.add_service(answer, self.hostname, ttl_sec)?; + service.add_service_type(answer, ttl_sec)?; + service.add_service_subtypes(answer, ttl_sec)?; + service.add_txt(answer, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + + // SRV + services.for_each(|service| { + if question.qname().name_eq(&service.service_fqdn(true)?) { + self.add_ipv4(answer, ttl_sec)?; + self.add_ipv6(answer, ttl_sec)?; + service.add_service(answer, self.hostname, ttl_sec)?; + replied = true; + } + + Ok(()) + })?; + } + _ => (), + } + } + + Ok(replied) + } + + fn set_header>(&self, answer: &mut AnswerBuilder) { + let header = answer.header_mut(); + header.set_id(self.id); + header.set_opcode(domain::base::iana::Opcode::Query); + header.set_rcode(domain::base::iana::Rcode::NoError); + + let mut flags = Flags::new(); + flags.qr = true; + flags.aa = true; + header.set_flags(flags); + } + + fn add_ipv4>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + answer.push(Record::, A>::new( + Self::host_fqdn(self.hostname, false).unwrap(), + Class::In, + ttl_sec, + A::from_octets(self.ip[0], self.ip[1], self.ip[2], self.ip[3]), + )) + } + + fn add_ipv6>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + if let Some(ip) = &self.ipv6 { + answer.push(Record::, Aaaa>::new( + Self::host_fqdn(self.hostname, false).unwrap(), + Class::In, + ttl_sec, + Aaaa::new((*ip).into()), + )) + } else { + Ok(()) + } + } + + fn host_fqdn(hostname: &str, suffix: bool) -> Result, FromStrError> { + let suffix = if suffix { "." } else { "" }; + + let mut host_fqdn = heapless::String::<60>::new(); + write!(host_fqdn, "{}.local{}", hostname, suffix,).unwrap(); + + Dname::from_str(&host_fqdn) + } +} + +pub struct Service<'a> { + pub name: &'a str, + pub service: &'a str, + pub protocol: &'a str, + pub port: u16, + pub service_subtypes: &'a [&'a str], + pub txt_kvs: &'a [(&'a str, &'a str)], +} + +impl<'a> Service<'a> { + fn add_service>( + &self, + answer: &mut AnswerBuilder, + hostname: &str, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + answer.push(Record::, Srv<_>>::new( + self.service_fqdn(false).unwrap(), + Class::In, + ttl_sec, + Srv::new(0, 0, self.port, Host::host_fqdn(hostname, false).unwrap()), + )) + } + + fn add_service_type>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + answer.push(Record::, Ptr<_>>::new( + Self::dns_sd_fqdn(false).unwrap(), + Class::In, + ttl_sec, + Ptr::new(self.service_type_fqdn(false).unwrap()), + ))?; + + answer.push(Record::, Ptr<_>>::new( + self.service_type_fqdn(false).unwrap(), + Class::In, + ttl_sec, + Ptr::new(self.service_fqdn(false).unwrap()), + )) + } + + fn add_service_subtypes>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + for service_subtype in self.service_subtypes { + self.add_service_subtype(answer, service_subtype, ttl_sec)?; + } + + Ok(()) + } + + fn add_service_subtype>( + &self, + answer: &mut AnswerBuilder, + service_subtype: &str, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + answer.push(Record::, Ptr<_>>::new( + Self::dns_sd_fqdn(false).unwrap(), + Class::In, + ttl_sec, + Ptr::new(self.service_subtype_fqdn(service_subtype, false).unwrap()), + ))?; + + answer.push(Record::, Ptr<_>>::new( + self.service_subtype_fqdn(service_subtype, false).unwrap(), + Class::In, + ttl_sec, + Ptr::new(self.service_fqdn(false).unwrap()), + )) + } + + fn add_txt>( + &self, + answer: &mut AnswerBuilder, + ttl_sec: u32, + ) -> Result<(), ShortBuf> { + // only way I found to create multiple parts in a Txt + // each slice is the length and then the data + let mut octets = Octets256::new(); + //octets.append_slice(&[1u8, b'X'])?; + //octets.append_slice(&[2u8, b'A', b'B'])?; + //octets.append_slice(&[0u8])?; + for (k, v) in self.txt_kvs { + octets.append_slice(&[(k.len() + v.len() + 1) as u8])?; + octets.append_slice(k.as_bytes())?; + octets.append_slice(&[b'='])?; + octets.append_slice(v.as_bytes())?; + } + + let txt = Txt::from_octets(&mut octets).unwrap(); + + answer.push(Record::, Txt<_>>::new( + self.service_fqdn(false).unwrap(), + Class::In, + ttl_sec, + txt, + )) + } + + fn service_fqdn(&self, suffix: bool) -> Result, FromStrError> { + let suffix = if suffix { "." } else { "" }; + + let mut service_fqdn = heapless::String::<60>::new(); + write!( + service_fqdn, + "{}.{}.{}.local{}", + self.name, self.service, self.protocol, suffix, + ) + .unwrap(); + + Dname::from_str(&service_fqdn) + } + + fn service_type_fqdn(&self, suffix: bool) -> Result, FromStrError> { + let suffix = if suffix { "." } else { "" }; + + let mut service_type_fqdn = heapless::String::<60>::new(); + write!( + service_type_fqdn, + "{}.{}.local{}", + self.service, self.protocol, suffix, + ) + .unwrap(); + + Dname::from_str(&service_type_fqdn) + } + + fn service_subtype_fqdn( + &self, + service_subtype: &str, + suffix: bool, + ) -> Result, FromStrError> { + let suffix = if suffix { "." } else { "" }; + + let mut service_subtype_fqdn = heapless::String::<40>::new(); + write!( + service_subtype_fqdn, + "{}._sub.{}.{}.local{}", + service_subtype, self.service, self.protocol, suffix, + ) + .unwrap(); + + Dname::from_str(&service_subtype_fqdn) + } + + fn dns_sd_fqdn(suffix: bool) -> Result, FromStrError> { + if suffix { + Dname::from_str("_services._dns-sd._udp.local.") + } else { + Dname::from_str("_services._dns-sd._udp.local") + } + } +} + +struct Buf<'a>(pub &'a mut [u8], pub usize); + +impl<'a> OctetsBuilder for Buf<'a> { + type Octets = Self; + + fn append_slice(&mut self, slice: &[u8]) -> Result<(), ShortBuf> { + if self.1 + slice.len() <= self.0.len() { + let end = self.1 + slice.len(); + self.0[self.1..end].copy_from_slice(slice); + self.1 = end; + + Ok(()) + } else { + Err(ShortBuf) + } + } + + fn truncate(&mut self, len: usize) { + self.1 = len; + } + + fn freeze(self) -> Self::Octets { + self + } + + fn len(&self) -> usize { + self.1 + } + + fn is_empty(&self) -> bool { + self.1 == 0 + } +} + +impl<'a> AsMut<[u8]> for Buf<'a> { + fn as_mut(&mut self) -> &mut [u8] { + &mut self.0[..self.1] + } +} diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index 523278e..0ad17ed 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -20,7 +20,7 @@ use core::{borrow::Borrow, cell::RefCell}; use crate::{ error::*, fabric::FabricMgr, - mdns::MdnsMgr, + mdns::Mdns, secure_channel::common::*, tlv, transport::{proto_ctx::ProtoCtx, session::CloneData}, @@ -36,7 +36,7 @@ use super::{case::Case, pake::PaseMgr}; pub struct SecureChannel<'a> { case: Case<'a>, pase: &'a RefCell, - mdns: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, } impl<'a> SecureChannel<'a> { @@ -44,7 +44,7 @@ impl<'a> SecureChannel<'a> { pub fn new< T: Borrow> + Borrow> - + Borrow> + + Borrow + Borrow + Borrow, >( @@ -62,7 +62,7 @@ impl<'a> SecureChannel<'a> { pub fn wrap( pase: &'a RefCell, fabric: &'a RefCell, - mdns: &'a MdnsMgr<'a>, + mdns: &'a dyn Mdns, rand: Rand, ) -> Self { Self { diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index 60920d0..79f7d2c 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -24,7 +24,7 @@ use super::{ use crate::{ crypto, error::{Error, ErrorCode}, - mdns::{MdnsMgr, ServiceMode}, + mdns::{Mdns, ServiceMode}, secure_channel::common::OpCode, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, transport::{ @@ -39,7 +39,7 @@ use log::{error, info}; #[allow(clippy::large_enum_variant)] enum PaseMgrState { - Enabled(Pake, heapless::String<16>, u16), + Enabled(Pake, heapless::String<16>), Disabled, } @@ -60,14 +60,14 @@ impl PaseMgr { } pub fn is_pase_session_enabled(&self) -> bool { - matches!(&self.state, PaseMgrState::Enabled(_, _, _)) + matches!(&self.state, PaseMgrState::Enabled(_, _)) } pub fn enable_pase_session( &mut self, verifier: VerifierData, discriminator: u16, - mdns: &MdnsMgr, + mdns: &dyn Mdns, ) -> Result<(), Error> { let mut buf = [0; 8]; (self.rand)(&mut buf); @@ -76,25 +76,21 @@ impl PaseMgr { let mut mdns_service_name = heapless::String::<16>::new(); write!(&mut mdns_service_name, "{:016X}", num).unwrap(); - mdns.publish_service( + mdns.add( &mdns_service_name, ServiceMode::Commissionable(discriminator), )?; self.state = PaseMgrState::Enabled( Pake::new(verifier, self.epoch, self.rand), mdns_service_name, - discriminator, ); Ok(()) } - pub fn disable_pase_session(&mut self, mdns: &MdnsMgr) -> Result<(), Error> { - if let PaseMgrState::Enabled(_, mdns_service_name, discriminator) = &self.state { - mdns.unpublish_service( - mdns_service_name, - ServiceMode::Commissionable(*discriminator), - )?; + pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> { + if let PaseMgrState::Enabled(_, mdns_service_name) = &self.state { + mdns.remove(mdns_service_name)?; } self.state = PaseMgrState::Disabled; @@ -108,7 +104,7 @@ impl PaseMgr { where F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result, { - if let PaseMgrState::Enabled(pake, _, _) = &mut self.state { + if let PaseMgrState::Enabled(pake, _) = &mut self.state { let data = f(pake, ctx)?; Ok(Some(data)) @@ -134,7 +130,7 @@ impl PaseMgr { pub fn pasepake3_handler( &mut self, ctx: &mut ProtoCtx, - mdns: &MdnsMgr, + mdns: &dyn Mdns, ) -> Result<(bool, Option), Error> { let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; self.disable_pase_session(mdns)?; diff --git a/matter/src/transport/udp.rs b/matter/src/transport/udp.rs index 7cf5288..5308462 100644 --- a/matter/src/transport/udp.rs +++ b/matter/src/transport/udp.rs @@ -27,7 +27,7 @@ mod smol_udp { use log::{debug, info, warn}; use smol::net::UdpSocket; - use crate::transport::network::SocketAddr; + use crate::transport::network::{IpAddr, Ipv4Addr, SocketAddr}; pub struct UdpListener { socket: UdpSocket, @@ -44,9 +44,18 @@ mod smol_udp { Ok(listener) } - pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { - info!("Waiting for incoming packets"); + pub async fn join_multicast(&mut self, ip_addr: IpAddr) -> Result<(), Error> { + match ip_addr { + IpAddr::V4(ip_addr) => self + .socket + .join_multicast_v4(ip_addr, Ipv4Addr::UNSPECIFIED)?, + IpAddr::V6(ip_addr) => self.socket.join_multicast_v6(&ip_addr, 0)?, + } + Ok(()) + } + + pub async fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { let (size, addr) = self.socket.recv_from(in_buf).await.map_err(|e| { warn!("Error on the network: {:?}", e); ErrorCode::Network @@ -96,6 +105,10 @@ mod dummy_udp { Ok(listener) } + pub async fn join_multicast(&mut self, ip_addr: IpAddr) -> Result<(), Error> { + Ok(()) + } + pub async fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, SocketAddr), Error> { info!("Pretending to wait for incoming packets (looping forever)");