Merge pull request #82 from ivmarkov/main
Do not hold on to RefCell borrows across await points
This commit is contained in:
		
						commit
						50f18dbbee
					
				
					 6 changed files with 210 additions and 227 deletions
				
			
		
							
								
								
									
										2
									
								
								.github/workflows/ci.yml
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yml
									
										
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -10,7 +10,7 @@ on:
 | 
			
		|||
  workflow_dispatch:
 | 
			
		||||
 | 
			
		||||
env:
 | 
			
		||||
  RUST_TOOLCHAIN: nightly-2023-07-01
 | 
			
		||||
  RUST_TOOLCHAIN: nightly
 | 
			
		||||
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  CARGO_TERM_COLOR: always
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -208,7 +208,6 @@ impl<'a> MdnsService<'a> {
 | 
			
		|||
        select(&mut broadcast, &mut respond).await.unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
 | 
			
		||||
        loop {
 | 
			
		||||
            select(
 | 
			
		||||
| 
						 | 
				
			
			@ -258,7 +257,6 @@ impl<'a> MdnsService<'a> {
 | 
			
		|||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
 | 
			
		||||
        loop {
 | 
			
		||||
            {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,7 +87,6 @@ impl<'a> Case<'a> {
 | 
			
		|||
        self.handle_casesigma3(exchange, rx, tx, &mut session).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn handle_casesigma3(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -97,23 +96,11 @@ impl<'a> Case<'a> {
 | 
			
		|||
    ) -> Result<(), Error> {
 | 
			
		||||
        rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
 | 
			
		||||
 | 
			
		||||
        let status = {
 | 
			
		||||
            let fabric_mgr = self.fabric_mgr.borrow();
 | 
			
		||||
 | 
			
		||||
            let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
			
		||||
        if fabric.is_none() {
 | 
			
		||||
            drop(fabric_mgr);
 | 
			
		||||
            complete_with_status(
 | 
			
		||||
                exchange,
 | 
			
		||||
                tx,
 | 
			
		||||
                common::SCStatusCodes::NoSharedTrustRoots,
 | 
			
		||||
                None,
 | 
			
		||||
            )
 | 
			
		||||
            .await?;
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
        // Safe to unwrap here
 | 
			
		||||
        let fabric = fabric.unwrap();
 | 
			
		||||
 | 
			
		||||
            if let Some(fabric) = fabric {
 | 
			
		||||
                let root = get_root_node_struct(rx.as_slice())?;
 | 
			
		||||
                let encrypted = root.find_tag(1)?.slice()?;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -125,7 +112,8 @@ impl<'a> Case<'a> {
 | 
			
		|||
                let decrypted = &mut decrypted[..encrypted.len()];
 | 
			
		||||
                decrypted.copy_from_slice(encrypted);
 | 
			
		||||
 | 
			
		||||
        let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
 | 
			
		||||
                let len =
 | 
			
		||||
                    Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
 | 
			
		||||
                let decrypted = &decrypted[..len];
 | 
			
		||||
 | 
			
		||||
                let root = get_root_node_struct(decrypted)?;
 | 
			
		||||
| 
						 | 
				
			
			@ -145,26 +133,17 @@ impl<'a> Case<'a> {
 | 
			
		|||
 | 
			
		||||
                if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
 | 
			
		||||
                    error!("Certificate Chain doesn't match: {}", e);
 | 
			
		||||
            complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
 | 
			
		||||
                .await?;
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if Case::validate_sigma3_sign(
 | 
			
		||||
                    SCStatusCodes::InvalidParameter
 | 
			
		||||
                } else if let Err(e) = Case::validate_sigma3_sign(
 | 
			
		||||
                    d.initiator_noc.0,
 | 
			
		||||
                    d.initiator_icac.map(|a| a.0),
 | 
			
		||||
                    &initiator_noc,
 | 
			
		||||
                    d.signature.0,
 | 
			
		||||
                    case_session,
 | 
			
		||||
        )
 | 
			
		||||
        .is_err()
 | 
			
		||||
        {
 | 
			
		||||
            error!("Sigma3 Signature doesn't match");
 | 
			
		||||
            complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
 | 
			
		||||
                .await?;
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
                ) {
 | 
			
		||||
                    error!("Sigma3 Signature doesn't match: {}", e);
 | 
			
		||||
                    SCStatusCodes::InvalidParameter
 | 
			
		||||
                } else {
 | 
			
		||||
                    // Only now do we add this message to the TT Hash
 | 
			
		||||
                    let mut peer_catids: NocCatIds = Default::default();
 | 
			
		||||
                    initiator_noc.get_cat_ids(&mut peer_catids);
 | 
			
		||||
| 
						 | 
				
			
			@ -179,18 +158,19 @@ impl<'a> Case<'a> {
 | 
			
		|||
                    )?;
 | 
			
		||||
 | 
			
		||||
                    // TODO: Handle NoSpace
 | 
			
		||||
        exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
 | 
			
		||||
                    exchange
 | 
			
		||||
                        .with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
 | 
			
		||||
 | 
			
		||||
        complete_with_status(
 | 
			
		||||
            exchange,
 | 
			
		||||
            tx,
 | 
			
		||||
            SCStatusCodes::SessionEstablishmentSuccess,
 | 
			
		||||
            None,
 | 
			
		||||
        )
 | 
			
		||||
        .await
 | 
			
		||||
                    SCStatusCodes::SessionEstablishmentSuccess
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                SCStatusCodes::NoSharedTrustRoots
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        complete_with_status(exchange, tx, status, None).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn handle_casesigma1(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -255,23 +235,13 @@ impl<'a> Case<'a> {
 | 
			
		|||
        const MAX_ENCRYPTED_SIZE: usize = 800;
 | 
			
		||||
 | 
			
		||||
        let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]);
 | 
			
		||||
        let encrypted_len = {
 | 
			
		||||
        let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
 | 
			
		||||
 | 
			
		||||
        let fabric_found = {
 | 
			
		||||
            let fabric_mgr = self.fabric_mgr.borrow();
 | 
			
		||||
 | 
			
		||||
            let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
			
		||||
            if fabric.is_none() {
 | 
			
		||||
                drop(fabric_mgr);
 | 
			
		||||
                complete_with_status(
 | 
			
		||||
                    exchange,
 | 
			
		||||
                    tx,
 | 
			
		||||
                    common::SCStatusCodes::NoSharedTrustRoots,
 | 
			
		||||
                    None,
 | 
			
		||||
                )
 | 
			
		||||
                .await?;
 | 
			
		||||
                return Ok(());
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if let Some(fabric) = fabric {
 | 
			
		||||
                #[cfg(feature = "alloc")]
 | 
			
		||||
                let signature_mut = &mut *signature;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -279,7 +249,7 @@ impl<'a> Case<'a> {
 | 
			
		|||
                let signature_mut = &mut signature;
 | 
			
		||||
 | 
			
		||||
                let sign_len = Case::get_sigma2_sign(
 | 
			
		||||
                fabric.unwrap(),
 | 
			
		||||
                    fabric,
 | 
			
		||||
                    &case_session.our_pub_key,
 | 
			
		||||
                    &case_session.peer_pub_key,
 | 
			
		||||
                    signature_mut,
 | 
			
		||||
| 
						 | 
				
			
			@ -292,15 +262,15 @@ impl<'a> Case<'a> {
 | 
			
		|||
                #[cfg(not(feature = "alloc"))]
 | 
			
		||||
                let encrypted_mut = &mut encrypted;
 | 
			
		||||
 | 
			
		||||
            Case::get_sigma2_encryption(
 | 
			
		||||
                fabric.unwrap(),
 | 
			
		||||
                let encrypted_len = Case::get_sigma2_encryption(
 | 
			
		||||
                    fabric,
 | 
			
		||||
                    self.rand,
 | 
			
		||||
                    &our_random,
 | 
			
		||||
                    case_session,
 | 
			
		||||
                    signature,
 | 
			
		||||
                    encrypted_mut,
 | 
			
		||||
            )?
 | 
			
		||||
        };
 | 
			
		||||
                )?;
 | 
			
		||||
 | 
			
		||||
                let encrypted = &encrypted[0..encrypted_len];
 | 
			
		||||
 | 
			
		||||
                // Generate our Response Body
 | 
			
		||||
| 
						 | 
				
			
			@ -318,7 +288,23 @@ impl<'a> Case<'a> {
 | 
			
		|||
 | 
			
		||||
                case_session.tt_hash.update(tx.as_mut_slice())?;
 | 
			
		||||
 | 
			
		||||
                true
 | 
			
		||||
            } else {
 | 
			
		||||
                false
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        if fabric_found {
 | 
			
		||||
            exchange.exchange(tx, rx).await
 | 
			
		||||
        } else {
 | 
			
		||||
            complete_with_status(
 | 
			
		||||
                exchange,
 | 
			
		||||
                tx,
 | 
			
		||||
                common::SCStatusCodes::NoSharedTrustRoots,
 | 
			
		||||
                None,
 | 
			
		||||
            )
 | 
			
		||||
            .await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn get_session_clone_data(
 | 
			
		||||
| 
						 | 
				
			
			@ -515,7 +501,7 @@ impl<'a> Case<'a> {
 | 
			
		|||
        fabric: &Fabric,
 | 
			
		||||
        rand: Rand,
 | 
			
		||||
        our_random: &[u8],
 | 
			
		||||
        case_session: &mut CaseSession,
 | 
			
		||||
        case_session: &CaseSession,
 | 
			
		||||
        signature: &[u8],
 | 
			
		||||
        out: &mut [u8],
 | 
			
		||||
    ) -> Result<usize, Error> {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -186,7 +186,7 @@ impl CryptoSpake2 {
 | 
			
		|||
        let (Z, V) = Self::get_ZV_as_verifier(
 | 
			
		||||
            &self.w0,
 | 
			
		||||
            &self.L,
 | 
			
		||||
            &mut self.M,
 | 
			
		||||
            &self.M,
 | 
			
		||||
            &X,
 | 
			
		||||
            &self.xy,
 | 
			
		||||
            &self.order,
 | 
			
		||||
| 
						 | 
				
			
			@ -228,7 +228,7 @@ impl CryptoSpake2 {
 | 
			
		|||
    fn get_ZV_as_prover(
 | 
			
		||||
        w0: &Mpi,
 | 
			
		||||
        w1: &Mpi,
 | 
			
		||||
        N: &mut EcPoint,
 | 
			
		||||
        N: &EcPoint,
 | 
			
		||||
        Y: &EcPoint,
 | 
			
		||||
        x: &Mpi,
 | 
			
		||||
        order: &Mpi,
 | 
			
		||||
| 
						 | 
				
			
			@ -264,7 +264,7 @@ impl CryptoSpake2 {
 | 
			
		|||
    fn get_ZV_as_verifier(
 | 
			
		||||
        w0: &Mpi,
 | 
			
		||||
        L: &EcPoint,
 | 
			
		||||
        M: &mut EcPoint,
 | 
			
		||||
        M: &EcPoint,
 | 
			
		||||
        X: &EcPoint,
 | 
			
		||||
        y: &Mpi,
 | 
			
		||||
        order: &Mpi,
 | 
			
		||||
| 
						 | 
				
			
			@ -292,7 +292,7 @@ impl CryptoSpake2 {
 | 
			
		|||
        Ok((Z, V))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn invert(group: &mut EcGroup, num: &EcPoint) -> Result<EcPoint, mbedtls::Error> {
 | 
			
		||||
    fn invert(group: &EcGroup, num: &EcPoint) -> Result<EcPoint, mbedtls::Error> {
 | 
			
		||||
        let p = group.p()?;
 | 
			
		||||
        let num_y = num.y()?;
 | 
			
		||||
        let inverted_num_y = p.sub(&num_y)?;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -213,7 +213,6 @@ impl<'a> Pake<'a> {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(non_snake_case)]
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn handle_pasepake1(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -224,6 +223,7 @@ impl<'a> Pake<'a> {
 | 
			
		|||
        rx.check_proto_opcode(OpCode::PASEPake1 as _)?;
 | 
			
		||||
        self.update_timeout(exchange, tx, false).await?;
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            let pase = self.pase.borrow();
 | 
			
		||||
            let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -244,12 +244,11 @@ impl<'a> Pake<'a> {
 | 
			
		|||
                cb: OctetStr(&cB),
 | 
			
		||||
            };
 | 
			
		||||
            resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        drop(pase);
 | 
			
		||||
        exchange.exchange(tx, rx).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn handle_pbkdfparamrequest(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -260,6 +259,7 @@ impl<'a> Pake<'a> {
 | 
			
		|||
        rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?;
 | 
			
		||||
        self.update_timeout(exchange, tx, true).await?;
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
            let pase = self.pase.borrow();
 | 
			
		||||
            let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -299,13 +299,11 @@ impl<'a> Pake<'a> {
 | 
			
		|||
            resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
			
		||||
 | 
			
		||||
            spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
 | 
			
		||||
 | 
			
		||||
        drop(pase);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        exchange.exchange(tx, rx).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn update_timeout(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -314,6 +312,7 @@ impl<'a> Pake<'a> {
 | 
			
		|||
    ) -> Result<(), Error> {
 | 
			
		||||
        self.check_session(exchange, tx).await?;
 | 
			
		||||
 | 
			
		||||
        let status = {
 | 
			
		||||
            let mut pase = self.pase.borrow_mut();
 | 
			
		||||
 | 
			
		||||
            if pase
 | 
			
		||||
| 
						 | 
				
			
			@ -325,7 +324,7 @@ impl<'a> Pake<'a> {
 | 
			
		|||
                pase.timeout = None;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        let status = if let Some(sd) = pase.timeout.as_mut() {
 | 
			
		||||
            if let Some(sd) = pase.timeout.as_mut() {
 | 
			
		||||
                if &sd.exch_id != exchange.id() {
 | 
			
		||||
                    info!("Other PAKE session in progress");
 | 
			
		||||
                    Some(SCStatusCodes::Busy)
 | 
			
		||||
| 
						 | 
				
			
			@ -337,13 +336,14 @@ impl<'a> Pake<'a> {
 | 
			
		|||
            } else {
 | 
			
		||||
                error!("PAKE session not found or expired");
 | 
			
		||||
                Some(SCStatusCodes::SessionNotFound)
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        if let Some(status) = status {
 | 
			
		||||
            drop(pase);
 | 
			
		||||
 | 
			
		||||
            complete_with_status(exchange, tx, status, None).await
 | 
			
		||||
        } else {
 | 
			
		||||
            let mut pase = self.pase.borrow_mut();
 | 
			
		||||
 | 
			
		||||
            pase.timeout = Some(Timeout::new(exchange, pase.epoch));
 | 
			
		||||
 | 
			
		||||
            Ok(())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -397,7 +397,6 @@ impl<'a> Matter<'a> {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    #[inline(always)]
 | 
			
		||||
    #[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex
 | 
			
		||||
    pub async fn handle_exchange<H>(
 | 
			
		||||
        &self,
 | 
			
		||||
        tx_buf: &mut [u8; MAX_TX_BUF_SIZE],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		
		Reference in a new issue