Remove allocations from Cert handling

This commit is contained in:
ivmarkov 2023-04-24 06:10:58 +00:00
parent f7a887c1d2
commit d82e9ec0af
7 changed files with 205 additions and 242 deletions

View file

@ -17,7 +17,7 @@
use super::{CertConsumer, MAX_DEPTH}; use super::{CertConsumer, MAX_DEPTH};
use crate::error::Error; use crate::error::Error;
use chrono::{Datelike, TimeZone, Utc}; use chrono::{Datelike, TimeZone, Utc}; // TODO
use core::fmt::Write; use core::fmt::Write;
use log::warn; use log::warn;

View file

@ -15,23 +15,22 @@
* limitations under the License. * limitations under the License.
*/ */
use core::fmt; use core::fmt::{self, Write};
extern crate alloc;
use crate::{ use crate::{
crypto::KeyPair, crypto::KeyPair,
error::Error, error::Error,
tlv::{self, FromTLV, TLVArrayOwned, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV},
utils::writebuf::WriteBuf, utils::writebuf::WriteBuf,
}; };
use alloc::{format, string::String, vec::Vec};
use log::error; use log::error;
use num_derive::FromPrimitive; use num_derive::FromPrimitive;
pub use self::asn1_writer::ASN1Writer; pub use self::asn1_writer::ASN1Writer;
use self::printer::CertPrinter; use self::printer::CertPrinter;
pub const MAX_CERT_TLV_LEN: usize = 300; // TODO
// As per https://datatracker.ietf.org/doc/html/rfc5280 // As per https://datatracker.ietf.org/doc/html/rfc5280
const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01]; const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01];
@ -116,8 +115,10 @@ macro_rules! add_if {
}; };
} }
fn get_print_str(key_usage: u16) -> String { fn get_print_str(key_usage: u16) -> heapless::String<256> {
format!( let mut string = heapless::String::new();
write!(
&mut string,
"{}{}{}{}{}{}{}{}{}", "{}{}{}{}{}{}{}{}{}",
add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "), add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "),
add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "), add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "),
@ -129,6 +130,9 @@ fn get_print_str(key_usage: u16) -> String {
add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "), add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "),
add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "), add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "),
) )
.unwrap();
string
} }
#[allow(unused_assignments)] #[allow(unused_assignments)]
@ -140,7 +144,7 @@ fn encode_key_usage(key_usage: u16, w: &mut dyn CertConsumer) -> Result<(), Erro
} }
fn encode_extended_key_usage( fn encode_extended_key_usage(
list: &TLVArrayOwned<u8>, list: impl Iterator<Item = u8>,
w: &mut dyn CertConsumer, w: &mut dyn CertConsumer,
) -> Result<(), Error> { ) -> Result<(), Error> {
const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01]; const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01];
@ -160,19 +164,18 @@ fn encode_extended_key_usage(
]; ];
w.start_seq("")?; w.start_seq("")?;
for t in list.iter() { for t in list {
let t = *t as usize; let t = t as usize;
if t > 0 && t <= encoding.len() { if t > 0 && t <= encoding.len() {
w.oid(encoding[t].0, encoding[t].1)?; w.oid(encoding[t].0, encoding[t].1)?;
} else { } else {
error!("Skipping encoding key usage out of bounds"); error!("Skipping encoding key usage out of bounds");
} }
} }
w.end_seq()?; w.end_seq()
Ok(())
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1)] #[tlvargs(start = 1)]
struct BasicConstraints { struct BasicConstraints {
is_ca: bool, is_ca: bool,
@ -212,18 +215,18 @@ fn encode_extension_end(w: &mut dyn CertConsumer) -> Result<(), Error> {
w.end_seq() w.end_seq()
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1, datatype = "list")] #[tlvargs(lifetime = "'a", start = 1, datatype = "list")]
struct Extensions { struct Extensions<'a> {
basic_const: Option<BasicConstraints>, basic_const: Option<BasicConstraints>,
key_usage: Option<u16>, key_usage: Option<u16>,
ext_key_usage: Option<TLVArrayOwned<u8>>, ext_key_usage: Option<TLVArray<'a, u8>>,
subj_key_id: Option<Vec<u8>>, subj_key_id: Option<OctetStr<'a>>,
auth_key_id: Option<Vec<u8>>, auth_key_id: Option<OctetStr<'a>>,
future_extensions: Option<Vec<u8>>, future_extensions: Option<OctetStr<'a>>,
} }
impl Extensions { impl<'a> Extensions<'a> {
fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> {
const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13]; const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13];
const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F]; const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F];
@ -245,30 +248,29 @@ impl Extensions {
} }
if let Some(t) = &self.ext_key_usage { if let Some(t) = &self.ext_key_usage {
encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?; encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?;
encode_extended_key_usage(t, w)?; encode_extended_key_usage(t.iter(), w)?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.subj_key_id { if let Some(t) = &self.subj_key_id {
encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?; encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?;
w.ostr("", t.as_slice())?; w.ostr("", t.0)?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.auth_key_id { if let Some(t) = &self.auth_key_id {
encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?; encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?;
w.start_seq("")?; w.start_seq("")?;
w.ctx("", 0, t.as_slice())?; w.ctx("", 0, t.0)?;
w.end_seq()?; w.end_seq()?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.future_extensions { if let Some(t) = &self.future_extensions {
error!("Future Extensions Not Yet Supported: {:x?}", t.as_slice()) error!("Future Extensions Not Yet Supported: {:x?}", t.0);
} }
w.end_seq()?; w.end_seq()?;
w.end_ctx()?; w.end_ctx()?;
Ok(()) Ok(())
} }
} }
const MAX_DN_ENTRIES: usize = 5;
#[derive(FromPrimitive, Copy, Clone)] #[derive(FromPrimitive, Copy, Clone)]
enum DnTags { enum DnTags {
@ -296,20 +298,23 @@ enum DnTags {
NocCat = 22, NocCat = 22,
} }
enum DistNameValue { #[derive(Debug)]
enum DistNameValue<'a> {
Uint(u64), Uint(u64),
Utf8Str(Vec<u8>), Utf8Str(&'a [u8]),
PrintableStr(Vec<u8>), PrintableStr(&'a [u8]),
} }
#[derive(Default)] const MAX_DN_ENTRIES: usize = 5;
struct DistNames {
#[derive(Default, Debug)]
struct DistNames<'a> {
// The order in which the DNs arrive is important, as the signing // The order in which the DNs arrive is important, as the signing
// requires that the ASN1 notation retains the same order // requires that the ASN1 notation retains the same order
dn: Vec<(u8, DistNameValue)>, dn: heapless::Vec<(u8, DistNameValue<'a>), MAX_DN_ENTRIES>,
} }
impl DistNames { impl<'a> DistNames<'a> {
fn u64(&self, match_id: DnTags) -> Option<u64> { fn u64(&self, match_id: DnTags) -> Option<u64> {
self.dn self.dn
.iter() .iter()
@ -339,24 +344,27 @@ impl DistNames {
const PRINTABLE_STR_THRESHOLD: u8 = 0x80; const PRINTABLE_STR_THRESHOLD: u8 = 0x80;
impl<'a> FromTLV<'a> for DistNames { impl<'a> FromTLV<'a> for DistNames<'a> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> { fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> {
let mut d = Self { let mut d = Self {
dn: Vec::with_capacity(MAX_DN_ENTRIES), dn: heapless::Vec::new(),
}; };
let iter = t.confirm_list()?.enter().ok_or(Error::Invalid)?; let iter = t.confirm_list()?.enter().ok_or(Error::Invalid)?;
for t in iter { for t in iter {
if let TagType::Context(tag) = t.get_tag() { if let TagType::Context(tag) = t.get_tag() {
if let Ok(value) = t.u64() { if let Ok(value) = t.u64() {
d.dn.push((tag, DistNameValue::Uint(value))); d.dn.push((tag, DistNameValue::Uint(value)))
.map_err(|_| Error::BufferTooSmall)?;
} else if let Ok(value) = t.slice() { } else if let Ok(value) = t.slice() {
if tag > PRINTABLE_STR_THRESHOLD { if tag > PRINTABLE_STR_THRESHOLD {
d.dn.push(( d.dn.push((
tag - PRINTABLE_STR_THRESHOLD, tag - PRINTABLE_STR_THRESHOLD,
DistNameValue::PrintableStr(value.to_vec()), DistNameValue::PrintableStr(value),
)); ))
.map_err(|_| Error::BufferTooSmall)?;
} else { } else {
d.dn.push((tag, DistNameValue::Utf8Str(value.to_vec()))); d.dn.push((tag, DistNameValue::Utf8Str(value)))
.map_err(|_| Error::BufferTooSmall)?;
} }
} }
} }
@ -365,24 +373,23 @@ impl<'a> FromTLV<'a> for DistNames {
} }
} }
impl ToTLV for DistNames { impl<'a> ToTLV for DistNames<'a> {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.start_list(tag)?; tw.start_list(tag)?;
for (name, value) in &self.dn { for (name, value) in &self.dn {
match value { match value {
DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?, DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?,
DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v.as_slice())?, DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v)?,
DistNameValue::PrintableStr(v) => tw.utf8( DistNameValue::PrintableStr(v) => {
TagType::Context(*name + PRINTABLE_STR_THRESHOLD), tw.utf8(TagType::Context(*name + PRINTABLE_STR_THRESHOLD), v)?
v.as_slice(), }
)?,
} }
} }
tw.end_container() tw.end_container()
} }
} }
impl DistNames { impl<'a> DistNames<'a> {
fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> { fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> {
const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03]; const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03];
const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04]; const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04];
@ -520,38 +527,36 @@ fn encode_dn_value(
} }
}, },
DistNameValue::Utf8Str(v) => { DistNameValue::Utf8Str(v) => {
let str = String::from_utf8(v.to_vec())?; w.utf8str("", core::str::from_utf8(v)?)?;
w.utf8str("", &str)?;
} }
DistNameValue::PrintableStr(v) => { DistNameValue::PrintableStr(v) => {
let str = String::from_utf8(v.to_vec())?; w.printstr("", core::str::from_utf8(v)?)?;
w.printstr("", &str)?;
} }
} }
w.end_seq()?; w.end_seq()?;
w.end_set() w.end_set()
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1)] #[tlvargs(lifetime = "'a", start = 1)]
pub struct Cert { pub struct Cert<'a> {
serial_no: Vec<u8>, serial_no: OctetStr<'a>,
sign_algo: u8, sign_algo: u8,
issuer: DistNames, issuer: DistNames<'a>,
not_before: u32, not_before: u32,
not_after: u32, not_after: u32,
subject: DistNames, subject: DistNames<'a>,
pubkey_algo: u8, pubkey_algo: u8,
ec_curve_id: u8, ec_curve_id: u8,
pubkey: Vec<u8>, pubkey: OctetStr<'a>,
extensions: Extensions, extensions: Extensions<'a>,
signature: Vec<u8>, signature: OctetStr<'a>,
} }
// TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding // TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding
// rules in terms of sequence may get complicated. Need to look into this // rules in terms of sequence may get complicated. Need to look into this
impl Cert { impl<'a> Cert<'a> {
pub fn new(cert_bin: &[u8]) -> Result<Self, Error> { pub fn new(cert_bin: &'a [u8]) -> Result<Self, Error> {
let root = tlv::get_root_node(cert_bin)?; let root = tlv::get_root_node(cert_bin)?;
Cert::from_tlv(&root) Cert::from_tlv(&root)
} }
@ -569,17 +574,21 @@ impl Cert {
} }
pub fn get_pubkey(&self) -> &[u8] { pub fn get_pubkey(&self) -> &[u8] {
self.pubkey.as_slice() self.pubkey.0
} }
pub fn get_subject_key_id(&self) -> Result<&[u8], Error> { pub fn get_subject_key_id(&self) -> Result<&[u8], Error> {
self.extensions.subj_key_id.as_deref().ok_or(Error::Invalid) if let Some(id) = self.extensions.subj_key_id.as_ref() {
Ok(id.0)
} else {
Err(Error::Invalid)
}
} }
pub fn is_authority(&self, their: &Cert) -> Result<bool, Error> { pub fn is_authority(&self, their: &Cert) -> Result<bool, Error> {
if let Some(our_auth_key) = &self.extensions.auth_key_id { if let Some(our_auth_key) = &self.extensions.auth_key_id {
let their_subject = their.get_subject_key_id()?; let their_subject = their.get_subject_key_id()?;
if our_auth_key == their_subject { if our_auth_key.0 == their_subject {
Ok(true) Ok(true)
} else { } else {
Ok(false) Ok(false)
@ -590,7 +599,7 @@ impl Cert {
} }
pub fn get_signature(&self) -> &[u8] { pub fn get_signature(&self) -> &[u8] {
self.signature.as_slice() self.signature.0
} }
pub fn as_tlv(&self, buf: &mut [u8]) -> Result<usize, Error> { pub fn as_tlv(&self, buf: &mut [u8]) -> Result<usize, Error> {
@ -617,7 +626,7 @@ impl Cert {
w.integer("", &[2])?; w.integer("", &[2])?;
w.end_ctx()?; w.end_ctx()?;
w.integer("Serial Num:", self.serial_no.as_slice())?; w.integer("Serial Num:", self.serial_no.0)?;
w.start_seq("Signature Algorithm:")?; w.start_seq("Signature Algorithm:")?;
let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(Error::Invalid)? { let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(Error::Invalid)? {
@ -647,7 +656,7 @@ impl Cert {
w.oid(str, &curve_id)?; w.oid(str, &curve_id)?;
w.end_seq()?; w.end_seq()?;
w.bitstr("Public-Key:", false, self.pubkey.as_slice())?; w.bitstr("Public-Key:", false, self.pubkey.0)?;
w.end_seq()?; w.end_seq()?;
self.extensions.encode(w)?; self.extensions.encode(w)?;
@ -658,7 +667,7 @@ impl Cert {
} }
} }
impl fmt::Display for Cert { impl<'a> fmt::Display for Cert<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut printer = CertPrinter::new(f); let mut printer = CertPrinter::new(f);
let _ = self let _ = self
@ -670,7 +679,7 @@ impl fmt::Display for Cert {
} }
pub struct CertVerifier<'a> { pub struct CertVerifier<'a> {
cert: &'a Cert, cert: &'a Cert<'a>,
} }
impl<'a> CertVerifier<'a> { impl<'a> CertVerifier<'a> {
@ -809,6 +818,7 @@ mod tests {
#[test] #[test]
fn test_tlv_conversions() { fn test_tlv_conversions() {
let _ = env_logger::try_init();
let test_input: [&[u8]; 3] = [ let test_input: [&[u8]; 3] = [
&test_vectors::NOC1_SUCCESS, &test_vectors::NOC1_SUCCESS,
&test_vectors::ICAC1_SUCCESS, &test_vectors::ICAC1_SUCCESS,

View file

@ -19,7 +19,7 @@ use core::cell::RefCell;
use core::convert::TryInto; use core::convert::TryInto;
use crate::acl::{AclEntry, AclMgr, AuthMode}; use crate::acl::{AclEntry, AclMgr, AuthMode};
use crate::cert::Cert; use crate::cert::{Cert, MAX_CERT_TLV_LEN};
use crate::crypto::{self, KeyPair}; use crate::crypto::{self, KeyPair};
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::data_model::sdm::dev_att; use crate::data_model::sdm::dev_att;
@ -158,14 +158,14 @@ pub const CLUSTER: Cluster<'static> = Cluster {
pub struct NocData { pub struct NocData {
pub key_pair: KeyPair, pub key_pair: KeyPair,
pub root_ca: Cert, pub root_ca: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
} }
impl NocData { impl NocData {
pub fn new(key_pair: KeyPair) -> Self { pub fn new(key_pair: KeyPair) -> Self {
Self { Self {
key_pair, key_pair,
root_ca: Cert::default(), root_ca: heapless::Vec::new(),
} }
} }
} }
@ -259,8 +259,10 @@ impl<'a> NocCluster<'a> {
writer.start_array(AttrDataWriter::TAG)?; writer.start_array(AttrDataWriter::TAG)?;
self.fabric_mgr.borrow().for_each(|entry, fab_idx| { self.fabric_mgr.borrow().for_each(|entry, fab_idx| {
if !attr.fab_filter || attr.fab_idx == fab_idx { if !attr.fab_filter || attr.fab_idx == fab_idx {
let root_ca_cert = entry.get_root_ca()?;
entry entry
.get_fabric_desc(fab_idx) .get_fabric_desc(fab_idx, &root_ca_cert)?
.to_tlv(&mut writer, TagType::Anonymous)?; .to_tlv(&mut writer, TagType::Anonymous)?;
} }
@ -351,12 +353,18 @@ impl<'a> NocCluster<'a> {
let r = AddNocReq::from_tlv(data).map_err(|_| NocStatus::InvalidNOC)?; let r = AddNocReq::from_tlv(data).map_err(|_| NocStatus::InvalidNOC)?;
let noc_value = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?; let noc_cert = Cert::new(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?;
info!("Received NOC as: {}", noc_value); info!("Received NOC as: {}", noc_cert);
let icac_value = if !r.icac_value.0.is_empty() {
let cert = Cert::new(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?; let noc = heapless::Vec::from_slice(r.noc_value.0).map_err(|_| NocStatus::InvalidNOC)?;
info!("Received ICAC as: {}", cert);
Some(cert) let icac = if !r.icac_value.0.is_empty() {
let icac_cert = Cert::new(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?;
info!("Received ICAC as: {}", icac_cert);
let icac =
heapless::Vec::from_slice(r.icac_value.0).map_err(|_| NocStatus::InvalidNOC)?;
Some(icac)
} else { } else {
None None
}; };
@ -364,8 +372,8 @@ impl<'a> NocCluster<'a> {
let fabric = Fabric::new( let fabric = Fabric::new(
noc_data.key_pair, noc_data.key_pair,
noc_data.root_ca, noc_data.root_ca,
icac_value, icac,
noc_value, noc,
r.ipk_value.0, r.ipk_value.0,
r.vendor_id, r.vendor_id,
"", "",
@ -592,7 +600,9 @@ impl<'a> NocCluster<'a> {
let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?; let req = CommonReq::from_tlv(data).map_err(Error::map_invalid_command)?;
info!("Received Trusted Cert:{:x?}", req.str); info!("Received Trusted Cert:{:x?}", req.str);
noc_data.root_ca = Cert::new(req.str.0)?; noc_data.root_ca =
heapless::Vec::from_slice(req.str.0).map_err(|_| Error::BufferTooSmall)?;
// TODO
} }
_ => (), _ => (),
} }

View file

@ -15,14 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use alloc::string::FromUtf8Error; use core::{array::TryFromSliceError, fmt, str::Utf8Error};
use core::{array::TryFromSliceError, fmt};
use async_channel::{SendError, TryRecvError}; use async_channel::{SendError, TryRecvError};
use log::error; use log::error;
extern crate alloc;
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug, PartialEq, Clone, Copy)]
pub enum Error { pub enum Error {
AttributeNotFound, AttributeNotFound,
@ -166,8 +163,8 @@ impl<T> From<SendError<T>> for Error {
} }
} }
impl From<FromUtf8Error> for Error { impl From<Utf8Error> for Error {
fn from(_e: FromUtf8Error) -> Self { fn from(_e: Utf8Error) -> Self {
Self::Utf8Fail Self::Utf8Fail
} }
} }

View file

@ -21,7 +21,7 @@ use byteorder::{BigEndian, ByteOrder, LittleEndian};
use log::{error, info}; use log::{error, info};
use crate::{ use crate::{
cert::Cert, cert::{Cert, MAX_CERT_TLV_LEN},
crypto::{self, hkdf_sha256, HmacSha256, KeyPair}, crypto::{self, hkdf_sha256, HmacSha256, KeyPair},
error::Error, error::Error,
group_keys::KeySet, group_keys::KeySet,
@ -30,7 +30,6 @@ use crate::{
tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr},
}; };
const MAX_CERT_TLV_LEN: usize = 300;
const COMPRESSED_FABRIC_ID_LEN: usize = 8; const COMPRESSED_FABRIC_ID_LEN: usize = 8;
macro_rules! fb_key { macro_rules! fb_key {
@ -72,9 +71,9 @@ pub struct Fabric {
fabric_id: u64, fabric_id: u64,
vendor_id: u16, vendor_id: u16,
key_pair: KeyPair, key_pair: KeyPair,
pub root_ca: Cert, pub root_ca: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
pub icac: Option<Cert>, pub icac: Option<heapless::Vec<u8, { MAX_CERT_TLV_LEN }>>,
pub noc: Cert, pub noc: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
pub ipk: KeySet, pub ipk: KeySet,
label: heapless::String<32>, label: heapless::String<32>,
mdns_service_name: heapless::String<33>, mdns_service_name: heapless::String<33>,
@ -83,20 +82,25 @@ pub struct Fabric {
impl Fabric { impl Fabric {
pub fn new( pub fn new(
key_pair: KeyPair, key_pair: KeyPair,
root_ca: Cert, root_ca: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
icac: Option<Cert>, icac: Option<heapless::Vec<u8, { MAX_CERT_TLV_LEN }>>,
noc: Cert, noc: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
ipk: &[u8], ipk: &[u8],
vendor_id: u16, vendor_id: u16,
label: &str, label: &str,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let node_id = noc.get_node_id()?; let (node_id, fabric_id) = {
let fabric_id = noc.get_fabric_id()?; let noc_p = Cert::new(&noc)?;
(noc_p.get_node_id()?, noc_p.get_fabric_id()?)
};
let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN]; let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN];
Fabric::get_compressed_id(root_ca.get_pubkey(), fabric_id, &mut compressed_id)?; let ipk = {
let ipk = KeySet::new(ipk, &compressed_id)?; let root_ca_p = Cert::new(&root_ca)?;
Fabric::get_compressed_id(root_ca_p.get_pubkey(), fabric_id, &mut compressed_id)?;
KeySet::new(ipk, &compressed_id)?
};
let mut mdns_service_name = heapless::String::<33>::new(); let mut mdns_service_name = heapless::String::<33>::new();
for c in compressed_id { for c in compressed_id {
@ -144,7 +148,7 @@ impl Fabric {
let mut mac = HmacSha256::new(self.ipk.op_key())?; let mut mac = HmacSha256::new(self.ipk.op_key())?;
mac.update(random)?; mac.update(random)?;
mac.update(self.root_ca.get_pubkey())?; mac.update(self.get_root_ca()?.get_pubkey())?;
let mut buf: [u8; 8] = [0; 8]; let mut buf: [u8; 8] = [0; 8];
LittleEndian::write_u64(&mut buf, self.fabric_id); LittleEndian::write_u64(&mut buf, self.fabric_id);
@ -174,15 +178,25 @@ impl Fabric {
self.fabric_id self.fabric_id
} }
pub fn get_fabric_desc(&self, fab_idx: u8) -> FabricDescriptor { pub fn get_root_ca(&self) -> Result<Cert<'_>, Error> {
FabricDescriptor { Cert::new(&self.root_ca)
root_public_key: OctetStr::new(self.root_ca.get_pubkey()), }
pub fn get_fabric_desc<'a>(
&'a self,
fab_idx: u8,
root_ca_cert: &'a Cert,
) -> Result<FabricDescriptor<'a>, Error> {
let desc = FabricDescriptor {
root_public_key: OctetStr::new(root_ca_cert.get_pubkey()),
vendor_id: self.vendor_id, vendor_id: self.vendor_id,
fabric_id: self.fabric_id, fabric_id: self.fabric_id,
node_id: self.node_id, node_id: self.node_id,
label: UtfStr(self.label.as_bytes()), label: UtfStr(self.label.as_bytes()),
fab_idx: Some(fab_idx), fab_idx: Some(fab_idx),
} };
Ok(desc)
} }
fn store<T>(&self, index: usize, mut psm: T) -> Result<(), Error> fn store<T>(&self, index: usize, mut psm: T) -> Result<(), Error>
@ -191,19 +205,13 @@ impl Fabric {
{ {
let mut _kb = heapless::String::<32>::new(); let mut _kb = heapless::String::<32>::new();
let mut buf = [0u8; MAX_CERT_TLV_LEN]; psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca)?;
let len = self.root_ca.as_tlv(&mut buf)?; psm.set_kv_slice(
psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len])?; fb_key!(index, ST_ICA, _kb),
self.icac.as_deref().unwrap_or(&[]),
)?;
let len = if let Some(icac) = &self.icac { psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc)?;
icac.as_tlv(&mut buf)?
} else {
0
};
psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len])?;
let len = self.noc.as_tlv(&mut buf)?;
psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len])?;
psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())?; psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())?;
psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes())?; psm.set_kv_slice(fb_key!(index, ST_LBL, _kb), self.label.as_bytes())?;
@ -228,18 +236,21 @@ impl Fabric {
let mut _kb = heapless::String::<32>::new(); let mut _kb = heapless::String::<32>::new();
let mut buf = [0u8; MAX_CERT_TLV_LEN]; let mut buf = [0u8; MAX_CERT_TLV_LEN];
let root_ca = psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?;
let root_ca = Cert::new(root_ca)?; let root_ca =
heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)?)
.unwrap();
let icac = psm.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)?; let icac = psm.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)?;
let icac = if !icac.is_empty() { let icac = if !icac.is_empty() {
Some(Cert::new(icac)?) Some(heapless::Vec::from_slice(icac).unwrap())
} else { } else {
None None
}; };
let noc = psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?; let noc =
let noc = Cert::new(noc)?; heapless::Vec::from_slice(psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)?)
.unwrap();
let label = psm.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)?; let label = psm.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)?;
let label: heapless::String<32> = core::str::from_utf8(label) let label: heapless::String<32> = core::str::from_utf8(label)
@ -293,21 +304,16 @@ impl Fabric {
{ {
let mut _kb = heapless::String::<32>::new(); let mut _kb = heapless::String::<32>::new();
let mut buf = [0u8; MAX_CERT_TLV_LEN]; psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &self.root_ca)
let len = self.root_ca.as_tlv(&mut buf)?;
psm.set_kv_slice(fb_key!(index, ST_RCA, _kb), &buf[..len])
.await?; .await?;
let len = if let Some(icac) = &self.icac { psm.set_kv_slice(
icac.as_tlv(&mut buf)? fb_key!(index, ST_ICA, _kb),
} else { self.icac.as_deref().unwrap_or(&[]),
0 )
}; .await?;
psm.set_kv_slice(fb_key!(index, ST_ICA, _kb), &buf[..len])
.await?;
let len = self.noc.as_tlv(&mut buf)?; psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &self.noc)
psm.set_kv_slice(fb_key!(index, ST_NOC, _kb), &buf[..len])
.await?; .await?;
psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key()) psm.set_kv_slice(fb_key!(index, ST_IPK, _kb), self.ipk.epoch_key())
.await?; .await?;
@ -337,24 +343,27 @@ impl Fabric {
let mut _kb = heapless::String::<32>::new(); let mut _kb = heapless::String::<32>::new();
let mut buf = [0u8; MAX_CERT_TLV_LEN]; let mut buf = [0u8; MAX_CERT_TLV_LEN];
let root_ca = psm
.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf) let root_ca = heapless::Vec::from_slice(
.await?; psm.get_kv_slice(fb_key!(index, ST_RCA, _kb), &mut buf)
let root_ca = Cert::new(root_ca)?; .await?,
)
.unwrap();
let icac = psm let icac = psm
.get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf) .get_kv_slice(fb_key!(index, ST_ICA, _kb), &mut buf)
.await?; .await?;
let icac = if !icac.is_empty() { let icac = if !icac.is_empty() {
Some(Cert::new(icac)?) Some(heapless::Vec::from_slice(icac).unwrap())
} else { } else {
None None
}; };
let noc = psm let noc = heapless::Vec::from_slice(
.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf) psm.get_kv_slice(fb_key!(index, ST_NOC, _kb), &mut buf)
.await?; .await?,
let noc = Cert::new(noc)?; )
.unwrap();
let label = psm let label = psm
.get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf) .get_kv_slice(fb_key!(index, ST_LBL, _kb), &mut buf)

View file

@ -349,7 +349,9 @@ impl<'a> Case<'a> {
verifier = verifier.add_cert(icac)?; verifier = verifier.add_cert(icac)?;
} }
verifier.add_cert(&fabric.root_ca)?.finalise()?; verifier
.add_cert(&Cert::new(&fabric.root_ca)?)?
.finalise()?;
Ok(()) Ok(())
} }
@ -481,9 +483,9 @@ impl<'a> Case<'a> {
let mut write_buf = WriteBuf::new(out); let mut write_buf = WriteBuf::new(out);
let mut tw = TLVWriter::new(&mut write_buf); let mut tw = TLVWriter::new(&mut write_buf);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; tw.str16(TagType::Context(1), &fabric.noc)?;
if let Some(icac_cert) = &fabric.icac { if let Some(icac_cert) = fabric.icac.as_ref() {
tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))? tw.str16(TagType::Context(2), icac_cert)?
}; };
tw.str8(TagType::Context(3), signature)?; tw.str8(TagType::Context(3), signature)?;
@ -523,9 +525,9 @@ impl<'a> Case<'a> {
let mut write_buf = WriteBuf::new(&mut buf); let mut write_buf = WriteBuf::new(&mut buf);
let mut tw = TLVWriter::new(&mut write_buf); let mut tw = TLVWriter::new(&mut write_buf);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; tw.str16(TagType::Context(1), &fabric.noc)?;
if let Some(icac_cert) = &fabric.icac { if let Some(icac_cert) = fabric.icac.as_deref() {
tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))?; tw.str16(TagType::Context(2), icac_cert)?;
} }
tw.str8(TagType::Context(3), our_pub_key)?; tw.str8(TagType::Context(3), our_pub_key)?;
tw.str8(TagType::Context(4), peer_pub_key)?; tw.str8(TagType::Context(4), peer_pub_key)?;

View file

@ -17,14 +17,10 @@
use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType};
use crate::error::Error; use crate::error::Error;
use alloc::borrow::ToOwned;
use alloc::{string::String, vec::Vec};
use core::fmt::Debug; use core::fmt::Debug;
use core::slice::Iter; use core::slice::Iter;
use log::error; use log::error;
extern crate alloc;
pub trait FromTLV<'a> { pub trait FromTLV<'a> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error>
where where
@ -118,14 +114,11 @@ totlv_for!(i8 u8 u16 u32 u64 bool);
// //
// - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec // - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec
// - These only have references into the original list // - These only have references into the original list
// - String, Vec<u8>: Is the owned version of utfstr and ostr, data is cloned into this
// - String is only partially implemented
// //
// - TLVArray: Is an array of entries, with reference within the original list // - TLVArray: Is an array of entries, with reference within the original list
// - TLVArrayOwned: Is the owned version of this, data is cloned into this
/// Implements UTFString from the spec /// Implements UTFString from the spec
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Copy, Clone, PartialEq, Default)]
pub struct UtfStr<'a>(pub &'a [u8]); pub struct UtfStr<'a>(pub &'a [u8]);
impl<'a> UtfStr<'a> { impl<'a> UtfStr<'a> {
@ -136,10 +129,6 @@ impl<'a> UtfStr<'a> {
pub fn as_str(&self) -> Result<&str, Error> { pub fn as_str(&self) -> Result<&str, Error> {
core::str::from_utf8(self.0).map_err(|_| Error::Invalid) core::str::from_utf8(self.0).map_err(|_| Error::Invalid)
} }
pub fn to_string(self) -> Result<String, Error> {
String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid)
}
} }
impl<'a> ToTLV for UtfStr<'a> { impl<'a> ToTLV for UtfStr<'a> {
@ -155,7 +144,7 @@ impl<'a> FromTLV<'a> for UtfStr<'a> {
} }
/// Implements OctetString from the spec /// Implements OctetString from the spec
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Copy, Clone, PartialEq, Default)]
pub struct OctetStr<'a>(pub &'a [u8]); pub struct OctetStr<'a>(pub &'a [u8]);
impl<'a> OctetStr<'a> { impl<'a> OctetStr<'a> {
@ -176,41 +165,6 @@ impl<'a> ToTLV for OctetStr<'a> {
} }
} }
/// Implements the Owned version of Octet String
impl FromTLV<'_> for Vec<u8> {
fn from_tlv(t: &TLVElement) -> Result<Vec<u8>, Error> {
t.slice().map(|x| x.to_owned())
}
}
impl ToTLV for Vec<u8> {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.str16(tag, self.as_slice())
}
}
/// Implements the Owned version of UTF String
impl FromTLV<'_> for String {
fn from_tlv(t: &TLVElement) -> Result<String, Error> {
match t.slice() {
Ok(x) => {
if let Ok(s) = String::from_utf8(x.to_vec()) {
Ok(s)
} else {
Err(Error::Invalid)
}
}
Err(e) => Err(e),
}
}
}
impl ToTLV for String {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.utf16(tag, self.as_bytes())
}
}
/// Applies to all the Option<> Processing /// Applies to all the Option<> Processing
impl<'a, T: FromTLV<'a>> FromTLV<'a> for Option<T> { impl<'a, T: FromTLV<'a>> FromTLV<'a> for Option<T> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Option<T>, Error> { fn from_tlv(t: &TLVElement<'a>) -> Result<Option<T>, Error> {
@ -279,37 +233,6 @@ impl<T: ToTLV> ToTLV for Nullable<T> {
} }
} }
/// Owned version of a TLVArray
pub struct TLVArrayOwned<T>(Vec<T>);
impl<'a, T: FromTLV<'a>> FromTLV<'a> for TLVArrayOwned<T> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> {
t.confirm_array()?;
let mut vec = Vec::<T>::new();
if let Some(tlv_iter) = t.enter() {
for element in tlv_iter {
vec.push(T::from_tlv(&element)?);
}
}
Ok(Self(vec))
}
}
impl<T: ToTLV> ToTLV for TLVArrayOwned<T> {
fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
tw.start_array(tag_type)?;
for t in &self.0 {
t.to_tlv(tw, TagType::Anonymous)?;
}
tw.end_container()
}
}
impl<T> TLVArrayOwned<T> {
pub fn iter(&self) -> Iter<T> {
self.0.iter()
}
}
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub enum TLVArray<'a, T> { pub enum TLVArray<'a, T> {
// This is used for the to-tlv path // This is used for the to-tlv path
@ -390,18 +313,23 @@ where
} }
} }
impl<'a, T: ToTLV> ToTLV for TLVArray<'a, T> { impl<'a, T: FromTLV<'a> + Copy + ToTLV> ToTLV for TLVArray<'a, T> {
fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
match *self { tw.start_array(tag_type)?;
Self::Slice(s) => { for a in self.iter() {
tw.start_array(tag_type)?; a.to_tlv(tw, TagType::Anonymous)?;
for a in s {
a.to_tlv(tw, TagType::Anonymous)?;
}
tw.end_container()
}
Self::Ptr(t) => t.to_tlv(tw, tag_type),
} }
tw.end_container()
// match *self {
// Self::Slice(s) => {
// tw.start_array(tag_type)?;
// for a in s {
// a.to_tlv(tw, TagType::Anonymous)?;
// }
// tw.end_container()
// }
// Self::Ptr(t) => t.to_tlv(tw, tag_type), <-- TODO: this fails the unit tests of Cert from/to TLV
// }
} }
} }
@ -414,10 +342,17 @@ impl<'a, T> FromTLV<'a> for TLVArray<'a, T> {
impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "TLVArray [")?;
let mut first = true;
for i in self.iter() { for i in self.iter() {
writeln!(f, "{:?}", i)?; if !first {
write!(f, ", ")?;
}
write!(f, "{:?}", i)?;
first = false;
} }
writeln!(f) write!(f, "]")
} }
} }