Do not hold on to RefCell borrows across await points

This commit is contained in:
ivmarkov 2023-08-01 06:49:42 +00:00
parent ede024cf71
commit f53f3b789d
6 changed files with 210 additions and 227 deletions

View file

@ -10,7 +10,7 @@ on:
workflow_dispatch: workflow_dispatch:
env: env:
RUST_TOOLCHAIN: nightly-2023-07-01 RUST_TOOLCHAIN: nightly
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
CARGO_TERM_COLOR: always CARGO_TERM_COLOR: always

View file

@ -208,7 +208,6 @@ impl<'a> MdnsService<'a> {
select(&mut broadcast, &mut respond).await.unwrap() select(&mut broadcast, &mut respond).await.unwrap()
} }
#[allow(clippy::await_holding_refcell_ref)]
async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> { async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
loop { loop {
select( select(
@ -258,7 +257,6 @@ impl<'a> MdnsService<'a> {
} }
} }
#[allow(clippy::await_holding_refcell_ref)]
async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> { async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
loop { loop {
{ {

View file

@ -87,7 +87,6 @@ impl<'a> Case<'a> {
self.handle_casesigma3(exchange, rx, tx, &mut session).await self.handle_casesigma3(exchange, rx, tx, &mut session).await
} }
#[allow(clippy::await_holding_refcell_ref)]
async fn handle_casesigma3( async fn handle_casesigma3(
&mut self, &mut self,
exchange: &mut Exchange<'_>, exchange: &mut Exchange<'_>,
@ -97,100 +96,81 @@ impl<'a> Case<'a> {
) -> Result<(), Error> { ) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::CASESigma3 as _)?; rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
let fabric_mgr = self.fabric_mgr.borrow(); let status = {
let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
if fabric.is_none() { if let Some(fabric) = fabric {
drop(fabric_mgr); let root = get_root_node_struct(rx.as_slice())?;
complete_with_status( let encrypted = root.find_tag(1)?.slice()?;
exchange,
tx,
common::SCStatusCodes::NoSharedTrustRoots,
None,
)
.await?;
return Ok(());
}
// Safe to unwrap here
let fabric = fabric.unwrap();
let root = get_root_node_struct(rx.as_slice())?; let mut decrypted = alloc!([0; 800]);
let encrypted = root.find_tag(1)?.slice()?; if encrypted.len() > decrypted.len() {
error!("Data too large");
Err(ErrorCode::NoSpace)?;
}
let decrypted = &mut decrypted[..encrypted.len()];
decrypted.copy_from_slice(encrypted);
let mut decrypted = alloc!([0; 800]); let len =
if encrypted.len() > decrypted.len() { Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
error!("Data too large"); let decrypted = &decrypted[..len];
Err(ErrorCode::NoSpace)?;
}
let decrypted = &mut decrypted[..encrypted.len()];
decrypted.copy_from_slice(encrypted);
let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?; let root = get_root_node_struct(decrypted)?;
let decrypted = &decrypted[..len]; let d = Sigma3Decrypt::from_tlv(&root)?;
let root = get_root_node_struct(decrypted)?; let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?);
let d = Sigma3Decrypt::from_tlv(&root)?; let mut initiator_icac = None;
if let Some(icac) = d.initiator_icac {
initiator_icac = Some(alloc!(Cert::new(icac.0)?));
}
let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?); #[cfg(feature = "alloc")]
let mut initiator_icac = None; let initiator_icac_mut = initiator_icac.as_deref();
if let Some(icac) = d.initiator_icac {
initiator_icac = Some(alloc!(Cert::new(icac.0)?));
}
#[cfg(feature = "alloc")] #[cfg(not(feature = "alloc"))]
let initiator_icac_mut = initiator_icac.as_deref(); let initiator_icac_mut = initiator_icac.as_ref();
#[cfg(not(feature = "alloc"))] if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
let initiator_icac_mut = initiator_icac.as_ref(); error!("Certificate Chain doesn't match: {}", e);
SCStatusCodes::InvalidParameter
} else if let Err(e) = Case::validate_sigma3_sign(
d.initiator_noc.0,
d.initiator_icac.map(|a| a.0),
&initiator_noc,
d.signature.0,
case_session,
) {
error!("Sigma3 Signature doesn't match: {}", e);
SCStatusCodes::InvalidParameter
} else {
// Only now do we add this message to the TT Hash
let mut peer_catids: NocCatIds = Default::default();
initiator_noc.get_cat_ids(&mut peer_catids);
case_session.tt_hash.update(rx.as_slice())?;
let clone_data = Case::get_session_clone_data(
fabric.ipk.op_key(),
fabric.get_node_id(),
initiator_noc.get_node_id()?,
exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
case_session,
&peer_catids,
)?;
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) { // TODO: Handle NoSpace
error!("Certificate Chain doesn't match: {}", e); exchange
complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None) .with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
.await?;
return Ok(());
}
if Case::validate_sigma3_sign( SCStatusCodes::SessionEstablishmentSuccess
d.initiator_noc.0, }
d.initiator_icac.map(|a| a.0), } else {
&initiator_noc, SCStatusCodes::NoSharedTrustRoots
d.signature.0, }
case_session, };
)
.is_err()
{
error!("Sigma3 Signature doesn't match");
complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
.await?;
return Ok(());
}
// Only now do we add this message to the TT Hash complete_with_status(exchange, tx, status, None).await
let mut peer_catids: NocCatIds = Default::default();
initiator_noc.get_cat_ids(&mut peer_catids);
case_session.tt_hash.update(rx.as_slice())?;
let clone_data = Case::get_session_clone_data(
fabric.ipk.op_key(),
fabric.get_node_id(),
initiator_noc.get_node_id()?,
exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
case_session,
&peer_catids,
)?;
// TODO: Handle NoSpace
exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
complete_with_status(
exchange,
tx,
SCStatusCodes::SessionEstablishmentSuccess,
None,
)
.await
} }
#[allow(clippy::await_holding_refcell_ref)]
async fn handle_casesigma1( async fn handle_casesigma1(
&mut self, &mut self,
exchange: &mut Exchange<'_>, exchange: &mut Exchange<'_>,
@ -255,70 +235,76 @@ impl<'a> Case<'a> {
const MAX_ENCRYPTED_SIZE: usize = 800; const MAX_ENCRYPTED_SIZE: usize = 800;
let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]); let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]);
let encrypted_len = { let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
let fabric_found = {
let fabric_mgr = self.fabric_mgr.borrow(); let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?; let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
if fabric.is_none() { if let Some(fabric) = fabric {
drop(fabric_mgr); #[cfg(feature = "alloc")]
complete_with_status( let signature_mut = &mut *signature;
exchange,
tx, #[cfg(not(feature = "alloc"))]
common::SCStatusCodes::NoSharedTrustRoots, let signature_mut = &mut signature;
None,
) let sign_len = Case::get_sigma2_sign(
.await?; fabric,
return Ok(()); &case_session.our_pub_key,
&case_session.peer_pub_key,
signature_mut,
)?;
let signature = &signature[..sign_len];
#[cfg(feature = "alloc")]
let encrypted_mut = &mut *encrypted;
#[cfg(not(feature = "alloc"))]
let encrypted_mut = &mut encrypted;
let encrypted_len = Case::get_sigma2_encryption(
fabric,
self.rand,
&our_random,
case_session,
signature,
encrypted_mut,
)?;
let encrypted = &encrypted[0..encrypted_len];
// Generate our Response Body
tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::CASESigma2 as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
tw.start_struct(TagType::Anonymous)?;
tw.str8(TagType::Context(1), &our_random)?;
tw.u16(TagType::Context(2), local_sessid)?;
tw.str8(TagType::Context(3), &case_session.our_pub_key)?;
tw.str16(TagType::Context(4), encrypted)?;
tw.end_container()?;
case_session.tt_hash.update(tx.as_mut_slice())?;
true
} else {
false
} }
#[cfg(feature = "alloc")]
let signature_mut = &mut *signature;
#[cfg(not(feature = "alloc"))]
let signature_mut = &mut signature;
let sign_len = Case::get_sigma2_sign(
fabric.unwrap(),
&case_session.our_pub_key,
&case_session.peer_pub_key,
signature_mut,
)?;
let signature = &signature[..sign_len];
#[cfg(feature = "alloc")]
let encrypted_mut = &mut *encrypted;
#[cfg(not(feature = "alloc"))]
let encrypted_mut = &mut encrypted;
Case::get_sigma2_encryption(
fabric.unwrap(),
self.rand,
&our_random,
case_session,
signature,
encrypted_mut,
)?
}; };
let encrypted = &encrypted[0..encrypted_len];
// Generate our Response Body if fabric_found {
tx.reset(); exchange.exchange(tx, rx).await
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); } else {
tx.set_proto_opcode(OpCode::CASESigma2 as u8); complete_with_status(
exchange,
let mut tw = TLVWriter::new(tx.get_writebuf()?); tx,
tw.start_struct(TagType::Anonymous)?; common::SCStatusCodes::NoSharedTrustRoots,
tw.str8(TagType::Context(1), &our_random)?; None,
tw.u16(TagType::Context(2), local_sessid)?; )
tw.str8(TagType::Context(3), &case_session.our_pub_key)?; .await
tw.str16(TagType::Context(4), encrypted)?; }
tw.end_container()?;
case_session.tt_hash.update(tx.as_mut_slice())?;
exchange.exchange(tx, rx).await
} }
fn get_session_clone_data( fn get_session_clone_data(
@ -515,7 +501,7 @@ impl<'a> Case<'a> {
fabric: &Fabric, fabric: &Fabric,
rand: Rand, rand: Rand,
our_random: &[u8], our_random: &[u8],
case_session: &mut CaseSession, case_session: &CaseSession,
signature: &[u8], signature: &[u8],
out: &mut [u8], out: &mut [u8],
) -> Result<usize, Error> { ) -> Result<usize, Error> {

View file

@ -186,7 +186,7 @@ impl CryptoSpake2 {
let (Z, V) = Self::get_ZV_as_verifier( let (Z, V) = Self::get_ZV_as_verifier(
&self.w0, &self.w0,
&self.L, &self.L,
&mut self.M, &self.M,
&X, &X,
&self.xy, &self.xy,
&self.order, &self.order,
@ -228,7 +228,7 @@ impl CryptoSpake2 {
fn get_ZV_as_prover( fn get_ZV_as_prover(
w0: &Mpi, w0: &Mpi,
w1: &Mpi, w1: &Mpi,
N: &mut EcPoint, N: &EcPoint,
Y: &EcPoint, Y: &EcPoint,
x: &Mpi, x: &Mpi,
order: &Mpi, order: &Mpi,
@ -264,7 +264,7 @@ impl CryptoSpake2 {
fn get_ZV_as_verifier( fn get_ZV_as_verifier(
w0: &Mpi, w0: &Mpi,
L: &EcPoint, L: &EcPoint,
M: &mut EcPoint, M: &EcPoint,
X: &EcPoint, X: &EcPoint,
y: &Mpi, y: &Mpi,
order: &Mpi, order: &Mpi,
@ -292,7 +292,7 @@ impl CryptoSpake2 {
Ok((Z, V)) Ok((Z, V))
} }
fn invert(group: &mut EcGroup, num: &EcPoint) -> Result<EcPoint, mbedtls::Error> { fn invert(group: &EcGroup, num: &EcPoint) -> Result<EcPoint, mbedtls::Error> {
let p = group.p()?; let p = group.p()?;
let num_y = num.y()?; let num_y = num.y()?;
let inverted_num_y = p.sub(&num_y)?; let inverted_num_y = p.sub(&num_y)?;

View file

@ -213,7 +213,6 @@ impl<'a> Pake<'a> {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(clippy::await_holding_refcell_ref)]
async fn handle_pasepake1( async fn handle_pasepake1(
&mut self, &mut self,
exchange: &mut Exchange<'_>, exchange: &mut Exchange<'_>,
@ -224,32 +223,32 @@ impl<'a> Pake<'a> {
rx.check_proto_opcode(OpCode::PASEPake1 as _)?; rx.check_proto_opcode(OpCode::PASEPake1 as _)?;
self.update_timeout(exchange, tx, false).await?; self.update_timeout(exchange, tx, false).await?;
let pase = self.pase.borrow(); {
let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; let pase = self.pase.borrow();
let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
let pA = extract_pasepake_1_or_3_params(rx.as_slice())?; let pA = extract_pasepake_1_or_3_params(rx.as_slice())?;
let mut pB: [u8; 65] = [0; 65]; let mut pB: [u8; 65] = [0; 65];
let mut cB: [u8; 32] = [0; 32]; let mut cB: [u8; 32] = [0; 32];
spake2p.start_verifier(&session.verifier)?; spake2p.start_verifier(&session.verifier)?;
spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?; spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?;
// Generate response // Generate response
tx.reset(); tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::PASEPake2 as u8); tx.set_proto_opcode(OpCode::PASEPake2 as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?); let mut tw = TLVWriter::new(tx.get_writebuf()?);
let resp = Pake1Resp { let resp = Pake1Resp {
pb: OctetStr(&pB), pb: OctetStr(&pB),
cb: OctetStr(&cB), cb: OctetStr(&cB),
}; };
resp.to_tlv(&mut tw, TagType::Anonymous)?; resp.to_tlv(&mut tw, TagType::Anonymous)?;
}
drop(pase);
exchange.exchange(tx, rx).await exchange.exchange(tx, rx).await
} }
#[allow(clippy::await_holding_refcell_ref)]
async fn handle_pbkdfparamrequest( async fn handle_pbkdfparamrequest(
&mut self, &mut self,
exchange: &mut Exchange<'_>, exchange: &mut Exchange<'_>,
@ -260,52 +259,51 @@ impl<'a> Pake<'a> {
rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?; rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?;
self.update_timeout(exchange, tx, true).await?; self.update_timeout(exchange, tx, true).await?;
let pase = self.pase.borrow(); {
let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?; let pase = self.pase.borrow();
let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
let root = tlv::get_root_node(rx.as_slice())?; let root = tlv::get_root_node(rx.as_slice())?;
let a = PBKDFParamReq::from_tlv(&root)?; let a = PBKDFParamReq::from_tlv(&root)?;
if a.passcode_id != 0 { if a.passcode_id != 0 {
error!("Can't yet handle passcode_id != 0"); error!("Can't yet handle passcode_id != 0");
Err(ErrorCode::Invalid)?; Err(ErrorCode::Invalid)?;
} }
let mut our_random: [u8; 32] = [0; 32]; let mut our_random: [u8; 32] = [0; 32];
(self.pase.borrow().rand)(&mut our_random); (self.pase.borrow().rand)(&mut our_random);
let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?; let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
spake2p.set_app_data(spake2p_data); spake2p.set_app_data(spake2p_data);
// Generate response // Generate response
tx.reset(); tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL); tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?); let mut tw = TLVWriter::new(tx.get_writebuf()?);
let mut resp = PBKDFParamResp { let mut resp = PBKDFParamResp {
init_random: a.initiator_random, init_random: a.initiator_random,
our_random: OctetStr(&our_random), our_random: OctetStr(&our_random),
local_sessid, local_sessid,
params: None, params: None,
};
if !a.has_params {
let params_resp = PBKDFParamRespParams {
count: session.verifier.count,
salt: OctetStr(&session.verifier.salt),
}; };
resp.params = Some(params_resp); if !a.has_params {
let params_resp = PBKDFParamRespParams {
count: session.verifier.count,
salt: OctetStr(&session.verifier.salt),
};
resp.params = Some(params_resp);
}
resp.to_tlv(&mut tw, TagType::Anonymous)?;
spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
} }
resp.to_tlv(&mut tw, TagType::Anonymous)?;
spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
drop(pase);
exchange.exchange(tx, rx).await exchange.exchange(tx, rx).await
} }
#[allow(clippy::await_holding_refcell_ref)]
async fn update_timeout( async fn update_timeout(
&mut self, &mut self,
exchange: &mut Exchange<'_>, exchange: &mut Exchange<'_>,
@ -314,36 +312,38 @@ impl<'a> Pake<'a> {
) -> Result<(), Error> { ) -> Result<(), Error> {
self.check_session(exchange, tx).await?; self.check_session(exchange, tx).await?;
let mut pase = self.pase.borrow_mut(); let status = {
let mut pase = self.pase.borrow_mut();
if pase if pase
.timeout .timeout
.as_ref() .as_ref()
.map(|sd| sd.is_sess_expired(pase.epoch)) .map(|sd| sd.is_sess_expired(pase.epoch))
.unwrap_or(false) .unwrap_or(false)
{ {
pase.timeout = None; pase.timeout = None;
} }
let status = if let Some(sd) = pase.timeout.as_mut() { if let Some(sd) = pase.timeout.as_mut() {
if &sd.exch_id != exchange.id() { if &sd.exch_id != exchange.id() {
info!("Other PAKE session in progress"); info!("Other PAKE session in progress");
Some(SCStatusCodes::Busy) Some(SCStatusCodes::Busy)
} else { } else {
None None
}
} else if new {
None
} else {
error!("PAKE session not found or expired");
Some(SCStatusCodes::SessionNotFound)
} }
} else if new {
None
} else {
error!("PAKE session not found or expired");
Some(SCStatusCodes::SessionNotFound)
}; };
if let Some(status) = status { if let Some(status) = status {
drop(pase);
complete_with_status(exchange, tx, status, None).await complete_with_status(exchange, tx, status, None).await
} else { } else {
let mut pase = self.pase.borrow_mut();
pase.timeout = Some(Timeout::new(exchange, pase.epoch)); pase.timeout = Some(Timeout::new(exchange, pase.epoch));
Ok(()) Ok(())

View file

@ -397,7 +397,6 @@ impl<'a> Matter<'a> {
} }
#[inline(always)] #[inline(always)]
#[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex
pub async fn handle_exchange<H>( pub async fn handle_exchange<H>(
&self, &self,
tx_buf: &mut [u8; MAX_TX_BUF_SIZE], tx_buf: &mut [u8; MAX_TX_BUF_SIZE],