/* * * Copyright (c) 2020-2022 Project CHIP Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use super::{TagType, TAG_SHIFT_BITS, TAG_SIZE_MAP}; use crate::{error::*, utils::writebuf::WriteBuf}; use log::error; #[allow(dead_code)] enum WriteElementType { S8 = 0, S16 = 1, S32 = 2, S64 = 3, U8 = 4, U16 = 5, U32 = 6, U64 = 7, False = 8, True = 9, F32 = 10, F64 = 11, Utf8l = 12, Utf16l = 13, Utf32l = 14, Utf64l = 15, Str8l = 16, Str16l = 17, Str32l = 18, Str64l = 19, Null = 20, Struct = 21, Array = 22, List = 23, EndCnt = 24, Last, } pub struct TLVWriter<'a, 'b> { buf: &'b mut WriteBuf<'a>, } impl<'a, 'b> TLVWriter<'a, 'b> { pub fn new(buf: &'b mut WriteBuf<'a>) -> Self { TLVWriter { buf } } // TODO: The current method of using writebuf's put methods force us to do // at max 3 checks while writing a single TLV (once for control, once for tag, // once for value), so do a single check and write the whole thing. #[inline(always)] fn put_control_tag( &mut self, tag_type: TagType, val_type: WriteElementType, ) -> Result<(), Error> { let (tag_id, tag_val) = match tag_type { TagType::Anonymous => (0_u8, 0), TagType::Context(v) => (1, v as u64), TagType::CommonPrf16(v) => (2, v as u64), TagType::CommonPrf32(v) => (3, v as u64), TagType::ImplPrf16(v) => (4, v as u64), TagType::ImplPrf32(v) => (5, v as u64), TagType::FullQual48(v) => (6, v as u64), TagType::FullQual64(v) => (7, v as u64), }; self.buf .le_u8(((tag_id) << TAG_SHIFT_BITS) | (val_type as u8))?; if tag_type != TagType::Anonymous { self.buf.le_uint(TAG_SIZE_MAP[tag_id as usize], tag_val)?; } Ok(()) } pub fn i8(&mut self, tag_type: TagType, data: i8) -> Result<(), Error> { self.put_control_tag(tag_type, WriteElementType::S8)?; self.buf.le_i8(data) } pub fn u8(&mut self, tag_type: TagType, data: u8) -> Result<(), Error> { self.put_control_tag(tag_type, WriteElementType::U8)?; self.buf.le_u8(data) } pub fn u16(&mut self, tag_type: TagType, data: u16) -> Result<(), Error> { if data <= 0xff { self.u8(tag_type, data as u8) } else { self.put_control_tag(tag_type, WriteElementType::U16)?; self.buf.le_u16(data) } } pub fn u32(&mut self, tag_type: TagType, data: u32) -> Result<(), Error> { if data <= 0xff { self.u8(tag_type, data as u8) } else if data <= 0xffff { self.u16(tag_type, data as u16) } else { self.put_control_tag(tag_type, WriteElementType::U32)?; self.buf.le_u32(data) } } pub fn u64(&mut self, tag_type: TagType, data: u64) -> Result<(), Error> { if data <= 0xff { self.u8(tag_type, data as u8) } else if data <= 0xffff { self.u16(tag_type, data as u16) } else if data <= 0xffffffff { self.u32(tag_type, data as u32) } else { self.put_control_tag(tag_type, WriteElementType::U64)?; self.buf.le_u64(data) } } pub fn str8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { if data.len() > 256 { error!("use put_str16() instead"); return Err(Error::Invalid); } self.put_control_tag(tag_type, WriteElementType::Str8l)?; self.buf.le_u8(data.len() as u8)?; self.buf.copy_from_slice(data) } pub fn str16(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { if data.len() <= 0xff { self.str8(tag_type, data) } else { self.put_control_tag(tag_type, WriteElementType::Str16l)?; self.buf.le_u16(data.len() as u16)?; self.buf.copy_from_slice(data) } } // This is quite hacky pub fn str16_as(&mut self, tag_type: TagType, data_gen: F) -> Result<(), Error> where F: FnOnce(&mut [u8]) -> Result, { let anchor = self.buf.get_tail(); self.put_control_tag(tag_type, WriteElementType::Str16l)?; let wb = self.buf.empty_as_mut_slice(); // Reserve 2 spaces for the control and length let str = &mut wb[2..]; let len = data_gen(str).unwrap_or_default(); if len <= 0xff { // Shift everything by 1 let str = &mut wb[1..]; for i in 0..len { str[i] = str[i + 1]; } self.buf.rewind_tail_to(anchor); self.put_control_tag(tag_type, WriteElementType::Str8l)?; self.buf.le_u8(len as u8)?; } else { self.buf.le_u16(len as u16)?; } self.buf.forward_tail_by(len); Ok(()) } pub fn utf8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { self.put_control_tag(tag_type, WriteElementType::Utf8l)?; self.buf.le_u8(data.len() as u8)?; self.buf.copy_from_slice(data) } pub fn utf16(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { if data.len() <= 0xff { self.utf8(tag_type, data) } else { self.put_control_tag(tag_type, WriteElementType::Utf16l)?; self.buf.le_u16(data.len() as u16)?; self.buf.copy_from_slice(data) } } fn no_val(&mut self, tag_type: TagType, element: WriteElementType) -> Result<(), Error> { self.put_control_tag(tag_type, element) } pub fn start_struct(&mut self, tag_type: TagType) -> Result<(), Error> { self.no_val(tag_type, WriteElementType::Struct) } pub fn start_array(&mut self, tag_type: TagType) -> Result<(), Error> { self.no_val(tag_type, WriteElementType::Array) } pub fn start_list(&mut self, tag_type: TagType) -> Result<(), Error> { self.no_val(tag_type, WriteElementType::List) } pub fn end_container(&mut self) -> Result<(), Error> { self.no_val(TagType::Anonymous, WriteElementType::EndCnt) } pub fn null(&mut self, tag_type: TagType) -> Result<(), Error> { self.no_val(tag_type, WriteElementType::Null) } pub fn bool(&mut self, tag_type: TagType, val: bool) -> Result<(), Error> { if val { self.no_val(tag_type, WriteElementType::True) } else { self.no_val(tag_type, WriteElementType::False) } } pub fn get_tail(&self) -> usize { self.buf.get_tail() } pub fn rewind_to(&mut self, anchor: usize) { self.buf.rewind_tail_to(anchor); } } #[cfg(test)] mod tests { use super::{TLVWriter, TagType}; use crate::utils::writebuf::WriteBuf; #[test] fn test_write_success() { let mut buf: [u8; 20] = [0; 20]; let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf, buf_len); let mut tw = TLVWriter::new(&mut writebuf); tw.start_struct(TagType::Anonymous).unwrap(); tw.u8(TagType::Anonymous, 12).unwrap(); tw.u8(TagType::Context(1), 13).unwrap(); tw.u16(TagType::Anonymous, 0x1212).unwrap(); tw.u16(TagType::Context(2), 0x1313).unwrap(); tw.start_array(TagType::Context(3)).unwrap(); tw.bool(TagType::Anonymous, true).unwrap(); tw.end_container().unwrap(); tw.end_container().unwrap(); assert_eq!( buf, [21, 4, 12, 36, 1, 13, 5, 0x12, 0x012, 37, 2, 0x13, 0x13, 54, 3, 9, 24, 24, 0, 0] ); } #[test] fn test_write_overflow() { let mut buf: [u8; 6] = [0; 6]; let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf, buf_len); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Anonymous, 12).unwrap(); tw.u8(TagType::Context(1), 13).unwrap(); match tw.u16(TagType::Anonymous, 12) { Ok(_) => panic!("This should have returned error"), _ => (), } match tw.u16(TagType::Context(2), 13) { Ok(_) => panic!("This should have returned error"), _ => (), } assert_eq!(buf, [4, 12, 36, 1, 13, 4]); } #[test] fn test_put_str8() { let mut buf: [u8; 20] = [0; 20]; let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf, buf_len); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Context(1), 13).unwrap(); tw.str8(TagType::Anonymous, &[10, 11, 12, 13, 14]).unwrap(); tw.u16(TagType::Context(2), 0x1313).unwrap(); tw.str8(TagType::Context(3), &[20, 21, 22]).unwrap(); assert_eq!( buf, [36, 1, 13, 16, 5, 10, 11, 12, 13, 14, 37, 2, 0x13, 0x13, 48, 3, 3, 20, 21, 22] ); } #[test] fn test_put_str16_as() { let mut buf: [u8; 20] = [0; 20]; let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf, buf_len); let mut tw = TLVWriter::new(&mut writebuf); tw.u8(TagType::Context(1), 13).unwrap(); tw.str8(TagType::Context(2), &[10, 11, 12, 13, 14]).unwrap(); tw.str16_as(TagType::Context(3), |buf| { buf[0] = 10; buf[1] = 11; Ok(2) }) .unwrap(); tw.u8(TagType::Context(4), 13).unwrap(); assert_eq!( buf, [36, 1, 13, 48, 2, 5, 10, 11, 12, 13, 14, 48, 3, 2, 10, 11, 36, 4, 13, 0] ); } }