From f53f3b789d5204bf940a65b65e18d072808b6909 Mon Sep 17 00:00:00 2001 From: ivmarkov Date: Tue, 1 Aug 2023 06:49:42 +0000 Subject: [PATCH] Do not hold on to RefCell borrows across await points --- .github/workflows/ci.yml | 2 +- rs-matter/src/mdns/builtin.rs | 2 - rs-matter/src/secure_channel/case.rs | 268 +++++++++--------- .../src/secure_channel/crypto_mbedtls.rs | 8 +- rs-matter/src/secure_channel/pake.rs | 156 +++++----- rs-matter/src/transport/core.rs | 1 - 6 files changed, 210 insertions(+), 227 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 57a9ce9..ea2f154 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -10,7 +10,7 @@ on: workflow_dispatch: env: - RUST_TOOLCHAIN: nightly-2023-07-01 + RUST_TOOLCHAIN: nightly GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} CARGO_TERM_COLOR: always diff --git a/rs-matter/src/mdns/builtin.rs b/rs-matter/src/mdns/builtin.rs index 97cb94e..5ca8676 100644 --- a/rs-matter/src/mdns/builtin.rs +++ b/rs-matter/src/mdns/builtin.rs @@ -208,7 +208,6 @@ impl<'a> MdnsService<'a> { select(&mut broadcast, &mut respond).await.unwrap() } - #[allow(clippy::await_holding_refcell_ref)] async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { loop { select( @@ -258,7 +257,6 @@ impl<'a> MdnsService<'a> { } } - #[allow(clippy::await_holding_refcell_ref)] async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> { loop { { diff --git a/rs-matter/src/secure_channel/case.rs b/rs-matter/src/secure_channel/case.rs index 84da032..090e989 100644 --- a/rs-matter/src/secure_channel/case.rs +++ b/rs-matter/src/secure_channel/case.rs @@ -87,7 +87,6 @@ impl<'a> Case<'a> { self.handle_casesigma3(exchange, rx, tx, &mut session).await } - #[allow(clippy::await_holding_refcell_ref)] async fn handle_casesigma3( &mut self, exchange: &mut Exchange<'_>, @@ -97,100 +96,81 @@ impl<'a> Case<'a> { ) -> Result<(), Error> { rx.check_proto_opcode(OpCode::CASESigma3 as _)?; - let fabric_mgr = self.fabric_mgr.borrow(); + let status = { + let fabric_mgr = self.fabric_mgr.borrow(); - let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; - if fabric.is_none() { - drop(fabric_mgr); - complete_with_status( - exchange, - tx, - common::SCStatusCodes::NoSharedTrustRoots, - None, - ) - .await?; - return Ok(()); - } - // Safe to unwrap here - let fabric = fabric.unwrap(); + let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; + if let Some(fabric) = fabric { + let root = get_root_node_struct(rx.as_slice())?; + let encrypted = root.find_tag(1)?.slice()?; - let root = get_root_node_struct(rx.as_slice())?; - let encrypted = root.find_tag(1)?.slice()?; + let mut decrypted = alloc!([0; 800]); + if encrypted.len() > decrypted.len() { + error!("Data too large"); + Err(ErrorCode::NoSpace)?; + } + let decrypted = &mut decrypted[..encrypted.len()]; + decrypted.copy_from_slice(encrypted); - let mut decrypted = alloc!([0; 800]); - if encrypted.len() > decrypted.len() { - error!("Data too large"); - Err(ErrorCode::NoSpace)?; - } - let decrypted = &mut decrypted[..encrypted.len()]; - decrypted.copy_from_slice(encrypted); + let len = + Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?; + let decrypted = &decrypted[..len]; - let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?; - let decrypted = &decrypted[..len]; + let root = get_root_node_struct(decrypted)?; + let d = Sigma3Decrypt::from_tlv(&root)?; - let root = get_root_node_struct(decrypted)?; - let d = Sigma3Decrypt::from_tlv(&root)?; + let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?); + let mut initiator_icac = None; + if let Some(icac) = d.initiator_icac { + initiator_icac = Some(alloc!(Cert::new(icac.0)?)); + } - let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?); - let mut initiator_icac = None; - if let Some(icac) = d.initiator_icac { - initiator_icac = Some(alloc!(Cert::new(icac.0)?)); - } + #[cfg(feature = "alloc")] + let initiator_icac_mut = initiator_icac.as_deref(); - #[cfg(feature = "alloc")] - let initiator_icac_mut = initiator_icac.as_deref(); + #[cfg(not(feature = "alloc"))] + let initiator_icac_mut = initiator_icac.as_ref(); - #[cfg(not(feature = "alloc"))] - let initiator_icac_mut = initiator_icac.as_ref(); + if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { + error!("Certificate Chain doesn't match: {}", e); + SCStatusCodes::InvalidParameter + } else if let Err(e) = Case::validate_sigma3_sign( + d.initiator_noc.0, + d.initiator_icac.map(|a| a.0), + &initiator_noc, + d.signature.0, + case_session, + ) { + error!("Sigma3 Signature doesn't match: {}", e); + SCStatusCodes::InvalidParameter + } else { + // Only now do we add this message to the TT Hash + let mut peer_catids: NocCatIds = Default::default(); + initiator_noc.get_cat_ids(&mut peer_catids); + case_session.tt_hash.update(rx.as_slice())?; + let clone_data = Case::get_session_clone_data( + fabric.ipk.op_key(), + fabric.get_node_id(), + initiator_noc.get_node_id()?, + exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, + case_session, + &peer_catids, + )?; - if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { - error!("Certificate Chain doesn't match: {}", e); - complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None) - .await?; - return Ok(()); - } + // TODO: Handle NoSpace + exchange + .with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; - if Case::validate_sigma3_sign( - d.initiator_noc.0, - d.initiator_icac.map(|a| a.0), - &initiator_noc, - d.signature.0, - case_session, - ) - .is_err() - { - error!("Sigma3 Signature doesn't match"); - complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None) - .await?; - return Ok(()); - } + SCStatusCodes::SessionEstablishmentSuccess + } + } else { + SCStatusCodes::NoSharedTrustRoots + } + }; - // Only now do we add this message to the TT Hash - let mut peer_catids: NocCatIds = Default::default(); - initiator_noc.get_cat_ids(&mut peer_catids); - case_session.tt_hash.update(rx.as_slice())?; - let clone_data = Case::get_session_clone_data( - fabric.ipk.op_key(), - fabric.get_node_id(), - initiator_noc.get_node_id()?, - exchange.with_session(|sess| Ok(sess.get_peer_addr()))?, - case_session, - &peer_catids, - )?; - - // TODO: Handle NoSpace - exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?; - - complete_with_status( - exchange, - tx, - SCStatusCodes::SessionEstablishmentSuccess, - None, - ) - .await + complete_with_status(exchange, tx, status, None).await } - #[allow(clippy::await_holding_refcell_ref)] async fn handle_casesigma1( &mut self, exchange: &mut Exchange<'_>, @@ -255,70 +235,76 @@ impl<'a> Case<'a> { const MAX_ENCRYPTED_SIZE: usize = 800; let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]); - let encrypted_len = { - let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]); + let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]); + + let fabric_found = { let fabric_mgr = self.fabric_mgr.borrow(); let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; - if fabric.is_none() { - drop(fabric_mgr); - complete_with_status( - exchange, - tx, - common::SCStatusCodes::NoSharedTrustRoots, - None, - ) - .await?; - return Ok(()); + if let Some(fabric) = fabric { + #[cfg(feature = "alloc")] + let signature_mut = &mut *signature; + + #[cfg(not(feature = "alloc"))] + let signature_mut = &mut signature; + + let sign_len = Case::get_sigma2_sign( + fabric, + &case_session.our_pub_key, + &case_session.peer_pub_key, + signature_mut, + )?; + let signature = &signature[..sign_len]; + + #[cfg(feature = "alloc")] + let encrypted_mut = &mut *encrypted; + + #[cfg(not(feature = "alloc"))] + let encrypted_mut = &mut encrypted; + + let encrypted_len = Case::get_sigma2_encryption( + fabric, + self.rand, + &our_random, + case_session, + signature, + encrypted_mut, + )?; + + let encrypted = &encrypted[0..encrypted_len]; + + // Generate our Response Body + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::CASESigma2 as u8); + + let mut tw = TLVWriter::new(tx.get_writebuf()?); + tw.start_struct(TagType::Anonymous)?; + tw.str8(TagType::Context(1), &our_random)?; + tw.u16(TagType::Context(2), local_sessid)?; + tw.str8(TagType::Context(3), &case_session.our_pub_key)?; + tw.str16(TagType::Context(4), encrypted)?; + tw.end_container()?; + + case_session.tt_hash.update(tx.as_mut_slice())?; + + true + } else { + false } - - #[cfg(feature = "alloc")] - let signature_mut = &mut *signature; - - #[cfg(not(feature = "alloc"))] - let signature_mut = &mut signature; - - let sign_len = Case::get_sigma2_sign( - fabric.unwrap(), - &case_session.our_pub_key, - &case_session.peer_pub_key, - signature_mut, - )?; - let signature = &signature[..sign_len]; - - #[cfg(feature = "alloc")] - let encrypted_mut = &mut *encrypted; - - #[cfg(not(feature = "alloc"))] - let encrypted_mut = &mut encrypted; - - Case::get_sigma2_encryption( - fabric.unwrap(), - self.rand, - &our_random, - case_session, - signature, - encrypted_mut, - )? }; - let encrypted = &encrypted[0..encrypted_len]; - // Generate our Response Body - tx.reset(); - tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - tx.set_proto_opcode(OpCode::CASESigma2 as u8); - - let mut tw = TLVWriter::new(tx.get_writebuf()?); - tw.start_struct(TagType::Anonymous)?; - tw.str8(TagType::Context(1), &our_random)?; - tw.u16(TagType::Context(2), local_sessid)?; - tw.str8(TagType::Context(3), &case_session.our_pub_key)?; - tw.str16(TagType::Context(4), encrypted)?; - tw.end_container()?; - - case_session.tt_hash.update(tx.as_mut_slice())?; - - exchange.exchange(tx, rx).await + if fabric_found { + exchange.exchange(tx, rx).await + } else { + complete_with_status( + exchange, + tx, + common::SCStatusCodes::NoSharedTrustRoots, + None, + ) + .await + } } fn get_session_clone_data( @@ -515,7 +501,7 @@ impl<'a> Case<'a> { fabric: &Fabric, rand: Rand, our_random: &[u8], - case_session: &mut CaseSession, + case_session: &CaseSession, signature: &[u8], out: &mut [u8], ) -> Result { diff --git a/rs-matter/src/secure_channel/crypto_mbedtls.rs b/rs-matter/src/secure_channel/crypto_mbedtls.rs index 8ddec40..0db2664 100644 --- a/rs-matter/src/secure_channel/crypto_mbedtls.rs +++ b/rs-matter/src/secure_channel/crypto_mbedtls.rs @@ -186,7 +186,7 @@ impl CryptoSpake2 { let (Z, V) = Self::get_ZV_as_verifier( &self.w0, &self.L, - &mut self.M, + &self.M, &X, &self.xy, &self.order, @@ -228,7 +228,7 @@ impl CryptoSpake2 { fn get_ZV_as_prover( w0: &Mpi, w1: &Mpi, - N: &mut EcPoint, + N: &EcPoint, Y: &EcPoint, x: &Mpi, order: &Mpi, @@ -264,7 +264,7 @@ impl CryptoSpake2 { fn get_ZV_as_verifier( w0: &Mpi, L: &EcPoint, - M: &mut EcPoint, + M: &EcPoint, X: &EcPoint, y: &Mpi, order: &Mpi, @@ -292,7 +292,7 @@ impl CryptoSpake2 { Ok((Z, V)) } - fn invert(group: &mut EcGroup, num: &EcPoint) -> Result { + fn invert(group: &EcGroup, num: &EcPoint) -> Result { let p = group.p()?; let num_y = num.y()?; let inverted_num_y = p.sub(&num_y)?; diff --git a/rs-matter/src/secure_channel/pake.rs b/rs-matter/src/secure_channel/pake.rs index ea2b98c..947b32f 100644 --- a/rs-matter/src/secure_channel/pake.rs +++ b/rs-matter/src/secure_channel/pake.rs @@ -213,7 +213,6 @@ impl<'a> Pake<'a> { } #[allow(non_snake_case)] - #[allow(clippy::await_holding_refcell_ref)] async fn handle_pasepake1( &mut self, exchange: &mut Exchange<'_>, @@ -224,32 +223,32 @@ impl<'a> Pake<'a> { rx.check_proto_opcode(OpCode::PASEPake1 as _)?; self.update_timeout(exchange, tx, false).await?; - let pase = self.pase.borrow(); - let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; + { + let pase = self.pase.borrow(); + let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; - let pA = extract_pasepake_1_or_3_params(rx.as_slice())?; - let mut pB: [u8; 65] = [0; 65]; - let mut cB: [u8; 32] = [0; 32]; - spake2p.start_verifier(&session.verifier)?; - spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?; + let pA = extract_pasepake_1_or_3_params(rx.as_slice())?; + let mut pB: [u8; 65] = [0; 65]; + let mut cB: [u8; 32] = [0; 32]; + spake2p.start_verifier(&session.verifier)?; + spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?; - // Generate response - tx.reset(); - tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - tx.set_proto_opcode(OpCode::PASEPake2 as u8); + // Generate response + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::PASEPake2 as u8); - let mut tw = TLVWriter::new(tx.get_writebuf()?); - let resp = Pake1Resp { - pb: OctetStr(&pB), - cb: OctetStr(&cB), - }; - resp.to_tlv(&mut tw, TagType::Anonymous)?; + let mut tw = TLVWriter::new(tx.get_writebuf()?); + let resp = Pake1Resp { + pb: OctetStr(&pB), + cb: OctetStr(&cB), + }; + resp.to_tlv(&mut tw, TagType::Anonymous)?; + } - drop(pase); exchange.exchange(tx, rx).await } - #[allow(clippy::await_holding_refcell_ref)] async fn handle_pbkdfparamrequest( &mut self, exchange: &mut Exchange<'_>, @@ -260,52 +259,51 @@ impl<'a> Pake<'a> { rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?; self.update_timeout(exchange, tx, true).await?; - let pase = self.pase.borrow(); - let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; + { + let pase = self.pase.borrow(); + let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; - let root = tlv::get_root_node(rx.as_slice())?; - let a = PBKDFParamReq::from_tlv(&root)?; - if a.passcode_id != 0 { - error!("Can't yet handle passcode_id != 0"); - Err(ErrorCode::Invalid)?; - } + let root = tlv::get_root_node(rx.as_slice())?; + let a = PBKDFParamReq::from_tlv(&root)?; + if a.passcode_id != 0 { + error!("Can't yet handle passcode_id != 0"); + Err(ErrorCode::Invalid)?; + } - let mut our_random: [u8; 32] = [0; 32]; - (self.pase.borrow().rand)(&mut our_random); + let mut our_random: [u8; 32] = [0; 32]; + (self.pase.borrow().rand)(&mut our_random); - let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; - let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; - spake2p.set_app_data(spake2p_data); + let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; + let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; + spake2p.set_app_data(spake2p_data); - // Generate response - tx.reset(); - tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); - tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); + // Generate response + tx.reset(); + tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); + tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); - let mut tw = TLVWriter::new(tx.get_writebuf()?); - let mut resp = PBKDFParamResp { - init_random: a.initiator_random, - our_random: OctetStr(&our_random), - local_sessid, - params: None, - }; - if !a.has_params { - let params_resp = PBKDFParamRespParams { - count: session.verifier.count, - salt: OctetStr(&session.verifier.salt), + let mut tw = TLVWriter::new(tx.get_writebuf()?); + let mut resp = PBKDFParamResp { + init_random: a.initiator_random, + our_random: OctetStr(&our_random), + local_sessid, + params: None, }; - resp.params = Some(params_resp); + if !a.has_params { + let params_resp = PBKDFParamRespParams { + count: session.verifier.count, + salt: OctetStr(&session.verifier.salt), + }; + resp.params = Some(params_resp); + } + resp.to_tlv(&mut tw, TagType::Anonymous)?; + + spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?; } - resp.to_tlv(&mut tw, TagType::Anonymous)?; - - spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?; - - drop(pase); exchange.exchange(tx, rx).await } - #[allow(clippy::await_holding_refcell_ref)] async fn update_timeout( &mut self, exchange: &mut Exchange<'_>, @@ -314,36 +312,38 @@ impl<'a> Pake<'a> { ) -> Result<(), Error> { self.check_session(exchange, tx).await?; - let mut pase = self.pase.borrow_mut(); + let status = { + let mut pase = self.pase.borrow_mut(); - if pase - .timeout - .as_ref() - .map(|sd| sd.is_sess_expired(pase.epoch)) - .unwrap_or(false) - { - pase.timeout = None; - } - - let status = if let Some(sd) = pase.timeout.as_mut() { - if &sd.exch_id != exchange.id() { - info!("Other PAKE session in progress"); - Some(SCStatusCodes::Busy) - } else { - None + if pase + .timeout + .as_ref() + .map(|sd| sd.is_sess_expired(pase.epoch)) + .unwrap_or(false) + { + pase.timeout = None; + } + + if let Some(sd) = pase.timeout.as_mut() { + if &sd.exch_id != exchange.id() { + info!("Other PAKE session in progress"); + Some(SCStatusCodes::Busy) + } else { + None + } + } else if new { + None + } else { + error!("PAKE session not found or expired"); + Some(SCStatusCodes::SessionNotFound) } - } else if new { - None - } else { - error!("PAKE session not found or expired"); - Some(SCStatusCodes::SessionNotFound) }; if let Some(status) = status { - drop(pase); - complete_with_status(exchange, tx, status, None).await } else { + let mut pase = self.pase.borrow_mut(); + pase.timeout = Some(Timeout::new(exchange, pase.epoch)); Ok(()) diff --git a/rs-matter/src/transport/core.rs b/rs-matter/src/transport/core.rs index 0874736..7300c65 100644 --- a/rs-matter/src/transport/core.rs +++ b/rs-matter/src/transport/core.rs @@ -397,7 +397,6 @@ impl<'a> Matter<'a> { } #[inline(always)] - #[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex pub async fn handle_exchange( &self, tx_buf: &mut [u8; MAX_TX_BUF_SIZE],