From 505fa39e8205e8b7218c0433294dd79881b8e8d1 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Mon, 24 Apr 2023 09:00:08 +0000 Subject: [PATCH] Create new secure channel sessions without async-channel --- matter/Cargo.toml | 1 - matter/src/error.rs | 15 ------- matter/src/secure_channel/case.rs | 16 ++++---- matter/src/secure_channel/core.rs | 24 +++++++---- matter/src/secure_channel/pake.rs | 30 ++++++++------ matter/src/transport/mgr.rs | 55 +++++++++++++------------ matter/src/transport/mod.rs | 1 - matter/src/transport/proto_ctx.rs | 2 +- matter/src/transport/queue.rs | 67 ------------------------------- 9 files changed, 71 insertions(+), 140 deletions(-) delete mode 100644 matter/src/transport/queue.rs diff --git a/matter/Cargo.toml b/matter/Cargo.toml index b626083..53bea88 100644 --- a/matter/Cargo.toml +++ b/matter/Cargo.toml @@ -45,7 +45,6 @@ smol = "1.3.0" owning_ref = "0.4.1" safemem = "0.3.3" chrono = { version = "0.4.23", default-features = false, features = ["clock", "std"] } -async-channel = "1.8" # crypto openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } diff --git a/matter/src/error.rs b/matter/src/error.rs index 04d55b3..e644a7a 100644 --- a/matter/src/error.rs +++ b/matter/src/error.rs @@ -17,7 +17,6 @@ use core::{array::TryFromSliceError, fmt, str::Utf8Error}; -use async_channel::{SendError, TryRecvError}; use log::error; #[derive(Debug, PartialEq, Clone, Copy)] @@ -156,26 +155,12 @@ impl From for Error { } } -impl From> for Error { - fn from(e: SendError) -> Self { - error!("Error in channel send {}", e); - Self::Invalid - } -} - impl From for Error { fn from(_e: Utf8Error) -> Self { Self::Utf8Fail } } -impl From for Error { - fn from(e: TryRecvError) -> Self { - error!("Error in channel try_recv {}", e); - Self::Invalid - } -} - impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "{:?}", self) diff --git a/matter/src/secure_channel/case.rs b/matter/src/secure_channel/case.rs index 4ffb6b2..e681ec9 100644 --- a/matter/src/secure_channel/case.rs +++ b/matter/src/secure_channel/case.rs @@ -30,7 +30,6 @@ use crate::{ transport::{ network::Address, proto_ctx::ProtoCtx, - queue::{Msg, WorkQ}, session::{CaseDetails, CloneData, NocCatIds, SessionMode}, }, utils::{epoch::UtcCalendar, rand::Rand, writebuf::WriteBuf}, @@ -83,7 +82,10 @@ impl<'a> Case<'a> { } } - pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn casesigma3_handler( + &mut self, + ctx: &mut ProtoCtx, + ) -> Result<(bool, Option), Error> { let mut case_session = ctx .exch_ctx .exch @@ -104,7 +106,7 @@ impl<'a> Case<'a> { None, )?; ctx.exch_ctx.exch.close(); - return Ok(true); + return Ok((true, None)); } // Safe to unwrap here let fabric = fabric.unwrap(); @@ -137,7 +139,7 @@ impl<'a> Case<'a> { error!("Certificate Chain doesn't match: {}", e); common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); - return Ok(true); + return Ok((true, None)); } if Case::validate_sigma3_sign( @@ -152,7 +154,7 @@ impl<'a> Case<'a> { error!("Sigma3 Signature doesn't match"); common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?; ctx.exch_ctx.exch.close(); - return Ok(true); + return Ok((true, None)); } // Only now do we add this message to the TT Hash @@ -167,13 +169,11 @@ impl<'a> Case<'a> { &case_session, &peer_catids, )?; - // Queue a transport mgr request to add a new session - WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; common::create_sc_status_report(ctx.tx, SCStatusCodes::SessionEstablishmentSuccess, None)?; ctx.exch_ctx.exch.clear_data(); ctx.exch_ctx.exch.close(); - Ok(true) + Ok((true, Some(clone_data))) } pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { diff --git a/matter/src/secure_channel/core.rs b/matter/src/secure_channel/core.rs index be806a7..e69dca5 100644 --- a/matter/src/secure_channel/core.rs +++ b/matter/src/secure_channel/core.rs @@ -23,7 +23,7 @@ use crate::{ mdns::MdnsMgr, secure_channel::common::*, tlv, - transport::proto_ctx::ProtoCtx, + transport::{proto_ctx::ProtoCtx, session::CloneData}, utils::{epoch::UtcCalendar, rand::Rand}, }; use log::{error, info}; @@ -55,22 +55,30 @@ impl<'a> SecureChannel<'a> { } } - pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result { + pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option), Error> { let proto_opcode: OpCode = num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); info!("Received Opcode: {:?}", proto_opcode); info!("Received Data:"); tlv::print_tlv_list(ctx.rx.as_slice()); - let reply = match proto_opcode { - OpCode::MRPStandAloneAck => Ok(true), - OpCode::PBKDFParamRequest => self.pase.borrow_mut().pbkdfparamreq_handler(ctx), - OpCode::PASEPake1 => self.pase.borrow_mut().pasepake1_handler(ctx), + let (reply, clone_data) = match proto_opcode { + OpCode::MRPStandAloneAck => Ok((true, None)), + OpCode::PBKDFParamRequest => self + .pase + .borrow_mut() + .pbkdfparamreq_handler(ctx) + .map(|reply| (reply, None)), + OpCode::PASEPake1 => self + .pase + .borrow_mut() + .pasepake1_handler(ctx) + .map(|reply| (reply, None)), OpCode::PASEPake3 => self .pase .borrow_mut() .pasepake3_handler(ctx, &mut self.mdns.borrow_mut()), - OpCode::CASESigma1 => self.case.casesigma1_handler(ctx), + OpCode::CASESigma1 => self.case.casesigma1_handler(ctx).map(|reply| (reply, None)), OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), _ => { error!("OpCode Not Handled: {:?}", proto_opcode); @@ -83,6 +91,6 @@ impl<'a> SecureChannel<'a> { tlv::print_tlv_list(ctx.tx.as_mut_slice()); } - Ok(reply) + Ok((reply, clone_data)) } } diff --git a/matter/src/secure_channel/pake.rs b/matter/src/secure_channel/pake.rs index ce05fb6..1901686 100644 --- a/matter/src/secure_channel/pake.rs +++ b/matter/src/secure_channel/pake.rs @@ -31,7 +31,6 @@ use crate::{ exchange::ExchangeCtx, network::Address, proto_ctx::ProtoCtx, - queue::{Msg, WorkQ}, session::{CloneData, SessionMode}, }, utils::{epoch::Epoch, rand::Rand}, @@ -101,15 +100,18 @@ impl PaseMgr { /// If the PASE Session is enabled, execute the closure, /// if not enabled, generate SC Status Report - fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<(), Error> + fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result, Error> where - F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result<(), Error>, + F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result, { if let PaseMgrState::Enabled(pake, _, _) = &mut self.state { - f(pake, ctx) + let data = f(pake, ctx)?; + + Ok(Some(data)) } else { error!("PASE Not enabled"); - create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None) + create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None)?; + Ok(None) } } @@ -129,10 +131,10 @@ impl PaseMgr { &mut self, ctx: &mut ProtoCtx, mdns: &mut MdnsMgr, - ) -> Result { - self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; + ) -> Result<(bool, Option), Error> { + let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; self.disable_pase_session(mdns)?; - Ok(true) + Ok((true, clone_data.flatten())) } } @@ -230,13 +232,13 @@ impl Pake { } #[allow(non_snake_case)] - pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { + pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result, Error> { let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; let (status_code, ke) = sd.spake2p.handle_cA(cA); - if status_code == SCStatusCodes::SessionEstablishmentSuccess { + let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys let ke = ke.ok_or(Error::Invalid)?; let mut session_keys: [u8; 48] = [0; 48]; @@ -262,12 +264,14 @@ impl Pake { .copy_from_slice(&session_keys[32..48]); // Queue a transport mgr request to add a new session - WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; - } + Some(clone_data) + } else { + None + }; create_sc_status_report(ctx.tx, status_code, None)?; ctx.exch_ctx.exch.close(); - Ok(()) + Ok(clone_data) } #[allow(non_snake_case)] diff --git a/matter/src/transport/mgr.rs b/matter/src/transport/mgr.rs index 2ada942..1d68f34 100644 --- a/matter/src/transport/mgr.rs +++ b/matter/src/transport/mgr.rs @@ -33,12 +33,14 @@ use crate::utils::epoch::{Epoch, UtcCalendar}; use crate::utils::rand::Rand; use super::proto_ctx::ProtoCtx; +use super::session::CloneData; -#[derive(Copy, Clone, Eq, PartialEq)] enum RecvState { New, OpenExchange, + AddSession(CloneData), EvictSession, + EvictSession2(CloneData), Ack, } @@ -69,7 +71,7 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { fn maybe_next_action(&mut self) -> Result>>, Error> { self.mgr.exch_mgr.purge(); - match self.state { + match core::mem::replace(&mut self.state, RecvState::New) { RecvState::New => { self.mgr.exch_mgr.get_sess_mgr().decode(self.rx)?; self.state = RecvState::OpenExchange; @@ -80,13 +82,18 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL { let mut proto_ctx = ProtoCtx::new(exch_ctx, self.rx, self.tx); - if self.mgr.secure_channel.handle(&mut proto_ctx)? { - proto_ctx.send()?; + let (reply, clone_data) = self.mgr.secure_channel.handle(&mut proto_ctx)?; - self.state = RecvState::Ack; - Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + if let Some(clone_data) = clone_data { + self.state = RecvState::AddSession(clone_data); } else { self.state = RecvState::Ack; + } + + if reply { + proto_ctx.send()?; + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + } else { Ok(None) } } else { @@ -106,11 +113,27 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } Err(err) => Err(err), }, + RecvState::AddSession(clone_data) => match self.mgr.exch_mgr.add_session(&clone_data) { + Ok(_) => { + self.state = RecvState::Ack; + Ok(None) + } + Err(Error::NoSpace) => { + self.state = RecvState::EvictSession2(clone_data); + Ok(None) + } + Err(err) => Err(err), + }, RecvState::EvictSession => { self.mgr.exch_mgr.evict_session(self.tx)?; self.state = RecvState::OpenExchange; Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) } + RecvState::EvictSession2(clone_data) => { + self.mgr.exch_mgr.evict_session(self.tx)?; + self.state = RecvState::AddSession(clone_data); + Ok(Some(Some(RecvAction::Send(self.tx.as_slice())))) + } RecvState::Ack => { if let Some(exch_id) = self.mgr.exch_mgr.pending_ack() { info!("Sending MRP Standalone ACK for exch {}", exch_id); @@ -127,7 +150,6 @@ impl<'r, 'a, 'p> RecvCompletion<'r, 'a, 'p> { } } -#[derive(Copy, Clone, Eq, PartialEq)] enum NotifyState {} pub enum NotifyAction<'r, 'p> { @@ -212,23 +234,4 @@ impl<'a> TransportMgr<'a> { pub fn notify(&mut self, _tx: &mut Packet) -> Result { Ok(false) } - - // async fn handle_queue_msgs(&mut self) -> Result<(), Error> { - // if let Ok(msg) = self.rx_q.try_recv() { - // match msg { - // Msg::NewSession(clone_data) => { - // // If a new session was created, add it - // let _ = self - // .exch_mgr - // .add_session(&clone_data) - // .await - // .map_err(|e| error!("Error adding new session {:?}", e)); - // } - // _ => { - // error!("Queue Message Type not yet handled {:?}", msg); - // } - // } - // } - // Ok(()) - // } } diff --git a/matter/src/transport/mod.rs b/matter/src/transport/mod.rs index 43acccd..1a81c75 100644 --- a/matter/src/transport/mod.rs +++ b/matter/src/transport/mod.rs @@ -24,6 +24,5 @@ pub mod packet; pub mod plain_hdr; pub mod proto_ctx; pub mod proto_hdr; -pub mod queue; pub mod session; pub mod udp; diff --git a/matter/src/transport/proto_ctx.rs b/matter/src/transport/proto_ctx.rs index 747a1e6..c4bf7f3 100644 --- a/matter/src/transport/proto_ctx.rs +++ b/matter/src/transport/proto_ctx.rs @@ -38,6 +38,6 @@ impl<'a, 'b> ProtoCtx<'a, 'b> { pub fn send(&mut self) -> Result<&[u8], Error> { self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess)?; - Ok(self.tx.as_mut_slice()) + Ok(self.tx.as_slice()) } } diff --git a/matter/src/transport/queue.rs b/matter/src/transport/queue.rs deleted file mode 100644 index b0c0f37..0000000 --- a/matter/src/transport/queue.rs +++ /dev/null @@ -1,67 +0,0 @@ -/* - * - * Copyright (c) 2020-2022 Project CHIP Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -use std::sync::Once; - -use async_channel::{bounded, Receiver, Sender}; - -use crate::error::Error; - -use super::session::CloneData; - -#[derive(Debug)] -pub enum Msg { - Tx(), - Rx(), - NewSession(CloneData), -} - -#[derive(Clone)] -pub struct WorkQ { - tx: Sender, -} - -static mut G_WQ: Option = None; -static INIT: Once = Once::new(); - -impl WorkQ { - pub fn init() -> Result, Error> { - let (tx, rx) = bounded::(3); - WorkQ::configure(tx); - Ok(rx) - } - - fn configure(tx: Sender) { - unsafe { - INIT.call_once(|| { - G_WQ = Some(WorkQ { tx }); - }); - } - } - - pub fn get() -> Result { - unsafe { G_WQ.as_ref().cloned().ok_or(Error::Invalid) } - } - - pub fn sync_send(&self, msg: Msg) -> Result<(), Error> { - smol::block_on(self.send(msg)) - } - - pub async fn send(&self, msg: Msg) -> Result<(), Error> { - self.tx.send(msg).await.map_err(|e| e.into()) - } -}