Merge pull request #82 from ivmarkov/main
Do not hold on to RefCell borrows across await points
This commit is contained in:
		
						commit
						50f18dbbee
					
				
					 6 changed files with 210 additions and 227 deletions
				
			
		
							
								
								
									
										2
									
								
								.github/workflows/ci.yml
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/ci.yml
									
										
									
									
										vendored
									
									
								
							| 
						 | 
					@ -10,7 +10,7 @@ on:
 | 
				
			||||||
  workflow_dispatch:
 | 
					  workflow_dispatch:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
env:
 | 
					env:
 | 
				
			||||||
  RUST_TOOLCHAIN: nightly-2023-07-01
 | 
					  RUST_TOOLCHAIN: nightly
 | 
				
			||||||
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 | 
					  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 | 
				
			||||||
  CARGO_TERM_COLOR: always
 | 
					  CARGO_TERM_COLOR: always
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -208,7 +208,6 @@ impl<'a> MdnsService<'a> {
 | 
				
			||||||
        select(&mut broadcast, &mut respond).await.unwrap()
 | 
					        select(&mut broadcast, &mut respond).await.unwrap()
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
					 | 
				
			||||||
    async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
 | 
					    async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
 | 
				
			||||||
        loop {
 | 
					        loop {
 | 
				
			||||||
            select(
 | 
					            select(
 | 
				
			||||||
| 
						 | 
					@ -258,7 +257,6 @@ impl<'a> MdnsService<'a> {
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
					 | 
				
			||||||
    async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
 | 
					    async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
 | 
				
			||||||
        loop {
 | 
					        loop {
 | 
				
			||||||
            {
 | 
					            {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -87,7 +87,6 @@ impl<'a> Case<'a> {
 | 
				
			||||||
        self.handle_casesigma3(exchange, rx, tx, &mut session).await
 | 
					        self.handle_casesigma3(exchange, rx, tx, &mut session).await
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
					 | 
				
			||||||
    async fn handle_casesigma3(
 | 
					    async fn handle_casesigma3(
 | 
				
			||||||
        &mut self,
 | 
					        &mut self,
 | 
				
			||||||
        exchange: &mut Exchange<'_>,
 | 
					        exchange: &mut Exchange<'_>,
 | 
				
			||||||
| 
						 | 
					@ -97,100 +96,81 @@ impl<'a> Case<'a> {
 | 
				
			||||||
    ) -> Result<(), Error> {
 | 
					    ) -> Result<(), Error> {
 | 
				
			||||||
        rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
 | 
					        rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let fabric_mgr = self.fabric_mgr.borrow();
 | 
					        let status = {
 | 
				
			||||||
 | 
					            let fabric_mgr = self.fabric_mgr.borrow();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
					            let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
				
			||||||
        if fabric.is_none() {
 | 
					            if let Some(fabric) = fabric {
 | 
				
			||||||
            drop(fabric_mgr);
 | 
					                let root = get_root_node_struct(rx.as_slice())?;
 | 
				
			||||||
            complete_with_status(
 | 
					                let encrypted = root.find_tag(1)?.slice()?;
 | 
				
			||||||
                exchange,
 | 
					 | 
				
			||||||
                tx,
 | 
					 | 
				
			||||||
                common::SCStatusCodes::NoSharedTrustRoots,
 | 
					 | 
				
			||||||
                None,
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
            .await?;
 | 
					 | 
				
			||||||
            return Ok(());
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        // Safe to unwrap here
 | 
					 | 
				
			||||||
        let fabric = fabric.unwrap();
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let root = get_root_node_struct(rx.as_slice())?;
 | 
					                let mut decrypted = alloc!([0; 800]);
 | 
				
			||||||
        let encrypted = root.find_tag(1)?.slice()?;
 | 
					                if encrypted.len() > decrypted.len() {
 | 
				
			||||||
 | 
					                    error!("Data too large");
 | 
				
			||||||
 | 
					                    Err(ErrorCode::NoSpace)?;
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					                let decrypted = &mut decrypted[..encrypted.len()];
 | 
				
			||||||
 | 
					                decrypted.copy_from_slice(encrypted);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let mut decrypted = alloc!([0; 800]);
 | 
					                let len =
 | 
				
			||||||
        if encrypted.len() > decrypted.len() {
 | 
					                    Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
 | 
				
			||||||
            error!("Data too large");
 | 
					                let decrypted = &decrypted[..len];
 | 
				
			||||||
            Err(ErrorCode::NoSpace)?;
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
        let decrypted = &mut decrypted[..encrypted.len()];
 | 
					 | 
				
			||||||
        decrypted.copy_from_slice(encrypted);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
 | 
					                let root = get_root_node_struct(decrypted)?;
 | 
				
			||||||
        let decrypted = &decrypted[..len];
 | 
					                let d = Sigma3Decrypt::from_tlv(&root)?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let root = get_root_node_struct(decrypted)?;
 | 
					                let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?);
 | 
				
			||||||
        let d = Sigma3Decrypt::from_tlv(&root)?;
 | 
					                let mut initiator_icac = None;
 | 
				
			||||||
 | 
					                if let Some(icac) = d.initiator_icac {
 | 
				
			||||||
 | 
					                    initiator_icac = Some(alloc!(Cert::new(icac.0)?));
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?);
 | 
					                #[cfg(feature = "alloc")]
 | 
				
			||||||
        let mut initiator_icac = None;
 | 
					                let initiator_icac_mut = initiator_icac.as_deref();
 | 
				
			||||||
        if let Some(icac) = d.initiator_icac {
 | 
					 | 
				
			||||||
            initiator_icac = Some(alloc!(Cert::new(icac.0)?));
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #[cfg(feature = "alloc")]
 | 
					                #[cfg(not(feature = "alloc"))]
 | 
				
			||||||
        let initiator_icac_mut = initiator_icac.as_deref();
 | 
					                let initiator_icac_mut = initiator_icac.as_ref();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        #[cfg(not(feature = "alloc"))]
 | 
					                if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
 | 
				
			||||||
        let initiator_icac_mut = initiator_icac.as_ref();
 | 
					                    error!("Certificate Chain doesn't match: {}", e);
 | 
				
			||||||
 | 
					                    SCStatusCodes::InvalidParameter
 | 
				
			||||||
 | 
					                } else if let Err(e) = Case::validate_sigma3_sign(
 | 
				
			||||||
 | 
					                    d.initiator_noc.0,
 | 
				
			||||||
 | 
					                    d.initiator_icac.map(|a| a.0),
 | 
				
			||||||
 | 
					                    &initiator_noc,
 | 
				
			||||||
 | 
					                    d.signature.0,
 | 
				
			||||||
 | 
					                    case_session,
 | 
				
			||||||
 | 
					                ) {
 | 
				
			||||||
 | 
					                    error!("Sigma3 Signature doesn't match: {}", e);
 | 
				
			||||||
 | 
					                    SCStatusCodes::InvalidParameter
 | 
				
			||||||
 | 
					                } else {
 | 
				
			||||||
 | 
					                    // Only now do we add this message to the TT Hash
 | 
				
			||||||
 | 
					                    let mut peer_catids: NocCatIds = Default::default();
 | 
				
			||||||
 | 
					                    initiator_noc.get_cat_ids(&mut peer_catids);
 | 
				
			||||||
 | 
					                    case_session.tt_hash.update(rx.as_slice())?;
 | 
				
			||||||
 | 
					                    let clone_data = Case::get_session_clone_data(
 | 
				
			||||||
 | 
					                        fabric.ipk.op_key(),
 | 
				
			||||||
 | 
					                        fabric.get_node_id(),
 | 
				
			||||||
 | 
					                        initiator_noc.get_node_id()?,
 | 
				
			||||||
 | 
					                        exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
 | 
				
			||||||
 | 
					                        case_session,
 | 
				
			||||||
 | 
					                        &peer_catids,
 | 
				
			||||||
 | 
					                    )?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
 | 
					                    // TODO: Handle NoSpace
 | 
				
			||||||
            error!("Certificate Chain doesn't match: {}", e);
 | 
					                    exchange
 | 
				
			||||||
            complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
 | 
					                        .with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
 | 
				
			||||||
                .await?;
 | 
					 | 
				
			||||||
            return Ok(());
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if Case::validate_sigma3_sign(
 | 
					                    SCStatusCodes::SessionEstablishmentSuccess
 | 
				
			||||||
            d.initiator_noc.0,
 | 
					                }
 | 
				
			||||||
            d.initiator_icac.map(|a| a.0),
 | 
					            } else {
 | 
				
			||||||
            &initiator_noc,
 | 
					                SCStatusCodes::NoSharedTrustRoots
 | 
				
			||||||
            d.signature.0,
 | 
					            }
 | 
				
			||||||
            case_session,
 | 
					        };
 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        .is_err()
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            error!("Sigma3 Signature doesn't match");
 | 
					 | 
				
			||||||
            complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
 | 
					 | 
				
			||||||
                .await?;
 | 
					 | 
				
			||||||
            return Ok(());
 | 
					 | 
				
			||||||
        }
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Only now do we add this message to the TT Hash
 | 
					        complete_with_status(exchange, tx, status, None).await
 | 
				
			||||||
        let mut peer_catids: NocCatIds = Default::default();
 | 
					 | 
				
			||||||
        initiator_noc.get_cat_ids(&mut peer_catids);
 | 
					 | 
				
			||||||
        case_session.tt_hash.update(rx.as_slice())?;
 | 
					 | 
				
			||||||
        let clone_data = Case::get_session_clone_data(
 | 
					 | 
				
			||||||
            fabric.ipk.op_key(),
 | 
					 | 
				
			||||||
            fabric.get_node_id(),
 | 
					 | 
				
			||||||
            initiator_noc.get_node_id()?,
 | 
					 | 
				
			||||||
            exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
 | 
					 | 
				
			||||||
            case_session,
 | 
					 | 
				
			||||||
            &peer_catids,
 | 
					 | 
				
			||||||
        )?;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        // TODO: Handle NoSpace
 | 
					 | 
				
			||||||
        exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        complete_with_status(
 | 
					 | 
				
			||||||
            exchange,
 | 
					 | 
				
			||||||
            tx,
 | 
					 | 
				
			||||||
            SCStatusCodes::SessionEstablishmentSuccess,
 | 
					 | 
				
			||||||
            None,
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
        .await
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
					 | 
				
			||||||
    async fn handle_casesigma1(
 | 
					    async fn handle_casesigma1(
 | 
				
			||||||
        &mut self,
 | 
					        &mut self,
 | 
				
			||||||
        exchange: &mut Exchange<'_>,
 | 
					        exchange: &mut Exchange<'_>,
 | 
				
			||||||
| 
						 | 
					@ -255,70 +235,76 @@ impl<'a> Case<'a> {
 | 
				
			||||||
        const MAX_ENCRYPTED_SIZE: usize = 800;
 | 
					        const MAX_ENCRYPTED_SIZE: usize = 800;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]);
 | 
					        let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]);
 | 
				
			||||||
        let encrypted_len = {
 | 
					        let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
 | 
				
			||||||
            let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
 | 
					
 | 
				
			||||||
 | 
					        let fabric_found = {
 | 
				
			||||||
            let fabric_mgr = self.fabric_mgr.borrow();
 | 
					            let fabric_mgr = self.fabric_mgr.borrow();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
					            let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
 | 
				
			||||||
            if fabric.is_none() {
 | 
					            if let Some(fabric) = fabric {
 | 
				
			||||||
                drop(fabric_mgr);
 | 
					                #[cfg(feature = "alloc")]
 | 
				
			||||||
                complete_with_status(
 | 
					                let signature_mut = &mut *signature;
 | 
				
			||||||
                    exchange,
 | 
					
 | 
				
			||||||
                    tx,
 | 
					                #[cfg(not(feature = "alloc"))]
 | 
				
			||||||
                    common::SCStatusCodes::NoSharedTrustRoots,
 | 
					                let signature_mut = &mut signature;
 | 
				
			||||||
                    None,
 | 
					
 | 
				
			||||||
                )
 | 
					                let sign_len = Case::get_sigma2_sign(
 | 
				
			||||||
                .await?;
 | 
					                    fabric,
 | 
				
			||||||
                return Ok(());
 | 
					                    &case_session.our_pub_key,
 | 
				
			||||||
 | 
					                    &case_session.peer_pub_key,
 | 
				
			||||||
 | 
					                    signature_mut,
 | 
				
			||||||
 | 
					                )?;
 | 
				
			||||||
 | 
					                let signature = &signature[..sign_len];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                #[cfg(feature = "alloc")]
 | 
				
			||||||
 | 
					                let encrypted_mut = &mut *encrypted;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                #[cfg(not(feature = "alloc"))]
 | 
				
			||||||
 | 
					                let encrypted_mut = &mut encrypted;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                let encrypted_len = Case::get_sigma2_encryption(
 | 
				
			||||||
 | 
					                    fabric,
 | 
				
			||||||
 | 
					                    self.rand,
 | 
				
			||||||
 | 
					                    &our_random,
 | 
				
			||||||
 | 
					                    case_session,
 | 
				
			||||||
 | 
					                    signature,
 | 
				
			||||||
 | 
					                    encrypted_mut,
 | 
				
			||||||
 | 
					                )?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                let encrypted = &encrypted[0..encrypted_len];
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                // Generate our Response Body
 | 
				
			||||||
 | 
					                tx.reset();
 | 
				
			||||||
 | 
					                tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
				
			||||||
 | 
					                tx.set_proto_opcode(OpCode::CASESigma2 as u8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
				
			||||||
 | 
					                tw.start_struct(TagType::Anonymous)?;
 | 
				
			||||||
 | 
					                tw.str8(TagType::Context(1), &our_random)?;
 | 
				
			||||||
 | 
					                tw.u16(TagType::Context(2), local_sessid)?;
 | 
				
			||||||
 | 
					                tw.str8(TagType::Context(3), &case_session.our_pub_key)?;
 | 
				
			||||||
 | 
					                tw.str16(TagType::Context(4), encrypted)?;
 | 
				
			||||||
 | 
					                tw.end_container()?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                case_session.tt_hash.update(tx.as_mut_slice())?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                true
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                false
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					 | 
				
			||||||
            #[cfg(feature = "alloc")]
 | 
					 | 
				
			||||||
            let signature_mut = &mut *signature;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            #[cfg(not(feature = "alloc"))]
 | 
					 | 
				
			||||||
            let signature_mut = &mut signature;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            let sign_len = Case::get_sigma2_sign(
 | 
					 | 
				
			||||||
                fabric.unwrap(),
 | 
					 | 
				
			||||||
                &case_session.our_pub_key,
 | 
					 | 
				
			||||||
                &case_session.peer_pub_key,
 | 
					 | 
				
			||||||
                signature_mut,
 | 
					 | 
				
			||||||
            )?;
 | 
					 | 
				
			||||||
            let signature = &signature[..sign_len];
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            #[cfg(feature = "alloc")]
 | 
					 | 
				
			||||||
            let encrypted_mut = &mut *encrypted;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            #[cfg(not(feature = "alloc"))]
 | 
					 | 
				
			||||||
            let encrypted_mut = &mut encrypted;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            Case::get_sigma2_encryption(
 | 
					 | 
				
			||||||
                fabric.unwrap(),
 | 
					 | 
				
			||||||
                self.rand,
 | 
					 | 
				
			||||||
                &our_random,
 | 
					 | 
				
			||||||
                case_session,
 | 
					 | 
				
			||||||
                signature,
 | 
					 | 
				
			||||||
                encrypted_mut,
 | 
					 | 
				
			||||||
            )?
 | 
					 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
        let encrypted = &encrypted[0..encrypted_len];
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Generate our Response Body
 | 
					        if fabric_found {
 | 
				
			||||||
        tx.reset();
 | 
					            exchange.exchange(tx, rx).await
 | 
				
			||||||
        tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
					        } else {
 | 
				
			||||||
        tx.set_proto_opcode(OpCode::CASESigma2 as u8);
 | 
					            complete_with_status(
 | 
				
			||||||
 | 
					                exchange,
 | 
				
			||||||
        let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
					                tx,
 | 
				
			||||||
        tw.start_struct(TagType::Anonymous)?;
 | 
					                common::SCStatusCodes::NoSharedTrustRoots,
 | 
				
			||||||
        tw.str8(TagType::Context(1), &our_random)?;
 | 
					                None,
 | 
				
			||||||
        tw.u16(TagType::Context(2), local_sessid)?;
 | 
					            )
 | 
				
			||||||
        tw.str8(TagType::Context(3), &case_session.our_pub_key)?;
 | 
					            .await
 | 
				
			||||||
        tw.str16(TagType::Context(4), encrypted)?;
 | 
					        }
 | 
				
			||||||
        tw.end_container()?;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        case_session.tt_hash.update(tx.as_mut_slice())?;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        exchange.exchange(tx, rx).await
 | 
					 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    fn get_session_clone_data(
 | 
					    fn get_session_clone_data(
 | 
				
			||||||
| 
						 | 
					@ -515,7 +501,7 @@ impl<'a> Case<'a> {
 | 
				
			||||||
        fabric: &Fabric,
 | 
					        fabric: &Fabric,
 | 
				
			||||||
        rand: Rand,
 | 
					        rand: Rand,
 | 
				
			||||||
        our_random: &[u8],
 | 
					        our_random: &[u8],
 | 
				
			||||||
        case_session: &mut CaseSession,
 | 
					        case_session: &CaseSession,
 | 
				
			||||||
        signature: &[u8],
 | 
					        signature: &[u8],
 | 
				
			||||||
        out: &mut [u8],
 | 
					        out: &mut [u8],
 | 
				
			||||||
    ) -> Result<usize, Error> {
 | 
					    ) -> Result<usize, Error> {
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -186,7 +186,7 @@ impl CryptoSpake2 {
 | 
				
			||||||
        let (Z, V) = Self::get_ZV_as_verifier(
 | 
					        let (Z, V) = Self::get_ZV_as_verifier(
 | 
				
			||||||
            &self.w0,
 | 
					            &self.w0,
 | 
				
			||||||
            &self.L,
 | 
					            &self.L,
 | 
				
			||||||
            &mut self.M,
 | 
					            &self.M,
 | 
				
			||||||
            &X,
 | 
					            &X,
 | 
				
			||||||
            &self.xy,
 | 
					            &self.xy,
 | 
				
			||||||
            &self.order,
 | 
					            &self.order,
 | 
				
			||||||
| 
						 | 
					@ -228,7 +228,7 @@ impl CryptoSpake2 {
 | 
				
			||||||
    fn get_ZV_as_prover(
 | 
					    fn get_ZV_as_prover(
 | 
				
			||||||
        w0: &Mpi,
 | 
					        w0: &Mpi,
 | 
				
			||||||
        w1: &Mpi,
 | 
					        w1: &Mpi,
 | 
				
			||||||
        N: &mut EcPoint,
 | 
					        N: &EcPoint,
 | 
				
			||||||
        Y: &EcPoint,
 | 
					        Y: &EcPoint,
 | 
				
			||||||
        x: &Mpi,
 | 
					        x: &Mpi,
 | 
				
			||||||
        order: &Mpi,
 | 
					        order: &Mpi,
 | 
				
			||||||
| 
						 | 
					@ -264,7 +264,7 @@ impl CryptoSpake2 {
 | 
				
			||||||
    fn get_ZV_as_verifier(
 | 
					    fn get_ZV_as_verifier(
 | 
				
			||||||
        w0: &Mpi,
 | 
					        w0: &Mpi,
 | 
				
			||||||
        L: &EcPoint,
 | 
					        L: &EcPoint,
 | 
				
			||||||
        M: &mut EcPoint,
 | 
					        M: &EcPoint,
 | 
				
			||||||
        X: &EcPoint,
 | 
					        X: &EcPoint,
 | 
				
			||||||
        y: &Mpi,
 | 
					        y: &Mpi,
 | 
				
			||||||
        order: &Mpi,
 | 
					        order: &Mpi,
 | 
				
			||||||
| 
						 | 
					@ -292,7 +292,7 @@ impl CryptoSpake2 {
 | 
				
			||||||
        Ok((Z, V))
 | 
					        Ok((Z, V))
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    fn invert(group: &mut EcGroup, num: &EcPoint) -> Result<EcPoint, mbedtls::Error> {
 | 
					    fn invert(group: &EcGroup, num: &EcPoint) -> Result<EcPoint, mbedtls::Error> {
 | 
				
			||||||
        let p = group.p()?;
 | 
					        let p = group.p()?;
 | 
				
			||||||
        let num_y = num.y()?;
 | 
					        let num_y = num.y()?;
 | 
				
			||||||
        let inverted_num_y = p.sub(&num_y)?;
 | 
					        let inverted_num_y = p.sub(&num_y)?;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -213,7 +213,6 @@ impl<'a> Pake<'a> {
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[allow(non_snake_case)]
 | 
					    #[allow(non_snake_case)]
 | 
				
			||||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
					 | 
				
			||||||
    async fn handle_pasepake1(
 | 
					    async fn handle_pasepake1(
 | 
				
			||||||
        &mut self,
 | 
					        &mut self,
 | 
				
			||||||
        exchange: &mut Exchange<'_>,
 | 
					        exchange: &mut Exchange<'_>,
 | 
				
			||||||
| 
						 | 
					@ -224,32 +223,32 @@ impl<'a> Pake<'a> {
 | 
				
			||||||
        rx.check_proto_opcode(OpCode::PASEPake1 as _)?;
 | 
					        rx.check_proto_opcode(OpCode::PASEPake1 as _)?;
 | 
				
			||||||
        self.update_timeout(exchange, tx, false).await?;
 | 
					        self.update_timeout(exchange, tx, false).await?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let pase = self.pase.borrow();
 | 
					        {
 | 
				
			||||||
        let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
					            let pase = self.pase.borrow();
 | 
				
			||||||
 | 
					            let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let pA = extract_pasepake_1_or_3_params(rx.as_slice())?;
 | 
					            let pA = extract_pasepake_1_or_3_params(rx.as_slice())?;
 | 
				
			||||||
        let mut pB: [u8; 65] = [0; 65];
 | 
					            let mut pB: [u8; 65] = [0; 65];
 | 
				
			||||||
        let mut cB: [u8; 32] = [0; 32];
 | 
					            let mut cB: [u8; 32] = [0; 32];
 | 
				
			||||||
        spake2p.start_verifier(&session.verifier)?;
 | 
					            spake2p.start_verifier(&session.verifier)?;
 | 
				
			||||||
        spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?;
 | 
					            spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Generate response
 | 
					            // Generate response
 | 
				
			||||||
        tx.reset();
 | 
					            tx.reset();
 | 
				
			||||||
        tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
					            tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
				
			||||||
        tx.set_proto_opcode(OpCode::PASEPake2 as u8);
 | 
					            tx.set_proto_opcode(OpCode::PASEPake2 as u8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
					            let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
				
			||||||
        let resp = Pake1Resp {
 | 
					            let resp = Pake1Resp {
 | 
				
			||||||
            pb: OctetStr(&pB),
 | 
					                pb: OctetStr(&pB),
 | 
				
			||||||
            cb: OctetStr(&cB),
 | 
					                cb: OctetStr(&cB),
 | 
				
			||||||
        };
 | 
					            };
 | 
				
			||||||
        resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
					            resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        drop(pase);
 | 
					 | 
				
			||||||
        exchange.exchange(tx, rx).await
 | 
					        exchange.exchange(tx, rx).await
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
					 | 
				
			||||||
    async fn handle_pbkdfparamrequest(
 | 
					    async fn handle_pbkdfparamrequest(
 | 
				
			||||||
        &mut self,
 | 
					        &mut self,
 | 
				
			||||||
        exchange: &mut Exchange<'_>,
 | 
					        exchange: &mut Exchange<'_>,
 | 
				
			||||||
| 
						 | 
					@ -260,52 +259,51 @@ impl<'a> Pake<'a> {
 | 
				
			||||||
        rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?;
 | 
					        rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?;
 | 
				
			||||||
        self.update_timeout(exchange, tx, true).await?;
 | 
					        self.update_timeout(exchange, tx, true).await?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let pase = self.pase.borrow();
 | 
					        {
 | 
				
			||||||
        let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
					            let pase = self.pase.borrow();
 | 
				
			||||||
 | 
					            let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let root = tlv::get_root_node(rx.as_slice())?;
 | 
					            let root = tlv::get_root_node(rx.as_slice())?;
 | 
				
			||||||
        let a = PBKDFParamReq::from_tlv(&root)?;
 | 
					            let a = PBKDFParamReq::from_tlv(&root)?;
 | 
				
			||||||
        if a.passcode_id != 0 {
 | 
					            if a.passcode_id != 0 {
 | 
				
			||||||
            error!("Can't yet handle passcode_id != 0");
 | 
					                error!("Can't yet handle passcode_id != 0");
 | 
				
			||||||
            Err(ErrorCode::Invalid)?;
 | 
					                Err(ErrorCode::Invalid)?;
 | 
				
			||||||
        }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let mut our_random: [u8; 32] = [0; 32];
 | 
					            let mut our_random: [u8; 32] = [0; 32];
 | 
				
			||||||
        (self.pase.borrow().rand)(&mut our_random);
 | 
					            (self.pase.borrow().rand)(&mut our_random);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
 | 
					            let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
 | 
				
			||||||
        let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
 | 
					            let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
 | 
				
			||||||
        spake2p.set_app_data(spake2p_data);
 | 
					            spake2p.set_app_data(spake2p_data);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        // Generate response
 | 
					            // Generate response
 | 
				
			||||||
        tx.reset();
 | 
					            tx.reset();
 | 
				
			||||||
        tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
					            tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
 | 
				
			||||||
        tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
 | 
					            tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
					            let mut tw = TLVWriter::new(tx.get_writebuf()?);
 | 
				
			||||||
        let mut resp = PBKDFParamResp {
 | 
					            let mut resp = PBKDFParamResp {
 | 
				
			||||||
            init_random: a.initiator_random,
 | 
					                init_random: a.initiator_random,
 | 
				
			||||||
            our_random: OctetStr(&our_random),
 | 
					                our_random: OctetStr(&our_random),
 | 
				
			||||||
            local_sessid,
 | 
					                local_sessid,
 | 
				
			||||||
            params: None,
 | 
					                params: None,
 | 
				
			||||||
        };
 | 
					 | 
				
			||||||
        if !a.has_params {
 | 
					 | 
				
			||||||
            let params_resp = PBKDFParamRespParams {
 | 
					 | 
				
			||||||
                count: session.verifier.count,
 | 
					 | 
				
			||||||
                salt: OctetStr(&session.verifier.salt),
 | 
					 | 
				
			||||||
            };
 | 
					            };
 | 
				
			||||||
            resp.params = Some(params_resp);
 | 
					            if !a.has_params {
 | 
				
			||||||
 | 
					                let params_resp = PBKDFParamRespParams {
 | 
				
			||||||
 | 
					                    count: session.verifier.count,
 | 
				
			||||||
 | 
					                    salt: OctetStr(&session.verifier.salt),
 | 
				
			||||||
 | 
					                };
 | 
				
			||||||
 | 
					                resp.params = Some(params_resp);
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        resp.to_tlv(&mut tw, TagType::Anonymous)?;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        drop(pase);
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        exchange.exchange(tx, rx).await
 | 
					        exchange.exchange(tx, rx).await
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[allow(clippy::await_holding_refcell_ref)]
 | 
					 | 
				
			||||||
    async fn update_timeout(
 | 
					    async fn update_timeout(
 | 
				
			||||||
        &mut self,
 | 
					        &mut self,
 | 
				
			||||||
        exchange: &mut Exchange<'_>,
 | 
					        exchange: &mut Exchange<'_>,
 | 
				
			||||||
| 
						 | 
					@ -314,36 +312,38 @@ impl<'a> Pake<'a> {
 | 
				
			||||||
    ) -> Result<(), Error> {
 | 
					    ) -> Result<(), Error> {
 | 
				
			||||||
        self.check_session(exchange, tx).await?;
 | 
					        self.check_session(exchange, tx).await?;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let mut pase = self.pase.borrow_mut();
 | 
					        let status = {
 | 
				
			||||||
 | 
					            let mut pase = self.pase.borrow_mut();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if pase
 | 
					            if pase
 | 
				
			||||||
            .timeout
 | 
					                .timeout
 | 
				
			||||||
            .as_ref()
 | 
					                .as_ref()
 | 
				
			||||||
            .map(|sd| sd.is_sess_expired(pase.epoch))
 | 
					                .map(|sd| sd.is_sess_expired(pase.epoch))
 | 
				
			||||||
            .unwrap_or(false)
 | 
					                .unwrap_or(false)
 | 
				
			||||||
        {
 | 
					            {
 | 
				
			||||||
            pase.timeout = None;
 | 
					                pase.timeout = None;
 | 
				
			||||||
        }
 | 
					            }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        let status = if let Some(sd) = pase.timeout.as_mut() {
 | 
					            if let Some(sd) = pase.timeout.as_mut() {
 | 
				
			||||||
            if &sd.exch_id != exchange.id() {
 | 
					                if &sd.exch_id != exchange.id() {
 | 
				
			||||||
                info!("Other PAKE session in progress");
 | 
					                    info!("Other PAKE session in progress");
 | 
				
			||||||
                Some(SCStatusCodes::Busy)
 | 
					                    Some(SCStatusCodes::Busy)
 | 
				
			||||||
            } else {
 | 
					                } else {
 | 
				
			||||||
                None
 | 
					                    None
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            } else if new {
 | 
				
			||||||
 | 
					                None
 | 
				
			||||||
 | 
					            } else {
 | 
				
			||||||
 | 
					                error!("PAKE session not found or expired");
 | 
				
			||||||
 | 
					                Some(SCStatusCodes::SessionNotFound)
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
        } else if new {
 | 
					 | 
				
			||||||
            None
 | 
					 | 
				
			||||||
        } else {
 | 
					 | 
				
			||||||
            error!("PAKE session not found or expired");
 | 
					 | 
				
			||||||
            Some(SCStatusCodes::SessionNotFound)
 | 
					 | 
				
			||||||
        };
 | 
					        };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if let Some(status) = status {
 | 
					        if let Some(status) = status {
 | 
				
			||||||
            drop(pase);
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            complete_with_status(exchange, tx, status, None).await
 | 
					            complete_with_status(exchange, tx, status, None).await
 | 
				
			||||||
        } else {
 | 
					        } else {
 | 
				
			||||||
 | 
					            let mut pase = self.pase.borrow_mut();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            pase.timeout = Some(Timeout::new(exchange, pase.epoch));
 | 
					            pase.timeout = Some(Timeout::new(exchange, pase.epoch));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            Ok(())
 | 
					            Ok(())
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -397,7 +397,6 @@ impl<'a> Matter<'a> {
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[inline(always)]
 | 
					    #[inline(always)]
 | 
				
			||||||
    #[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex
 | 
					 | 
				
			||||||
    pub async fn handle_exchange<H>(
 | 
					    pub async fn handle_exchange<H>(
 | 
				
			||||||
        &self,
 | 
					        &self,
 | 
				
			||||||
        tx_buf: &mut [u8; MAX_TX_BUF_SIZE],
 | 
					        tx_buf: &mut [u8; MAX_TX_BUF_SIZE],
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
	Add table
		
		Reference in a new issue