diff --git a/rs-matter/src/cert/mod.rs b/rs-matter/src/cert/mod.rs index 5821c07..1e1cda2 100644 --- a/rs-matter/src/cert/mod.rs +++ b/rs-matter/src/cert/mod.rs @@ -175,7 +175,7 @@ fn encode_extended_key_usage( w.end_seq() } -#[derive(FromTLV, ToTLV, Default, Debug)] +#[derive(FromTLV, ToTLV, Default, Debug, PartialEq)] #[tlvargs(start = 1)] struct BasicConstraints { is_ca: bool, @@ -215,7 +215,7 @@ fn encode_extension_end(w: &mut dyn CertConsumer) -> Result<(), Error> { w.end_seq() } -#[derive(ToTLV, Default, Debug)] +#[derive(ToTLV, Default, Debug, PartialEq)] #[tlvargs(lifetime = "'a", start = 1, datatype = "list")] struct Extensions<'a> { basic_const: Option, @@ -348,7 +348,7 @@ enum DnTags { NocCat = 22, } -#[derive(Debug)] +#[derive(Debug, PartialEq)] enum DistNameValue<'a> { Uint(u64), Utf8Str(&'a [u8]), @@ -357,7 +357,7 @@ enum DistNameValue<'a> { const MAX_DN_ENTRIES: usize = 5; -#[derive(Default, Debug)] +#[derive(Default, Debug, PartialEq)] struct DistNames<'a> { // The order in which the DNs arrive is important, as the signing // requires that the ASN1 notation retains the same order @@ -595,7 +595,7 @@ fn encode_dn_value( w.end_set() } -#[derive(FromTLV, ToTLV, Default, Debug)] +#[derive(FromTLV, ToTLV, Default, Debug, PartialEq)] #[tlvargs(lifetime = "'a", start = 1)] pub struct Cert<'a> { serial_no: OctetStr<'a>, @@ -922,7 +922,10 @@ mod tests { let mut wb = WriteBuf::new(&mut buf); let mut tw = TLVWriter::new(&mut wb); cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); - assert_eq!(*input, wb.as_slice()); + + let root2 = tlv::get_root_node(wb.as_slice()).unwrap(); + let cert2 = Cert::from_tlv(&root2).unwrap(); + assert_eq!(cert, cert2); } } diff --git a/rs-matter/src/tlv/traits.rs b/rs-matter/src/tlv/traits.rs index 8a7f49a..77e8d57 100644 --- a/rs-matter/src/tlv/traits.rs +++ b/rs-matter/src/tlv/traits.rs @@ -371,6 +371,28 @@ impl<'a, T: FromTLV<'a> + Clone> Iterator for TLVArrayIter<'a, T> { } } +impl<'a, 'b, T> PartialEq> for TLVArray<'a, T> +where + T: ToTLV + FromTLV<'a> + Clone + PartialEq, + 'b: 'a, +{ + fn eq(&self, other: &TLVArray<'b, T>) -> bool { + let mut iter1 = self.iter(); + let mut iter2 = other.iter(); + loop { + match (iter1.next(), iter2.next()) { + (None, None) => return true, + (Some(x), Some(y)) => { + if x != y { + return false; + } + } + _ => return false, + } + } + } +} + impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T> where T: ToTLV + FromTLV<'a> + Clone + PartialEq, diff --git a/rs-matter/tests/common/handlers.rs b/rs-matter/tests/common/handlers.rs index 868f21a..198eb73 100644 --- a/rs-matter/tests/common/handlers.rs +++ b/rs-matter/tests/common/handlers.rs @@ -220,7 +220,7 @@ impl<'a> ImEngine<'a> { let out = &out[out.len() - 1]; let root = tlv::get_root_node_struct(&out.data).unwrap(); - match expected { + match *expected { WriteResponse::TransactionSuccess(t) => { assert_eq!(out.action, OpCode::WriteResponse); let resp = WriteResp::from_tlv(&root).unwrap();