/* * * Copyright (c) 2020-2022 Project CHIP Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use core::{fmt::Write, time::Duration}; use super::{ common::{create_sc_status_report, SCStatusCodes}, spake2p::{Spake2P, VerifierData}, }; use crate::{ crypto, error::{Error, ErrorCode}, mdns::{Mdns, ServiceMode}, secure_channel::common::OpCode, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV}, transport::{ exchange::ExchangeCtx, network::Address, proto_ctx::ProtoCtx, session::{CloneData, SessionMode}, }, utils::{epoch::Epoch, rand::Rand}, }; use log::{error, info}; #[allow(clippy::large_enum_variant)] enum PaseMgrState { Enabled(Pake, heapless::String<16>), Disabled, } pub struct PaseMgr { state: PaseMgrState, epoch: Epoch, rand: Rand, } impl PaseMgr { #[inline(always)] pub fn new(epoch: Epoch, rand: Rand) -> Self { Self { state: PaseMgrState::Disabled, epoch, rand, } } pub fn is_pase_session_enabled(&self) -> bool { matches!(&self.state, PaseMgrState::Enabled(_, _)) } pub fn enable_pase_session( &mut self, verifier: VerifierData, discriminator: u16, mdns: &dyn Mdns, ) -> Result<(), Error> { let mut buf = [0; 8]; (self.rand)(&mut buf); let num = u64::from_be_bytes(buf); let mut mdns_service_name = heapless::String::<16>::new(); write!(&mut mdns_service_name, "{:016X}", num).unwrap(); mdns.add( &mdns_service_name, ServiceMode::Commissionable(discriminator), )?; self.state = PaseMgrState::Enabled( Pake::new(verifier, self.epoch, self.rand), mdns_service_name, ); Ok(()) } pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> { if let PaseMgrState::Enabled(_, mdns_service_name) = &self.state { mdns.remove(mdns_service_name)?; } self.state = PaseMgrState::Disabled; Ok(()) } /// If the PASE Session is enabled, execute the closure, /// if not enabled, generate SC Status Report fn if_enabled(&mut self, ctx: &mut ProtoCtx, f: F) -> Result, Error> where F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result, { if let PaseMgrState::Enabled(pake, _) = &mut self.state { let data = f(pake, ctx)?; Ok(Some(data)) } else { error!("PASE Not enabled"); create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None)?; Ok(None) } } pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result { ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?; Ok(true) } pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result { ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8); self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?; Ok(true) } pub fn pasepake3_handler( &mut self, ctx: &mut ProtoCtx, mdns: &dyn Mdns, ) -> Result<(bool, Option), Error> { let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; self.disable_pase_session(mdns)?; Ok((true, clone_data.flatten())) } } // This file basically deals with the handlers for the PASE secure channel protocol // TLV extraction and encoding is done in this file. // We create a Spake2p object and set it up in the exchange-data. This object then // handles Spake2+ specific stuff. const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60); const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys"; struct SessionData { start_time: Duration, exch_id: u16, peer_addr: Address, spake2p: Spake2P, } impl SessionData { fn is_sess_expired(&self, epoch: Epoch) -> Result { Ok(epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS) } } #[allow(clippy::large_enum_variant)] enum PakeState { Idle, InProgress(SessionData), } impl PakeState { const fn new() -> Self { Self::Idle } fn take(&mut self) -> Result { let new = core::mem::replace(self, PakeState::Idle); if let PakeState::InProgress(s) = new { Ok(s) } else { Err(ErrorCode::InvalidSignature.into()) } } fn is_idle(&self) -> bool { core::mem::discriminant(self) == core::mem::discriminant(&PakeState::Idle) } fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result { let sd = self.take()?; if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() { Err(ErrorCode::InvalidState.into()) } else { Ok(sd) } } fn make_in_progress(&mut self, epoch: Epoch, spake2p: Spake2P, exch_ctx: &ExchangeCtx) { *self = PakeState::InProgress(SessionData { start_time: epoch(), spake2p, exch_id: exch_ctx.exch.get_id(), peer_addr: exch_ctx.sess.get_peer_addr(), }); } fn set_sess_data(&mut self, sd: SessionData) { *self = PakeState::InProgress(sd); } } impl Default for PakeState { fn default() -> Self { Self::new() } } struct Pake { verifier: VerifierData, state: PakeState, epoch: Epoch, rand: Rand, } impl Pake { pub fn new(verifier: VerifierData, epoch: Epoch, rand: Rand) -> Self { // TODO: Can any PBKDF2 calculation be pre-computed here Self { verifier, state: PakeState::new(), epoch, rand, } } #[allow(non_snake_case)] pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result, Error> { let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; let (status_code, ke) = sd.spake2p.handle_cA(cA); let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess { // Get the keys let ke = ke.ok_or(ErrorCode::Invalid)?; let mut session_keys: [u8; 48] = [0; 48]; crypto::hkdf_sha256(&[], ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) .map_err(|_x| ErrorCode::NoSpace)?; // Create a session let data = sd.spake2p.get_app_data(); let peer_sessid: u16 = (data & 0xffff) as u16; let local_sessid: u16 = ((data >> 16) & 0xffff) as u16; let mut clone_data = CloneData::new( 0, 0, peer_sessid, local_sessid, ctx.exch_ctx.sess.get_peer_addr(), SessionMode::Pase, ); clone_data.dec_key.copy_from_slice(&session_keys[0..16]); clone_data.enc_key.copy_from_slice(&session_keys[16..32]); clone_data .att_challenge .copy_from_slice(&session_keys[32..48]); // Queue a transport mgr request to add a new session Some(clone_data) } else { None }; create_sc_status_report(ctx.tx, status_code, None)?; ctx.exch_ctx.exch.close(); Ok(clone_data) } #[allow(non_snake_case)] pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; let pA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?; let mut pB: [u8; 65] = [0; 65]; let mut cB: [u8; 32] = [0; 32]; sd.spake2p.start_verifier(&self.verifier)?; sd.spake2p.handle_pA(pA, &mut pB, &mut cB, self.rand)?; let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); let resp = Pake1Resp { pb: OctetStr(&pB), cb: OctetStr(&cB), }; resp.to_tlv(&mut tw, TagType::Anonymous)?; self.state.set_sess_data(sd); Ok(()) } pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { if !self.state.is_idle() { let sd = self.state.take()?; if sd.is_sess_expired(self.epoch)? { info!("Previous session expired, clearing it"); self.state = PakeState::Idle; } else { info!("Previous session in-progress, denying new request"); // little-endian timeout (here we've hardcoded 500ms) create_sc_status_report(ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; return Ok(()); } } let root = tlv::get_root_node(ctx.rx.as_slice())?; let a = PBKDFParamReq::from_tlv(&root)?; if a.passcode_id != 0 { error!("Can't yet handle passcode_id != 0"); Err(ErrorCode::Invalid)?; } let mut our_random: [u8; 32] = [0; 32]; (self.rand)(&mut our_random); let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; let mut spake2p = Spake2P::new(); spake2p.set_app_data(spake2p_data); // Generate response let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); let mut resp = PBKDFParamResp { init_random: a.initiator_random, our_random: OctetStr(&our_random), local_sessid, params: None, }; if !a.has_params { let params_resp = PBKDFParamRespParams { count: self.verifier.count, salt: OctetStr(&self.verifier.salt), }; resp.params = Some(params_resp); } resp.to_tlv(&mut tw, TagType::Anonymous)?; spake2p.set_context(ctx.rx.as_slice(), ctx.tx.as_mut_slice())?; self.state .make_in_progress(self.epoch, spake2p, &ctx.exch_ctx); Ok(()) } } #[derive(ToTLV)] #[tlvargs(start = 1)] struct Pake1Resp<'a> { pb: OctetStr<'a>, cb: OctetStr<'a>, } #[derive(ToTLV)] #[tlvargs(start = 1)] struct PBKDFParamRespParams<'a> { count: u32, salt: OctetStr<'a>, } #[derive(ToTLV)] #[tlvargs(start = 1)] struct PBKDFParamResp<'a> { init_random: OctetStr<'a>, our_random: OctetStr<'a>, local_sessid: u16, params: Option>, } #[allow(non_snake_case)] fn extract_pasepake_1_or_3_params(buf: &[u8]) -> Result<&[u8], Error> { let root = get_root_node_struct(buf)?; let pA = root.find_tag(1)?.slice()?; Ok(pA) } #[derive(FromTLV)] #[tlvargs(lifetime = "'a", start = 1)] struct PBKDFParamReq<'a> { initiator_random: OctetStr<'a>, initiator_ssid: u16, passcode_id: u16, has_params: bool, }