Merge pull request #82 from ivmarkov/main
Do not hold on to RefCell borrows across await points
This commit is contained in:
		
						commit
						50f18dbbee
					
				
					 6 changed files with 210 additions and 227 deletions
				
			
		
							
								
								
									
										2
									
								
								.github/workflows/ci.yml
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yml
									
										
									
									
										vendored
									
									
								
							| 
						 | 
				
			
			@ -10,7 +10,7 @@ on:
 | 
			
		|||
  workflow_dispatch:
 | 
			
		||||
 | 
			
		||||
env:
 | 
			
		||||
  RUST_TOOLCHAIN: nightly-2023-07-01
 | 
			
		||||
  RUST_TOOLCHAIN: nightly
 | 
			
		||||
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 | 
			
		||||
  CARGO_TERM_COLOR: always
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -208,7 +208,6 @@ impl<'a> MdnsService<'a> {
 | 
			
		|||
        select(&mut broadcast, &mut respond).await.unwrap()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
 | 
			
		||||
        loop {
 | 
			
		||||
            select(
 | 
			
		||||
| 
						 | 
				
			
			@ -258,7 +257,6 @@ impl<'a> MdnsService<'a> {
 | 
			
		|||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
 | 
			
		||||
        loop {
 | 
			
		||||
            {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -87,7 +87,6 @@ impl<'a> Case<'a> {
 | 
			
		|||
        self.handle_casesigma3(exchange, rx, tx, &mut session).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn handle_casesigma3(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -97,100 +96,81 @@ impl<'a> Case<'a> {
 | 
			
		|||
    ) -> Result<(), Error> {
 | 
			
		||||
        rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
 | 
			
		||||
 | 
			
		||||
        let fabric_mgr = self.fabric_mgr.borrow();
 | 
			
		||||
        let status = {
 | 
			
		||||
            let fabric_mgr = self.fabric_mgr.borrow();
 | 
			
		||||
 | 
			
		||||
        let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
			
		||||
        if fabric.is_none() {
 | 
			
		||||
            drop(fabric_mgr);
 | 
			
		||||
            complete_with_status(
 | 
			
		||||
                exchange,
 | 
			
		||||
                tx,
 | 
			
		||||
                common::SCStatusCodes::NoSharedTrustRoots,
 | 
			
		||||
                None,
 | 
			
		||||
            )
 | 
			
		||||
            .await?;
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
        // Safe to unwrap here
 | 
			
		||||
        let fabric = fabric.unwrap();
 | 
			
		||||
            let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
			
		||||
            if let Some(fabric) = fabric {
 | 
			
		||||
                let root = get_root_node_struct(rx.as_slice())?;
 | 
			
		||||
                let encrypted = root.find_tag(1)?.slice()?;
 | 
			
		||||
 | 
			
		||||
        let root = get_root_node_struct(rx.as_slice())?;
 | 
			
		||||
        let encrypted = root.find_tag(1)?.slice()?;
 | 
			
		||||
                let mut decrypted = alloc!([0; 800]);
 | 
			
		||||
                if encrypted.len() > decrypted.len() {
 | 
			
		||||
                    error!("Data too large");
 | 
			
		||||
                    Err(ErrorCode::NoSpace)?;
 | 
			
		||||
                }
 | 
			
		||||
                let decrypted = &mut decrypted[..encrypted.len()];
 | 
			
		||||
                decrypted.copy_from_slice(encrypted);
 | 
			
		||||
 | 
			
		||||
        let mut decrypted = alloc!([0; 800]);
 | 
			
		||||
        if encrypted.len() > decrypted.len() {
 | 
			
		||||
            error!("Data too large");
 | 
			
		||||
            Err(ErrorCode::NoSpace)?;
 | 
			
		||||
        }
 | 
			
		||||
        let decrypted = &mut decrypted[..encrypted.len()];
 | 
			
		||||
        decrypted.copy_from_slice(encrypted);
 | 
			
		||||
                let len =
 | 
			
		||||
                    Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
 | 
			
		||||
                let decrypted = &decrypted[..len];
 | 
			
		||||
 | 
			
		||||
        let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
 | 
			
		||||
        let decrypted = &decrypted[..len];
 | 
			
		||||
                let root = get_root_node_struct(decrypted)?;
 | 
			
		||||
                let d = Sigma3Decrypt::from_tlv(&root)?;
 | 
			
		||||
 | 
			
		||||
        let root = get_root_node_struct(decrypted)?;
 | 
			
		||||
        let d = Sigma3Decrypt::from_tlv(&root)?;
 | 
			
		||||
                let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?);
 | 
			
		||||
                let mut initiator_icac = None;
 | 
			
		||||
                if let Some(icac) = d.initiator_icac {
 | 
			
		||||
                    initiator_icac = Some(alloc!(Cert::new(icac.0)?));
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
        let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?);
 | 
			
		||||
        let mut initiator_icac = None;
 | 
			
		||||
        if let Some(icac) = d.initiator_icac {
 | 
			
		||||
            initiator_icac = Some(alloc!(Cert::new(icac.0)?));
 | 
			
		||||
        }
 | 
			
		||||
                #[cfg(feature = "alloc")]
 | 
			
		||||
                let initiator_icac_mut = initiator_icac.as_deref();
 | 
			
		||||
 | 
			
		||||
        #[cfg(feature = "alloc")]
 | 
			
		||||
        let initiator_icac_mut = initiator_icac.as_deref();
 | 
			
		||||
                #[cfg(not(feature = "alloc"))]
 | 
			
		||||
                let initiator_icac_mut = initiator_icac.as_ref();
 | 
			
		||||
 | 
			
		||||
        #[cfg(not(feature = "alloc"))]
 | 
			
		||||
        let initiator_icac_mut = initiator_icac.as_ref();
 | 
			
		||||
                if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
 | 
			
		||||
                    error!("Certificate Chain doesn't match: {}", e);
 | 
			
		||||
                    SCStatusCodes::InvalidParameter
 | 
			
		||||
                } else if let Err(e) = Case::validate_sigma3_sign(
 | 
			
		||||
                    d.initiator_noc.0,
 | 
			
		||||
                    d.initiator_icac.map(|a| a.0),
 | 
			
		||||
                    &initiator_noc,
 | 
			
		||||
                    d.signature.0,
 | 
			
		||||
                    case_session,
 | 
			
		||||
                ) {
 | 
			
		||||
                    error!("Sigma3 Signature doesn't match: {}", e);
 | 
			
		||||
                    SCStatusCodes::InvalidParameter
 | 
			
		||||
                } else {
 | 
			
		||||
                    // Only now do we add this message to the TT Hash
 | 
			
		||||
                    let mut peer_catids: NocCatIds = Default::default();
 | 
			
		||||
                    initiator_noc.get_cat_ids(&mut peer_catids);
 | 
			
		||||
                    case_session.tt_hash.update(rx.as_slice())?;
 | 
			
		||||
                    let clone_data = Case::get_session_clone_data(
 | 
			
		||||
                        fabric.ipk.op_key(),
 | 
			
		||||
                        fabric.get_node_id(),
 | 
			
		||||
                        initiator_noc.get_node_id()?,
 | 
			
		||||
                        exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
 | 
			
		||||
                        case_session,
 | 
			
		||||
                        &peer_catids,
 | 
			
		||||
                    )?;
 | 
			
		||||
 | 
			
		||||
        if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
 | 
			
		||||
            error!("Certificate Chain doesn't match: {}", e);
 | 
			
		||||
            complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
 | 
			
		||||
                .await?;
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
                    // TODO: Handle NoSpace
 | 
			
		||||
                    exchange
 | 
			
		||||
                        .with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
 | 
			
		||||
 | 
			
		||||
        if Case::validate_sigma3_sign(
 | 
			
		||||
            d.initiator_noc.0,
 | 
			
		||||
            d.initiator_icac.map(|a| a.0),
 | 
			
		||||
            &initiator_noc,
 | 
			
		||||
            d.signature.0,
 | 
			
		||||
            case_session,
 | 
			
		||||
        )
 | 
			
		||||
        .is_err()
 | 
			
		||||
        {
 | 
			
		||||
            error!("Sigma3 Signature doesn't match");
 | 
			
		||||
            complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
 | 
			
		||||
                .await?;
 | 
			
		||||
            return Ok(());
 | 
			
		||||
        }
 | 
			
		||||
                    SCStatusCodes::SessionEstablishmentSuccess
 | 
			
		||||
                }
 | 
			
		||||
            } else {
 | 
			
		||||
                SCStatusCodes::NoSharedTrustRoots
 | 
			
		||||
            }
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        // Only now do we add this message to the TT Hash
 | 
			
		||||
        let mut peer_catids: NocCatIds = Default::default();
 | 
			
		||||
        initiator_noc.get_cat_ids(&mut peer_catids);
 | 
			
		||||
        case_session.tt_hash.update(rx.as_slice())?;
 | 
			
		||||
        let clone_data = Case::get_session_clone_data(
 | 
			
		||||
            fabric.ipk.op_key(),
 | 
			
		||||
            fabric.get_node_id(),
 | 
			
		||||
            initiator_noc.get_node_id()?,
 | 
			
		||||
            exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
 | 
			
		||||
            case_session,
 | 
			
		||||
            &peer_catids,
 | 
			
		||||
        )?;
 | 
			
		||||
 | 
			
		||||
        // TODO: Handle NoSpace
 | 
			
		||||
        exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
 | 
			
		||||
 | 
			
		||||
        complete_with_status(
 | 
			
		||||
            exchange,
 | 
			
		||||
            tx,
 | 
			
		||||
            SCStatusCodes::SessionEstablishmentSuccess,
 | 
			
		||||
            None,
 | 
			
		||||
        )
 | 
			
		||||
        .await
 | 
			
		||||
        complete_with_status(exchange, tx, status, None).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn handle_casesigma1(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -255,70 +235,76 @@ impl<'a> Case<'a> {
 | 
			
		|||
        const MAX_ENCRYPTED_SIZE: usize = 800;
 | 
			
		||||
 | 
			
		||||
        let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]);
 | 
			
		||||
        let encrypted_len = {
 | 
			
		||||
            let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
 | 
			
		||||
        let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
 | 
			
		||||
 | 
			
		||||
        let fabric_found = {
 | 
			
		||||
            let fabric_mgr = self.fabric_mgr.borrow();
 | 
			
		||||
 | 
			
		||||
            let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
			
		||||
            if fabric.is_none() {
 | 
			
		||||
                drop(fabric_mgr);
 | 
			
		||||
                complete_with_status(
 | 
			
		||||
                    exchange,
 | 
			
		||||
                    tx,
 | 
			
		||||
                    common::SCStatusCodes::NoSharedTrustRoots,
 | 
			
		||||
                    None,
 | 
			
		||||
                )
 | 
			
		||||
                .await?;
 | 
			
		||||
                return Ok(());
 | 
			
		||||
            if let Some(fabric) = fabric {
 | 
			
		||||
                #[cfg(feature = "alloc")]
 | 
			
		||||
                let signature_mut = &mut *signature;
 | 
			
		||||
 | 
			
		||||
                #[cfg(not(feature = "alloc"))]
 | 
			
		||||
                let signature_mut = &mut signature;
 | 
			
		||||
 | 
			
		||||
                let sign_len = Case::get_sigma2_sign(
 | 
			
		||||
                    fabric,
 | 
			
		||||
                    &case_session.our_pub_key,
 | 
			
		||||
                    &case_session.peer_pub_key,
 | 
			
		||||
                    signature_mut,
 | 
			
		||||
                )?;
 | 
			
		||||
                let signature = &signature[..sign_len];
 | 
			
		||||
 | 
			
		||||
                #[cfg(feature = "alloc")]
 | 
			
		||||
                let encrypted_mut = &mut *encrypted;
 | 
			
		||||
 | 
			
		||||
                #[cfg(not(feature = "alloc"))]
 | 
			
		||||
                let encrypted_mut = &mut encrypted;
 | 
			
		||||
 | 
			
		||||
                let encrypted_len = Case::get_sigma2_encryption(
 | 
			
		||||
                    fabric,
 | 
			
		||||
                    self.rand,
 | 
			
		||||
                    &our_random,
 | 
			
		||||
                    case_session,
 | 
			
		||||
                    signature,
 | 
			
		||||
                    encrypted_mut,
 | 
			
		||||
                )?;
 | 
			
		||||
 | 
			
		||||
                let encrypted = &encrypted[0..encrypted_len];
 | 
			
		||||
 | 
			
		||||
                // Generate our Response Body
 | 
			
		||||
                tx.reset();
 | 
			
		||||
                tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
			
		||||
                tx.set_proto_opcode(OpCode::CASESigma2 as u8);
 | 
			
		||||
 | 
			
		||||
                let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
			
		||||
                tw.start_struct(TagType::Anonymous)?;
 | 
			
		||||
                tw.str8(TagType::Context(1), &our_random)?;
 | 
			
		||||
                tw.u16(TagType::Context(2), local_sessid)?;
 | 
			
		||||
                tw.str8(TagType::Context(3), &case_session.our_pub_key)?;
 | 
			
		||||
                tw.str16(TagType::Context(4), encrypted)?;
 | 
			
		||||
                tw.end_container()?;
 | 
			
		||||
 | 
			
		||||
                case_session.tt_hash.update(tx.as_mut_slice())?;
 | 
			
		||||
 | 
			
		||||
                true
 | 
			
		||||
            } else {
 | 
			
		||||
                false
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            #[cfg(feature = "alloc")]
 | 
			
		||||
            let signature_mut = &mut *signature;
 | 
			
		||||
 | 
			
		||||
            #[cfg(not(feature = "alloc"))]
 | 
			
		||||
            let signature_mut = &mut signature;
 | 
			
		||||
 | 
			
		||||
            let sign_len = Case::get_sigma2_sign(
 | 
			
		||||
                fabric.unwrap(),
 | 
			
		||||
                &case_session.our_pub_key,
 | 
			
		||||
                &case_session.peer_pub_key,
 | 
			
		||||
                signature_mut,
 | 
			
		||||
            )?;
 | 
			
		||||
            let signature = &signature[..sign_len];
 | 
			
		||||
 | 
			
		||||
            #[cfg(feature = "alloc")]
 | 
			
		||||
            let encrypted_mut = &mut *encrypted;
 | 
			
		||||
 | 
			
		||||
            #[cfg(not(feature = "alloc"))]
 | 
			
		||||
            let encrypted_mut = &mut encrypted;
 | 
			
		||||
 | 
			
		||||
            Case::get_sigma2_encryption(
 | 
			
		||||
                fabric.unwrap(),
 | 
			
		||||
                self.rand,
 | 
			
		||||
                &our_random,
 | 
			
		||||
                case_session,
 | 
			
		||||
                signature,
 | 
			
		||||
                encrypted_mut,
 | 
			
		||||
            )?
 | 
			
		||||
        };
 | 
			
		||||
        let encrypted = &encrypted[0..encrypted_len];
 | 
			
		||||
 | 
			
		||||
        // Generate our Response Body
 | 
			
		||||
        tx.reset();
 | 
			
		||||
        tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
			
		||||
        tx.set_proto_opcode(OpCode::CASESigma2 as u8);
 | 
			
		||||
 | 
			
		||||
        let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
			
		||||
        tw.start_struct(TagType::Anonymous)?;
 | 
			
		||||
        tw.str8(TagType::Context(1), &our_random)?;
 | 
			
		||||
        tw.u16(TagType::Context(2), local_sessid)?;
 | 
			
		||||
        tw.str8(TagType::Context(3), &case_session.our_pub_key)?;
 | 
			
		||||
        tw.str16(TagType::Context(4), encrypted)?;
 | 
			
		||||
        tw.end_container()?;
 | 
			
		||||
 | 
			
		||||
        case_session.tt_hash.update(tx.as_mut_slice())?;
 | 
			
		||||
 | 
			
		||||
        exchange.exchange(tx, rx).await
 | 
			
		||||
        if fabric_found {
 | 
			
		||||
            exchange.exchange(tx, rx).await
 | 
			
		||||
        } else {
 | 
			
		||||
            complete_with_status(
 | 
			
		||||
                exchange,
 | 
			
		||||
                tx,
 | 
			
		||||
                common::SCStatusCodes::NoSharedTrustRoots,
 | 
			
		||||
                None,
 | 
			
		||||
            )
 | 
			
		||||
            .await
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn get_session_clone_data(
 | 
			
		||||
| 
						 | 
				
			
			@ -515,7 +501,7 @@ impl<'a> Case<'a> {
 | 
			
		|||
        fabric: &Fabric,
 | 
			
		||||
        rand: Rand,
 | 
			
		||||
        our_random: &[u8],
 | 
			
		||||
        case_session: &mut CaseSession,
 | 
			
		||||
        case_session: &CaseSession,
 | 
			
		||||
        signature: &[u8],
 | 
			
		||||
        out: &mut [u8],
 | 
			
		||||
    ) -> Result<usize, Error> {
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -186,7 +186,7 @@ impl CryptoSpake2 {
 | 
			
		|||
        let (Z, V) = Self::get_ZV_as_verifier(
 | 
			
		||||
            &self.w0,
 | 
			
		||||
            &self.L,
 | 
			
		||||
            &mut self.M,
 | 
			
		||||
            &self.M,
 | 
			
		||||
            &X,
 | 
			
		||||
            &self.xy,
 | 
			
		||||
            &self.order,
 | 
			
		||||
| 
						 | 
				
			
			@ -228,7 +228,7 @@ impl CryptoSpake2 {
 | 
			
		|||
    fn get_ZV_as_prover(
 | 
			
		||||
        w0: &Mpi,
 | 
			
		||||
        w1: &Mpi,
 | 
			
		||||
        N: &mut EcPoint,
 | 
			
		||||
        N: &EcPoint,
 | 
			
		||||
        Y: &EcPoint,
 | 
			
		||||
        x: &Mpi,
 | 
			
		||||
        order: &Mpi,
 | 
			
		||||
| 
						 | 
				
			
			@ -264,7 +264,7 @@ impl CryptoSpake2 {
 | 
			
		|||
    fn get_ZV_as_verifier(
 | 
			
		||||
        w0: &Mpi,
 | 
			
		||||
        L: &EcPoint,
 | 
			
		||||
        M: &mut EcPoint,
 | 
			
		||||
        M: &EcPoint,
 | 
			
		||||
        X: &EcPoint,
 | 
			
		||||
        y: &Mpi,
 | 
			
		||||
        order: &Mpi,
 | 
			
		||||
| 
						 | 
				
			
			@ -292,7 +292,7 @@ impl CryptoSpake2 {
 | 
			
		|||
        Ok((Z, V))
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn invert(group: &mut EcGroup, num: &EcPoint) -> Result<EcPoint, mbedtls::Error> {
 | 
			
		||||
    fn invert(group: &EcGroup, num: &EcPoint) -> Result<EcPoint, mbedtls::Error> {
 | 
			
		||||
        let p = group.p()?;
 | 
			
		||||
        let num_y = num.y()?;
 | 
			
		||||
        let inverted_num_y = p.sub(&num_y)?;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -213,7 +213,6 @@ impl<'a> Pake<'a> {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(non_snake_case)]
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn handle_pasepake1(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -224,32 +223,32 @@ impl<'a> Pake<'a> {
 | 
			
		|||
        rx.check_proto_opcode(OpCode::PASEPake1 as _)?;
 | 
			
		||||
        self.update_timeout(exchange, tx, false).await?;
 | 
			
		||||
 | 
			
		||||
        let pase = self.pase.borrow();
 | 
			
		||||
        let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
			
		||||
        {
 | 
			
		||||
            let pase = self.pase.borrow();
 | 
			
		||||
            let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
			
		||||
 | 
			
		||||
        let pA = extract_pasepake_1_or_3_params(rx.as_slice())?;
 | 
			
		||||
        let mut pB: [u8; 65] = [0; 65];
 | 
			
		||||
        let mut cB: [u8; 32] = [0; 32];
 | 
			
		||||
        spake2p.start_verifier(&session.verifier)?;
 | 
			
		||||
        spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?;
 | 
			
		||||
            let pA = extract_pasepake_1_or_3_params(rx.as_slice())?;
 | 
			
		||||
            let mut pB: [u8; 65] = [0; 65];
 | 
			
		||||
            let mut cB: [u8; 32] = [0; 32];
 | 
			
		||||
            spake2p.start_verifier(&session.verifier)?;
 | 
			
		||||
            spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?;
 | 
			
		||||
 | 
			
		||||
        // Generate response
 | 
			
		||||
        tx.reset();
 | 
			
		||||
        tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
			
		||||
        tx.set_proto_opcode(OpCode::PASEPake2 as u8);
 | 
			
		||||
            // Generate response
 | 
			
		||||
            tx.reset();
 | 
			
		||||
            tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
			
		||||
            tx.set_proto_opcode(OpCode::PASEPake2 as u8);
 | 
			
		||||
 | 
			
		||||
        let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
			
		||||
        let resp = Pake1Resp {
 | 
			
		||||
            pb: OctetStr(&pB),
 | 
			
		||||
            cb: OctetStr(&cB),
 | 
			
		||||
        };
 | 
			
		||||
        resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
			
		||||
            let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
			
		||||
            let resp = Pake1Resp {
 | 
			
		||||
                pb: OctetStr(&pB),
 | 
			
		||||
                cb: OctetStr(&cB),
 | 
			
		||||
            };
 | 
			
		||||
            resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        drop(pase);
 | 
			
		||||
        exchange.exchange(tx, rx).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn handle_pbkdfparamrequest(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -260,52 +259,51 @@ impl<'a> Pake<'a> {
 | 
			
		|||
        rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?;
 | 
			
		||||
        self.update_timeout(exchange, tx, true).await?;
 | 
			
		||||
 | 
			
		||||
        let pase = self.pase.borrow();
 | 
			
		||||
        let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
			
		||||
        {
 | 
			
		||||
            let pase = self.pase.borrow();
 | 
			
		||||
            let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
			
		||||
 | 
			
		||||
        let root = tlv::get_root_node(rx.as_slice())?;
 | 
			
		||||
        let a = PBKDFParamReq::from_tlv(&root)?;
 | 
			
		||||
        if a.passcode_id != 0 {
 | 
			
		||||
            error!("Can't yet handle passcode_id != 0");
 | 
			
		||||
            Err(ErrorCode::Invalid)?;
 | 
			
		||||
        }
 | 
			
		||||
            let root = tlv::get_root_node(rx.as_slice())?;
 | 
			
		||||
            let a = PBKDFParamReq::from_tlv(&root)?;
 | 
			
		||||
            if a.passcode_id != 0 {
 | 
			
		||||
                error!("Can't yet handle passcode_id != 0");
 | 
			
		||||
                Err(ErrorCode::Invalid)?;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
        let mut our_random: [u8; 32] = [0; 32];
 | 
			
		||||
        (self.pase.borrow().rand)(&mut our_random);
 | 
			
		||||
            let mut our_random: [u8; 32] = [0; 32];
 | 
			
		||||
            (self.pase.borrow().rand)(&mut our_random);
 | 
			
		||||
 | 
			
		||||
        let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
 | 
			
		||||
        let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
 | 
			
		||||
        spake2p.set_app_data(spake2p_data);
 | 
			
		||||
            let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
 | 
			
		||||
            let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
 | 
			
		||||
            spake2p.set_app_data(spake2p_data);
 | 
			
		||||
 | 
			
		||||
        // Generate response
 | 
			
		||||
        tx.reset();
 | 
			
		||||
        tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
			
		||||
        tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
 | 
			
		||||
            // Generate response
 | 
			
		||||
            tx.reset();
 | 
			
		||||
            tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
			
		||||
            tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
 | 
			
		||||
 | 
			
		||||
        let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
			
		||||
        let mut resp = PBKDFParamResp {
 | 
			
		||||
            init_random: a.initiator_random,
 | 
			
		||||
            our_random: OctetStr(&our_random),
 | 
			
		||||
            local_sessid,
 | 
			
		||||
            params: None,
 | 
			
		||||
        };
 | 
			
		||||
        if !a.has_params {
 | 
			
		||||
            let params_resp = PBKDFParamRespParams {
 | 
			
		||||
                count: session.verifier.count,
 | 
			
		||||
                salt: OctetStr(&session.verifier.salt),
 | 
			
		||||
            let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
			
		||||
            let mut resp = PBKDFParamResp {
 | 
			
		||||
                init_random: a.initiator_random,
 | 
			
		||||
                our_random: OctetStr(&our_random),
 | 
			
		||||
                local_sessid,
 | 
			
		||||
                params: None,
 | 
			
		||||
            };
 | 
			
		||||
            resp.params = Some(params_resp);
 | 
			
		||||
            if !a.has_params {
 | 
			
		||||
                let params_resp = PBKDFParamRespParams {
 | 
			
		||||
                    count: session.verifier.count,
 | 
			
		||||
                    salt: OctetStr(&session.verifier.salt),
 | 
			
		||||
                };
 | 
			
		||||
                resp.params = Some(params_resp);
 | 
			
		||||
            }
 | 
			
		||||
            resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
			
		||||
 | 
			
		||||
            spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
 | 
			
		||||
        }
 | 
			
		||||
        resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
			
		||||
 | 
			
		||||
        spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
 | 
			
		||||
 | 
			
		||||
        drop(pase);
 | 
			
		||||
 | 
			
		||||
        exchange.exchange(tx, rx).await
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
			
		||||
    async fn update_timeout(
 | 
			
		||||
        &mut self,
 | 
			
		||||
        exchange: &mut Exchange<'_>,
 | 
			
		||||
| 
						 | 
				
			
			@ -314,36 +312,38 @@ impl<'a> Pake<'a> {
 | 
			
		|||
    ) -> Result<(), Error> {
 | 
			
		||||
        self.check_session(exchange, tx).await?;
 | 
			
		||||
 | 
			
		||||
        let mut pase = self.pase.borrow_mut();
 | 
			
		||||
        let status = {
 | 
			
		||||
            let mut pase = self.pase.borrow_mut();
 | 
			
		||||
 | 
			
		||||
        if pase
 | 
			
		||||
            .timeout
 | 
			
		||||
            .as_ref()
 | 
			
		||||
            .map(|sd| sd.is_sess_expired(pase.epoch))
 | 
			
		||||
            .unwrap_or(false)
 | 
			
		||||
        {
 | 
			
		||||
            pase.timeout = None;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        let status = if let Some(sd) = pase.timeout.as_mut() {
 | 
			
		||||
            if &sd.exch_id != exchange.id() {
 | 
			
		||||
                info!("Other PAKE session in progress");
 | 
			
		||||
                Some(SCStatusCodes::Busy)
 | 
			
		||||
            } else {
 | 
			
		||||
                None
 | 
			
		||||
            if pase
 | 
			
		||||
                .timeout
 | 
			
		||||
                .as_ref()
 | 
			
		||||
                .map(|sd| sd.is_sess_expired(pase.epoch))
 | 
			
		||||
                .unwrap_or(false)
 | 
			
		||||
            {
 | 
			
		||||
                pase.timeout = None;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if let Some(sd) = pase.timeout.as_mut() {
 | 
			
		||||
                if &sd.exch_id != exchange.id() {
 | 
			
		||||
                    info!("Other PAKE session in progress");
 | 
			
		||||
                    Some(SCStatusCodes::Busy)
 | 
			
		||||
                } else {
 | 
			
		||||
                    None
 | 
			
		||||
                }
 | 
			
		||||
            } else if new {
 | 
			
		||||
                None
 | 
			
		||||
            } else {
 | 
			
		||||
                error!("PAKE session not found or expired");
 | 
			
		||||
                Some(SCStatusCodes::SessionNotFound)
 | 
			
		||||
            }
 | 
			
		||||
        } else if new {
 | 
			
		||||
            None
 | 
			
		||||
        } else {
 | 
			
		||||
            error!("PAKE session not found or expired");
 | 
			
		||||
            Some(SCStatusCodes::SessionNotFound)
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        if let Some(status) = status {
 | 
			
		||||
            drop(pase);
 | 
			
		||||
 | 
			
		||||
            complete_with_status(exchange, tx, status, None).await
 | 
			
		||||
        } else {
 | 
			
		||||
            let mut pase = self.pase.borrow_mut();
 | 
			
		||||
 | 
			
		||||
            pase.timeout = Some(Timeout::new(exchange, pase.epoch));
 | 
			
		||||
 | 
			
		||||
            Ok(())
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -397,7 +397,6 @@ impl<'a> Matter<'a> {
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
    #[inline(always)]
 | 
			
		||||
    #[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex
 | 
			
		||||
    pub async fn handle_exchange<H>(
 | 
			
		||||
        &self,
 | 
			
		||||
        tx_buf: &mut [u8; MAX_TX_BUF_SIZE],
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
	Add table
		
		Reference in a new issue