Compare commits

...
Sign in to create a new pull request.

68 commits
main ... no_std

Author SHA1 Message Date
Kedar Sovani
7504f6f5b6
Merge pull request #65 from tomoyuki-nakabayashi/tlv_tool/add-hexstring-option
tlv_tool: Add hexstring option.
2023-07-12 15:21:11 +05:30
Restyled.io
a453610557 Restyled by prettier-markdown 2023-07-12 15:50:56 +09:00
tomoyuki nakabayashi
995bb95ac1 tlv_tool: Add hexstring option. 2023-07-10 06:23:46 +09:00
Kedar Sovani
0cc5fe97f5
Merge pull request #63 from thekuwayama/add__tools/tlv_tool_workspace.exclude
fix: add tools/tlv_tool to workspace.exclude
2023-07-03 09:29:14 +05:30
thekuwayama
f15e541e41 fix: add tools/tlv_tool workspace.exclude 2023-07-02 05:31:25 +09:00
Kedar Sovani
aa6aa476a9
Merge pull request #61 from ivmarkov/no_std
Fix #60
2023-07-01 10:35:10 +05:30
ivmarkov
24cc92aa11 Fix #60 2023-06-30 13:04:53 +00:00
ivmarkov
b40a0afbd0 Configurable parts_list in descriptor 2023-06-17 14:00:44 +00:00
ivmarkov
8e2340363f Add from/to TLV for i16, i32 and i64 2023-06-16 18:42:11 +00:00
ivmarkov
28ec278756 More comments for tailoring the example for no_std 2023-06-13 10:38:02 +00:00
ivmarkov
42470e1a34 Fix the no_std build 2023-06-13 10:12:42 +00:00
ivmarkov
864692845b Workaround broken join_multicast_v4 on ESP-IDF 2023-06-13 07:02:37 +00:00
ivmarkov
8d9bd1332c Support for ESP-IDF build 2023-06-12 11:41:33 +00:00
ivmarkov
9070e87944 Proper mDNS responder 2023-06-12 09:47:20 +00:00
ivmarkov
aa159d1772 Clippy 2023-06-10 18:51:34 +00:00
ivmarkov
df311ac6e0 Default mDns impl 2023-06-10 18:47:21 +00:00
ivmarkov
e2277a17a4 Make Matter covariant over its lifetime 2023-06-10 16:41:33 +00:00
ivmarkov
eb21772a09 Simplify main user-facing API 2023-06-09 10:16:07 +00:00
ivmarkov
6c6f74e2e0 Fix a bug in mDNS 2023-06-01 04:59:01 +00:00
ivmarkov
78f2282cd4 Make sure nix is not brought in no-std compiles 2023-05-31 12:51:37 +00:00
ivmarkov
526e592a5c Make the example working again 2023-05-28 14:05:43 +00:00
ivmarkov
c2e72e5f0a More inlines 2023-05-28 11:45:27 +00:00
ivmarkov
dcbfa1f0e3 Clippy 2023-05-28 11:13:02 +00:00
ivmarkov
1c26df0712 Control memory by removing implicit copy 2023-05-28 11:04:46 +00:00
ivmarkov
2e0a09b532 built-in mDNS; memory optimizations 2023-05-24 10:07:11 +00:00
ivmarkov
fccf9fa5f6 no_std needs default features switched off for several crates 2023-05-14 09:08:51 +00:00
ivmarkov
f89f77c3f3 Move MATTER_PORT outside of STD-only udp module 2023-05-14 09:08:51 +00:00
ivmarkov
d48b97d77f Just use time-rs in no_std mode 2023-05-14 09:08:51 +00:00
ivmarkov
86083bd831 Builds for STD with ESP IDF 2023-05-14 09:08:51 +00:00
ivmarkov
e817fa8411 Colorizing is now no_std compatible 2023-05-14 09:08:51 +00:00
ivmarkov
a539f4621e More crypto fixes 2023-05-14 09:08:51 +00:00
ivmarkov
2f2e332c75 Fix compilation errors in crypto 2023-05-14 09:08:51 +00:00
ivmarkov
86fb8ce1f0 Fix no_std errors 2023-05-14 09:08:51 +00:00
ivmarkov
5fc3d2d510 Remove heapless::String from QR API 2023-05-14 09:08:51 +00:00
imarkov
2a2bdab9c5 Optional feature to capture stacktrace on error 2023-05-14 09:08:51 +00:00
ivmarkov
0677a5938a Persistence - trace info 2023-05-14 09:08:51 +00:00
ivmarkov
fdea2863fa Persistence bugfixing 2023-05-14 09:08:51 +00:00
ivmarkov
7437cf2c94 Simple persistance via TLV 2023-05-14 09:08:51 +00:00
ivmarkov
1392810a6c Bugfix: unnecessary struct container 2023-05-14 09:08:51 +00:00
ivmarkov
bb275cd50a Bugfix: subscription_id was not sent 2023-05-14 09:08:51 +00:00
ivmarkov
b21f257c47 Bugfix: missing descriptor cluster 2023-05-14 09:08:51 +00:00
ivmarkov
6ea96c390e Error log on arm failure 2023-05-14 09:08:51 +00:00
ivmarkov
669ef8accc Bugfix: only report devtype for the queried endpoint 2023-05-14 09:08:51 +00:00
ivmarkov
1895f34439 TX packets are reused; need way to reset them 2023-05-14 09:08:51 +00:00
ivmarkov
fb2d5a4a23 Root cert buffer too short 2023-05-14 09:08:51 +00:00
ivmarkov
cf7fac7631 MRP standalone ack messages should not be acknowledged 2023-05-14 09:08:51 +00:00
ivmarkov
4c83112b33 Bugfix: fabric adding wrongly started at index 0 2023-05-14 09:08:51 +00:00
ivmarkov
b4f92b0063 Bugfix: two separate failsafe instances were used 2023-05-14 09:08:51 +00:00
ivmarkov
875ac697ad Restore transaction completion code 2023-05-14 09:08:51 +00:00
ivmarkov
40a476e0d9 Bugfix: arm failsafe was reporting wrong status 2023-05-14 09:08:51 +00:00
ivmarkov
d12e1cfa13 Heap-allocated packets not necessary; no_std and no-alloc build supported end-to-end 2023-05-14 09:08:51 +00:00
ivmarkov
e9b4dc5a5c Comm with chip-tool 2023-05-14 09:08:51 +00:00
ivmarkov
c28df04cb5 Actually add the bonjour feature 2023-05-14 09:08:51 +00:00
ivmarkov
52185ec9a4 Cleanup a bit the mDns story 2023-05-14 09:08:51 +00:00
ivmarkov
a7ca17fabc On-off example now buildable 2023-05-14 09:08:51 +00:00
ivmarkov
17002db7e1 no_std printing of QR code (kind of...) 2023-05-14 09:08:51 +00:00
ivmarkov
1ef431eceb Cleanup the dependencies as much as possible 2023-05-14 09:08:51 +00:00
ivmarkov
c594cf1c55 Fix compilation error since the introduction of UtcCalendar 2023-05-14 09:08:51 +00:00
ivmarkov
117c36ee61 More ergonomic api when STD is available 2023-05-14 09:08:51 +00:00
ivmarkov
625baa72a3 Create new secure channel sessions without async-channel 2023-05-14 09:08:51 +00:00
ivmarkov
2b6317a9e2 Chrono dep made optional 2023-05-14 09:08:51 +00:00
ivmarkov
0b807f03a6 Linux & MacOS mDNS services now implement the Mdns trait 2023-05-14 09:08:51 +00:00
ivmarkov
86a1b5ce7e Fix several no_std incompatibilities 2023-05-14 09:08:51 +00:00
ivmarkov
d82e9ec0af Remove allocations from Cert handling 2023-05-14 09:08:51 +00:00
ivmarkov
f7a887c1d2 Remove allocations from Base38 and QR calc 2023-05-14 09:08:51 +00:00
ivmarkov
26fb6b01c5 Long reads and subscriptions reintroduced 2023-05-14 09:08:51 +00:00
ivmarkov
c11a1a1372 Start reintroducing long reads and subscriptions from mainline 2023-05-14 09:08:51 +00:00
ivmarkov
d446007f6b Support for no_std
Support for no_std

Further no_std compat
2023-05-14 09:08:51 +00:00
129 changed files with 11611 additions and 8137 deletions

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
target target
Cargo.lock Cargo.lock
.vscode .vscode
.embuild

View file

@ -1,4 +1,16 @@
[workspace] [workspace]
members = ["matter", "matter_macro_derive", "boxslab", "tools/tlv_tool"] members = ["matter", "matter_macro_derive"]
exclude = ["examples/*", "tools/tlv_tool"]
exclude = ["examples/*"] # For compatibility with ESP IDF
[patch.crates-io]
smol = { git = "https://github.com/esp-rs-compat/smol" }
polling = { git = "https://github.com/esp-rs-compat/polling" }
socket2 = { git = "https://github.com/esp-rs-compat/socket2" }
[profile.release]
opt-level = 3
[profile.dev]
debug = true
opt-level = 3

View file

@ -13,13 +13,31 @@ Building the library:
$ cargo build $ cargo build
``` ```
Building the example: Building and running the example (Linux, MacOS X):
``` ```
$ RUST_LOG="matter" cargo run --example onoff_light $ cargo run --example onoff_light
``` ```
With the chip-tool (the current tool for testing Matter) use the Ethernet commissioning mechanism: Building the example (Espressif's ESP-IDF):
* Install all build prerequisites described [here](https://github.com/esp-rs/esp-idf-template#prerequisites)
* Build with the following command line:
```
export MCU=esp32; export CARGO_TARGET_XTENSA_ESP32_ESPIDF_LINKER=ldproxy; export RUSTFLAGS="-C default-linker-libraries"; export WIFI_SSID=ssid;export WIFI_PASS=pass; cargo build --example onoff_light --no-default-features --features std,crypto_rustcrypto --target xtensa-esp32-espidf -Zbuild-std=std,panic_abort
```
* If you are building for a different Espressif MCU, change the `MCU` variable, the `xtensa-esp32-espidf` target and the name of the `CARGO_TARGET_<esp-idf-target-uppercase>_LINKER` variable to match your MCU and its Rust target. Available Espressif MCUs and targets are:
* esp32 / xtensa-esp32-espidf
* esp32s2 / xtensa-esp32s2-espidf
* esp32s3 / xtensa-esp32s3-espidf
* esp32c3 / riscv32imc-esp-espidf
* esp32c5 / riscv32imc-esp-espidf
* esp32c6 / risxcv32imac-esp-espidf
* Put in `WIFI_SSID` / `WIFI_PASS` the SSID & password for your wireless router
* Flash using the `espflash` utility described in the build prerequsites' link above
## Test
With the `chip-tool` (the current tool for testing Matter) use the Ethernet commissioning mechanism:
``` ```
$ chip-tool pairing code 12344321 <Pairing-Code> $ chip-tool pairing code 12344321 <Pairing-Code>

View file

@ -1,9 +0,0 @@
[package]
name = "boxslab"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bitmaps={version="3.2.0", features=[]}

View file

@ -1,237 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use std::{
mem::MaybeUninit,
ops::{Deref, DerefMut},
sync::Mutex,
};
// TODO: why is max bitmap size 64 a correct max size? Could we match
// boxslabs instead or store used/not used inside the box slabs themselves?
const MAX_BITMAP_SIZE: usize = 64;
pub struct Bitmap {
inner: bitmaps::Bitmap<MAX_BITMAP_SIZE>,
max_size: usize,
}
impl Bitmap {
pub fn new(max_size: usize) -> Self {
assert!(max_size <= MAX_BITMAP_SIZE);
Bitmap {
inner: bitmaps::Bitmap::new(),
max_size,
}
}
pub fn set(&mut self, index: usize) -> bool {
assert!(index < self.max_size);
self.inner.set(index, true)
}
pub fn reset(&mut self, index: usize) -> bool {
assert!(index < self.max_size);
self.inner.set(index, false)
}
pub fn first_false_index(&self) -> Option<usize> {
match self.inner.first_false_index() {
Some(idx) if idx < self.max_size => Some(idx),
_ => None,
}
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn is_full(&self) -> bool {
self.first_false_index().is_none()
}
}
#[macro_export]
macro_rules! box_slab {
($name:ident,$t:ty,$v:expr) => {
use std::mem::MaybeUninit;
use std::sync::Once;
use $crate::{BoxSlab, Slab, SlabPool};
pub struct $name;
impl SlabPool for $name {
type SlabType = $t;
fn get_slab() -> &'static Slab<Self> {
const MAYBE_INIT: MaybeUninit<$t> = MaybeUninit::uninit();
static mut SLAB_POOL: [MaybeUninit<$t>; $v] = [MAYBE_INIT; $v];
static mut SLAB_SPACE: Option<Slab<$name>> = None;
static mut INIT: Once = Once::new();
unsafe {
INIT.call_once(|| {
SLAB_SPACE = Some(Slab::<$name>::init(&mut SLAB_POOL, $v));
});
SLAB_SPACE.as_ref().unwrap()
}
}
}
};
}
pub trait SlabPool {
type SlabType: 'static;
fn get_slab() -> &'static Slab<Self>
where
Self: Sized;
}
pub struct Inner<T: 'static + SlabPool> {
pool: &'static mut [MaybeUninit<T::SlabType>],
map: Bitmap,
}
// TODO: Instead of a mutex, we should replace this with a CAS loop
pub struct Slab<T: 'static + SlabPool>(Mutex<Inner<T>>);
impl<T: SlabPool> Slab<T> {
pub fn init(pool: &'static mut [MaybeUninit<T::SlabType>], size: usize) -> Self {
Self(Mutex::new(Inner {
pool,
map: Bitmap::new(size),
}))
}
pub fn try_new(new_object: T::SlabType) -> Option<BoxSlab<T>> {
let slab = T::get_slab();
let mut inner = slab.0.lock().unwrap();
if let Some(index) = inner.map.first_false_index() {
inner.map.set(index);
inner.pool[index].write(new_object);
let cell_ptr = unsafe { &mut *inner.pool[index].as_mut_ptr() };
Some(BoxSlab {
data: cell_ptr,
index,
})
} else {
None
}
}
pub fn free(&self, index: usize) {
let mut inner = self.0.lock().unwrap();
inner.map.reset(index);
let old_value = std::mem::replace(&mut inner.pool[index], MaybeUninit::uninit());
let _old_value = unsafe { old_value.assume_init() };
// This will drop the old_value
}
}
pub struct BoxSlab<T: 'static + SlabPool> {
// Because the data is a reference within the MaybeUninit, we don't have a mechanism
// to go out to the MaybeUninit from this reference. Hence this index
index: usize,
// TODO: We should figure out a way to get rid of the index too
data: &'static mut T::SlabType,
}
impl<T: 'static + SlabPool> Drop for BoxSlab<T> {
fn drop(&mut self) {
T::get_slab().free(self.index);
}
}
impl<T: SlabPool> Deref for BoxSlab<T> {
type Target = T::SlabType;
fn deref(&self) -> &Self::Target {
self.data
}
}
impl<T: SlabPool> DerefMut for BoxSlab<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data
}
}
#[cfg(test)]
mod tests {
use std::{ops::Deref, sync::Arc};
pub struct Test {
val: Arc<u32>,
}
box_slab!(TestSlab, Test, 3);
#[test]
fn simple_alloc_free() {
{
let a = Slab::<TestSlab>::try_new(Test { val: Arc::new(10) }).unwrap();
assert_eq!(*a.val.deref(), 10);
let inner = TestSlab::get_slab().0.lock().unwrap();
assert!(!inner.map.is_empty());
}
// Validates that the 'Drop' got executed
let inner = TestSlab::get_slab().0.lock().unwrap();
assert!(inner.map.is_empty());
println!("Box Size {}", std::mem::size_of::<Box<Test>>());
println!("BoxSlab Size {}", std::mem::size_of::<BoxSlab<TestSlab>>());
}
#[test]
fn alloc_full_block() {
{
let a = Slab::<TestSlab>::try_new(Test { val: Arc::new(10) }).unwrap();
let b = Slab::<TestSlab>::try_new(Test { val: Arc::new(11) }).unwrap();
let c = Slab::<TestSlab>::try_new(Test { val: Arc::new(12) }).unwrap();
// Test that at overflow, we return None
assert!(Slab::<TestSlab>::try_new(Test { val: Arc::new(13) }).is_none(),);
assert_eq!(*b.val.deref(), 11);
{
let inner = TestSlab::get_slab().0.lock().unwrap();
// Test that the bitmap is marked as full
assert!(inner.map.is_full());
}
// Purposefully drop, to test that new allocation is possible
std::mem::drop(b);
let d = Slab::<TestSlab>::try_new(Test { val: Arc::new(21) }).unwrap();
assert_eq!(*d.val.deref(), 21);
// Ensure older allocations are still valid
assert_eq!(*a.val.deref(), 10);
assert_eq!(*c.val.deref(), 12);
}
// Validates that the 'Drop' got executed - test that the bitmap is empty
let inner = TestSlab::get_slab().0.lock().unwrap();
assert!(inner.map.is_empty());
}
#[test]
fn test_drop_logic() {
let root = Arc::new(10);
{
let _a = Slab::<TestSlab>::try_new(Test { val: root.clone() }).unwrap();
let _b = Slab::<TestSlab>::try_new(Test { val: root.clone() }).unwrap();
let _c = Slab::<TestSlab>::try_new(Test { val: root.clone() }).unwrap();
assert_eq!(Arc::strong_count(&root), 4);
}
// Test that Drop was correctly called on all the members of the pool
assert_eq!(Arc::strong_count(&root), 1);
}
}

View file

@ -16,7 +16,7 @@
*/ */
use matter::data_model::sdm::dev_att::{DataType, DevAttDataFetcher}; use matter::data_model::sdm::dev_att::{DataType, DevAttDataFetcher};
use matter::error::Error; use matter::error::{Error, ErrorCode};
pub struct HardCodedDevAtt {} pub struct HardCodedDevAtt {}
@ -159,7 +159,7 @@ impl DevAttDataFetcher for HardCodedDevAtt {
data.copy_from_slice(src); data.copy_from_slice(src);
Ok(src.len()) Ok(src.len())
} else { } else {
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
} }
} }

View file

@ -15,40 +15,415 @@
* limitations under the License. * limitations under the License.
*/ */
mod dev_att; use core::borrow::Borrow;
use matter::core::{self, CommissioningData}; use core::pin::pin;
use embassy_futures::select::select;
use log::info;
use matter::core::{CommissioningData, Matter};
use matter::data_model::cluster_basic_information::BasicInfoConfig; use matter::data_model::cluster_basic_information::BasicInfoConfig;
use matter::data_model::device_types::device_type_add_on_off_light; use matter::data_model::cluster_on_off;
use matter::data_model::core::DataModel;
use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT;
use matter::data_model::objects::*;
use matter::data_model::root_endpoint;
use matter::data_model::system_model::descriptor;
use matter::error::Error;
use matter::interaction_model::core::InteractionModel;
use matter::mdns::{DefaultMdns, DefaultMdnsRunner};
use matter::secure_channel::spake2p::VerifierData; use matter::secure_channel::spake2p::VerifierData;
use matter::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
fn main() { use matter::transport::{
env_logger::init(); core::RecvAction, core::Transport, packet::MAX_RX_BUF_SIZE, packet::MAX_TX_BUF_SIZE,
let comm_data = CommissioningData { udp::UdpListener,
// TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456),
discriminator: 250,
}; };
use matter::utils::select::EitherUnwrap;
// vid/pid should match those in the DAC mod dev_att;
let dev_info = BasicInfoConfig {
#[cfg(feature = "std")]
fn main() -> Result<(), Error> {
let thread = std::thread::Builder::new()
.stack_size(120 * 1024)
.spawn(run)
.unwrap();
thread.join().unwrap()
}
// NOTE (no_std): For no_std, name this entry point according to your MCU platform
#[cfg(not(feature = "std"))]
#[no_mangle]
fn app_main() {
run().unwrap();
}
fn run() -> Result<(), Error> {
initialize_logger();
info!(
"Matter memory: mDNS={}, Matter={}, Transport={}",
core::mem::size_of::<DefaultMdns>(),
core::mem::size_of::<Matter>(),
core::mem::size_of::<Transport>(),
);
let dev_det = BasicInfoConfig {
vid: 0xFFF1, vid: 0xFFF1,
pid: 0x8000, pid: 0x8000,
hw_ver: 2, hw_ver: 2,
sw_ver: 1, sw_ver: 1,
sw_ver_str: "1".to_string(), sw_ver_str: "1",
serial_no: "aabbccdd".to_string(), serial_no: "aabbccdd",
device_name: "OnOff Light".to_string(), device_name: "OnOff Light",
}; };
let dev_att = Box::new(dev_att::HardCodedDevAtt::new());
let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); let (ipv4_addr, ipv6_addr, interface) = initialize_network()?;
let dm = matter.get_data_model();
let mdns = DefaultMdns::new(
0,
"matter-demo",
ipv4_addr.octets(),
Some(ipv6_addr.octets()),
interface,
&dev_det,
matter::MATTER_PORT,
);
let mut mdns_runner = DefaultMdnsRunner::new(&mdns);
let dev_att = dev_att::HardCodedDevAtt::new();
#[cfg(feature = "std")]
let epoch = matter::utils::epoch::sys_epoch;
#[cfg(feature = "std")]
let rand = matter::utils::rand::sys_rand;
// NOTE (no_std): For no_std, provide your own function here
#[cfg(not(feature = "std"))]
let epoch = matter::utils::epoch::dummy_epoch;
// NOTE (no_std): For no_std, provide your own function here
#[cfg(not(feature = "std"))]
let rand = matter::utils::rand::dummy_rand;
let matter = Matter::new(
// vid/pid should match those in the DAC
&dev_det,
&dev_att,
&mdns,
epoch,
rand,
matter::MATTER_PORT,
);
let psm_path = std::env::temp_dir().join("matter-iot");
info!("Persisting from/to {}", psm_path.display());
#[cfg(all(feature = "std", not(target_os = "espidf")))]
let psm = matter::persist::FilePsm::new(psm_path)?;
let mut buf = [0; 4096];
let buf = &mut buf;
#[cfg(all(feature = "std", not(target_os = "espidf")))]
{ {
let mut node = dm.node.write().unwrap(); if let Some(data) = psm.load("acls", buf)? {
let endpoint = device_type_add_on_off_light(&mut node).unwrap(); matter.load_acls(data)?;
println!("Added OnOff Light Device type at endpoint id: {}", endpoint);
println!("Data Model now is: {}", node);
} }
matter.start_daemon().unwrap(); if let Some(data) = psm.load("fabrics", buf)? {
matter.load_fabrics(data)?;
}
}
let mut transport = Transport::new(&matter);
transport.start(
CommissioningData {
// TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456, *matter.borrow()),
discriminator: 250,
},
buf,
)?;
let node = Node {
id: 0,
endpoints: &[
root_endpoint::endpoint(0),
Endpoint {
id: 1,
device_type: DEV_TYPE_ON_OFF_LIGHT,
clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER],
},
],
};
let mut handler = handler(&matter);
let mut im = InteractionModel(DataModel::new(matter.borrow(), &node, &mut handler));
let mut rx_buf = [0; MAX_RX_BUF_SIZE];
let mut tx_buf = [0; MAX_TX_BUF_SIZE];
let im = &mut im;
let mdns_runner = &mut mdns_runner;
let transport = &mut transport;
let rx_buf = &mut rx_buf;
let tx_buf = &mut tx_buf;
let mut io_fut = pin!(async move {
// NOTE (no_std): On no_std, the `UdpListener` implementation is a no-op so you might want to
// replace it with your own UDP stack
let udp = UdpListener::new(SocketAddr::new(
IpAddr::V6(Ipv6Addr::UNSPECIFIED),
matter::MATTER_PORT,
))
.await?;
loop {
let (len, addr) = udp.recv(rx_buf).await?;
let mut completion = transport.recv(Address::Udp(addr), &mut rx_buf[..len], tx_buf);
while let Some(action) = completion.next_action()? {
match action {
RecvAction::Send(addr, buf) => {
udp.send(addr.unwrap_udp(), buf).await?;
}
RecvAction::Interact(mut ctx) => {
if im.handle(&mut ctx)? && ctx.send()? {
udp.send(ctx.tx.peer.unwrap_udp(), ctx.tx.as_slice())
.await?;
}
}
}
}
#[cfg(all(feature = "std", not(target_os = "espidf")))]
{
if let Some(data) = transport.matter().store_fabrics(buf)? {
psm.store("fabrics", data)?;
}
if let Some(data) = transport.matter().store_acls(buf)? {
psm.store("acls", data)?;
}
}
}
#[allow(unreachable_code)]
Ok::<_, matter::error::Error>(())
});
// NOTE (no_std): On no_std, the `run_udp` is a no-op so you might want to replace it with `run` and
// connect the pipes of the `run` method with your own UDP stack
let mut mdns_fut = pin!(async move { mdns_runner.run_udp().await });
let mut fut = pin!(async move { select(&mut io_fut, &mut mdns_fut).await.unwrap() });
#[cfg(feature = "std")]
smol::block_on(&mut fut)?;
// NOTE (no_std): For no_std, replace with your own more efficient no_std executor,
// because the executor used below is a simple busy-loop poller
#[cfg(not(feature = "std"))]
embassy_futures::block_on(&mut fut)?;
Ok(())
}
fn handler<'a>(matter: &'a Matter<'a>) -> impl Handler + 'a {
root_endpoint::handler(0, matter)
.chain(
1,
descriptor::ID,
descriptor::DescriptorCluster::new(*matter.borrow()),
)
.chain(
1,
cluster_on_off::ID,
cluster_on_off::OnOffCluster::new(*matter.borrow()),
)
}
// NOTE (no_std): For no_std, implement here your own way of initializing the logger
#[cfg(all(not(feature = "std"), not(target_os = "espidf")))]
#[inline(never)]
fn initialize_logger() {}
// NOTE (no_std): For no_std, implement here your own way of initializing the network
#[cfg(all(not(feature = "std"), not(target_os = "espidf")))]
#[inline(never)]
fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> {
Ok((Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0))
}
#[cfg(all(feature = "std", not(target_os = "espidf")))]
#[inline(never)]
fn initialize_logger() {
env_logger::init_from_env(
env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"),
);
}
#[cfg(all(feature = "std", not(target_os = "espidf")))]
#[inline(never)]
fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> {
use log::error;
use matter::error::ErrorCode;
use nix::{net::if_::InterfaceFlags, sys::socket::SockaddrIn6};
let interfaces = || {
nix::ifaddrs::getifaddrs().unwrap().filter(|ia| {
ia.flags
.contains(InterfaceFlags::IFF_UP | InterfaceFlags::IFF_BROADCAST)
&& !ia
.flags
.intersects(InterfaceFlags::IFF_LOOPBACK | InterfaceFlags::IFF_POINTOPOINT)
})
};
// A quick and dirty way to get a network interface that has a link-local IPv6 address assigned as well as a non-loopback IPv4
// Most likely, this is the interface we need
// (as opposed to all the docker and libvirt interfaces that might be assigned on the machine and which seem by default to be IPv4 only)
let (iname, ip, ipv6) = interfaces()
.filter_map(|ia| {
ia.address
.and_then(|addr| addr.as_sockaddr_in6().map(SockaddrIn6::ip))
.filter(|ip| ip.octets()[..2] == [0xfe, 0x80])
.map(|ipv6| (ia.interface_name, ipv6))
})
.filter_map(|(iname, ipv6)| {
interfaces()
.filter(|ia2| ia2.interface_name == iname)
.find_map(|ia2| {
ia2.address
.and_then(|addr| addr.as_sockaddr_in().map(|addr| addr.ip().into()))
.map(|ip| (iname.clone(), ip, ipv6))
})
})
.next()
.ok_or_else(|| {
error!("Cannot find network interface suitable for mDNS broadcasting");
ErrorCode::Network
})?;
info!(
"Will use network interface {} with {}/{} for mDNS",
iname, ip, ipv6
);
Ok((ip, ipv6, 0 as _))
}
#[cfg(target_os = "espidf")]
#[inline(never)]
fn initialize_logger() {
esp_idf_svc::log::EspLogger::initialize_default();
}
#[cfg(target_os = "espidf")]
#[inline(never)]
fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> {
use core::time::Duration;
use embedded_svc::wifi::{AuthMethod, ClientConfiguration, Configuration};
use esp_idf_hal::prelude::Peripherals;
use esp_idf_svc::handle::RawHandle;
use esp_idf_svc::wifi::{BlockingWifi, EspWifi};
use esp_idf_svc::{eventloop::EspSystemEventLoop, nvs::EspDefaultNvsPartition};
use esp_idf_sys::{
self as _, esp, esp_ip6_addr_t, esp_netif_create_ip6_linklocal, esp_netif_get_ip6_linklocal,
}; // If using the `binstart` feature of `esp-idf-sys`, always keep this module imported
const SSID: &'static str = env!("WIFI_SSID");
const PASSWORD: &'static str = env!("WIFI_PASS");
#[allow(clippy::needless_update)]
{
// VFS is necessary for poll-based async IO
esp_idf_sys::esp!(unsafe {
esp_idf_sys::esp_vfs_eventfd_register(&esp_idf_sys::esp_vfs_eventfd_config_t {
max_fds: 5,
..Default::default()
})
})?;
}
let peripherals = Peripherals::take().unwrap();
let sys_loop = EspSystemEventLoop::take()?;
let nvs = EspDefaultNvsPartition::take()?;
let mut wifi = EspWifi::new(peripherals.modem, sys_loop.clone(), Some(nvs))?;
let mut bwifi = BlockingWifi::wrap(&mut wifi, sys_loop)?;
let wifi_configuration: Configuration = Configuration::Client(ClientConfiguration {
ssid: SSID.into(),
bssid: None,
auth_method: AuthMethod::WPA2Personal,
password: PASSWORD.into(),
channel: None,
});
bwifi.set_configuration(&wifi_configuration)?;
bwifi.start()?;
info!("Wifi started");
bwifi.connect()?;
info!("Wifi connected");
esp!(unsafe {
esp_netif_create_ip6_linklocal(bwifi.wifi_mut().sta_netif_mut().handle() as _)
})?;
bwifi.wait_netif_up()?;
info!("Wifi netif up");
let ip_info = wifi.sta_netif().get_ip_info()?;
let mut ipv6: esp_ip6_addr_t = Default::default();
info!("Waiting for IPv6 address");
while esp!(unsafe { esp_netif_get_ip6_linklocal(wifi.sta_netif().handle() as _, &mut ipv6) })
.is_err()
{
info!("Waiting...");
std::thread::sleep(Duration::from_secs(2));
}
info!("Wifi DHCP info: {:?}, IPv6: {:?}", ip_info, ipv6.addr);
let ipv4_octets = ip_info.ip.octets();
let ipv6_octets = [
ipv6.addr[0].to_le_bytes()[0],
ipv6.addr[0].to_le_bytes()[1],
ipv6.addr[0].to_le_bytes()[2],
ipv6.addr[0].to_le_bytes()[3],
ipv6.addr[1].to_le_bytes()[0],
ipv6.addr[1].to_le_bytes()[1],
ipv6.addr[1].to_le_bytes()[2],
ipv6.addr[1].to_le_bytes()[3],
ipv6.addr[2].to_le_bytes()[0],
ipv6.addr[2].to_le_bytes()[1],
ipv6.addr[2].to_le_bytes()[2],
ipv6.addr[2].to_le_bytes()[3],
ipv6.addr[3].to_le_bytes()[0],
ipv6.addr[3].to_le_bytes()[1],
ipv6.addr[3].to_le_bytes()[2],
ipv6.addr[3].to_le_bytes()[3],
];
let interface = wifi.sta_netif().get_index();
// Not OK of course, but for a demo this is good enough
// Wifi will continue to be available and working in the background
core::mem::forget(wifi);
Ok((ipv4_octets.into(), ipv6_octets.into(), interface))
} }

View file

@ -159,7 +159,7 @@ impl DevAttDataFetcher for HardCodedDevAtt {
data.copy_from_slice(src); data.copy_from_slice(src);
Ok(src.len()) Ok(src.len())
} else { } else {
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
} }
} }

View file

@ -15,4 +15,4 @@
* limitations under the License. * limitations under the License.
*/ */
pub mod dev_att; // TODO pub mod dev_att;

View file

@ -15,55 +15,56 @@
* limitations under the License. * limitations under the License.
*/ */
mod dev_att; // TODO
use matter::core::{self, CommissioningData}; // mod dev_att;
use matter::data_model::cluster_basic_information::BasicInfoConfig; // use matter::core::{self, CommissioningData};
use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster}; // use matter::data_model::cluster_basic_information::BasicInfoConfig;
use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER; // use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster};
use matter::secure_channel::spake2p::VerifierData; // use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER;
// use matter::secure_channel::spake2p::VerifierData;
fn main() { fn main() {
env_logger::init(); // env_logger::init();
let comm_data = CommissioningData { // let comm_data = CommissioningData {
// TODO: Hard-coded for now // // TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456), // verifier: VerifierData::new_with_pw(123456),
discriminator: 250, // discriminator: 250,
}; // };
// vid/pid should match those in the DAC // // vid/pid should match those in the DAC
let dev_info = BasicInfoConfig { // let dev_info = BasicInfoConfig {
vid: 0xFFF1, // vid: 0xFFF1,
pid: 0x8002, // pid: 0x8002,
hw_ver: 2, // hw_ver: 2,
sw_ver: 1, // sw_ver: 1,
sw_ver_str: "1".to_string(), // sw_ver_str: "1".to_string(),
serial_no: "aabbccdd".to_string(), // serial_no: "aabbccdd".to_string(),
device_name: "Smart Speaker".to_string(), // device_name: "Smart Speaker".to_string(),
}; // };
let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); // let dev_att = Box::new(dev_att::HardCodedDevAtt::new());
let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); // let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap();
let dm = matter.get_data_model(); // let dm = matter.get_data_model();
{ // {
let mut node = dm.node.write().unwrap(); // let mut node = dm.node.write().unwrap();
let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap(); // let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap();
let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap(); // let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap();
// Add some callbacks // // Add some callbacks
let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback.")); // let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback."));
let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback.")); // let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback."));
let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback.")); // let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback."));
let start_over_callback = // let start_over_callback =
Box::new(|| log::info!("Comamnd [StartOver] handled with callback.")); // Box::new(|| log::info!("Comamnd [StartOver] handled with callback."));
media_playback_cluster.add_callback(Commands::Play, play_callback); // media_playback_cluster.add_callback(Commands::Play, play_callback);
media_playback_cluster.add_callback(Commands::Pause, pause_callback); // media_playback_cluster.add_callback(Commands::Pause, pause_callback);
media_playback_cluster.add_callback(Commands::Stop, stop_callback); // media_playback_cluster.add_callback(Commands::Stop, stop_callback);
media_playback_cluster.add_callback(Commands::StartOver, start_over_callback); // media_playback_cluster.add_callback(Commands::StartOver, start_over_callback);
node.add_cluster(endpoint_audio, media_playback_cluster) // node.add_cluster(endpoint_audio, media_playback_cluster)
.unwrap(); // .unwrap();
println!("Added Speaker type at endpoint id: {}", endpoint_audio) // println!("Added Speaker type at endpoint id: {}", endpoint_audio)
} // }
matter.start_daemon().unwrap(); // matter.start_daemon().unwrap();
} }

View file

@ -0,0 +1,69 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
mod dev_att;
use matter::core::{self, CommissioningData};
use matter::data_model::cluster_basic_information::BasicInfoConfig;
use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster};
use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER;
use matter::secure_channel::spake2p::VerifierData;
fn main() {
env_logger::init();
let comm_data = CommissioningData {
// TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456),
discriminator: 250,
};
// vid/pid should match those in the DAC
let dev_info = BasicInfoConfig {
vid: 0xFFF1,
pid: 0x8002,
hw_ver: 2,
sw_ver: 1,
sw_ver_str: "1".to_string(),
serial_no: "aabbccdd".to_string(),
device_name: "Smart Speaker".to_string(),
};
let dev_att = Box::new(dev_att::HardCodedDevAtt::new());
let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap();
let dm = matter.get_data_model();
{
let mut node = dm.node.write().unwrap();
let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap();
let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap();
// Add some callbacks
let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback."));
let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback."));
let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback."));
let start_over_callback =
Box::new(|| log::info!("Comamnd [StartOver] handled with callback."));
media_playback_cluster.add_callback(Commands::Play, play_callback);
media_playback_cluster.add_callback(Commands::Pause, pause_callback);
media_playback_cluster.add_callback(Commands::Stop, stop_callback);
media_playback_cluster.add_callback(Commands::StartOver, start_over_callback);
node.add_cluster(endpoint_audio, media_playback_cluster)
.unwrap();
println!("Added Speaker type at endpoint id: {}", endpoint_audio)
}
matter.start_daemon().unwrap();
}

View file

@ -15,38 +15,49 @@ name = "matter"
path = "src/lib.rs" path = "src/lib.rs"
[features] [features]
default = ["crypto_mbedtls"] default = ["os", "crypto_rustcrypto"]
crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] os = ["std", "backtrace", "env_logger", "nix", "critical-section/std", "embassy-sync/std", "embassy-time/std"]
crypto_mbedtls = ["mbedtls"] std = ["alloc", "rand", "qrcode", "async-io", "smol", "esp-idf-sys/std"]
crypto_esp_mbedtls = ["esp-idf-sys"] backtrace = []
crypto_rustcrypto = ["sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert"] alloc = []
nightly = []
crypto_openssl = ["alloc", "openssl", "foreign-types", "hmac", "sha2"]
crypto_mbedtls = ["alloc", "mbedtls"]
crypto_rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"]
[dependencies] [dependencies]
boxslab = { path = "../boxslab" }
matter_macro_derive = { path = "../matter_macro_derive" } matter_macro_derive = { path = "../matter_macro_derive" }
bitflags = "1.3" bitflags = { version = "1.3", default-features = false }
byteorder = "1.4.3" byteorder = { version = "1.4.3", default-features = false }
heapless = { version = "0.7.16", features = ["x86-sync-pool"] } heapless = "0.7.16"
generic-array = "0.14.6" num = { version = "0.4", default-features = false }
num = "0.4"
num-derive = "0.3.3" num-derive = "0.3.3"
num-traits = "0.2.15" num-traits = { version = "0.2.15", default-features = false }
strum = { version = "0.24", features = ["derive"], default-features = false }
log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] }
env_logger = { version = "0.10.0", default-features = false, features = [] } no-std-net = "0.6"
rand = "0.8.5" subtle = { version = "2.4.1", default-features = false }
esp-idf-sys = { version = "0.32", features = ["binstart"], optional = true } safemem = { version = "0.3.3", default-features = false }
subtle = "2.4.1" owo-colors = "3"
colored = "2.0.0" time = { version = "0.3", default-features = false }
smol = "1.3.0" verhoeff = { version = "1", default-features = false }
owning_ref = "0.4.1" embassy-futures = "0.1"
safemem = "0.3.3" embassy-time = { version = "0.1.1", features = ["generic-queue-8"] }
chrono = { version = "0.4.23", default-features = false, features = ["clock", "std"] } embassy-sync = "0.2"
async-channel = "1.8" critical-section = "1.1.1"
domain = { version = "0.7.2", default_features = false, features = ["heapless"] }
# STD-only dependencies
rand = { version = "0.8.5", optional = true }
qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code
smol = { version = "1.2", optional = true } # =1.2 for compatibility with ESP IDF
async-io = { version = "=1.12", optional = true } # =1.2 for compatibility with ESP IDF
# crypto # crypto
openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true }
foreign-types = { version = "0.3.2", optional = true } foreign-types = { version = "0.3.2", optional = true }
mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true }
# rust-crypto
sha2 = { version = "0.10", default-features = false, optional = true } sha2 = { version = "0.10", default-features = false, optional = true }
hmac = { version = "0.12", optional = true } hmac = { version = "0.12", optional = true }
pbkdf2 = { version = "0.12", optional = true } pbkdf2 = { version = "0.12", optional = true }
@ -56,28 +67,30 @@ ccm = { version = "0.5", default-features = false, features = ["alloc"], optiona
p256 = { version = "0.13.0", default-features = false, features = ["arithmetic", "ecdh", "ecdsa"], optional = true } p256 = { version = "0.13.0", default-features = false, features = ["arithmetic", "ecdh", "ecdsa"], optional = true }
elliptic-curve = { version = "0.13.2", optional = true } elliptic-curve = { version = "0.13.2", optional = true }
crypto-bigint = { version = "0.4", default-features = false, optional = true } crypto-bigint = { version = "0.4", default-features = false, optional = true }
# Note: requires std rand_core = { version = "0.6", default-features = false, optional = true }
x509-cert = { version = "0.2.0", default-features = false, features = ["pem", "std"], optional = true } x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], optional = true } # TODO: requires `alloc`
# to compute the check digit
verhoeff = "1"
# print QR code
qrcode = { version = "0.12", default-features = false }
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
astro-dnssd = "0.3" astro-dnssd = { version = "0.3" }
# MDNS support [target.'cfg(not(target_os = "espidf"))'.dependencies]
[target.'cfg(target_os = "linux")'.dependencies] mbedtls = { git = "https://github.com/fortanix/rust-mbedtls", optional = true }
lazy_static = "1.4.0" env_logger = { version = "0.10.0", optional = true }
libmdns = { version = "0.7.4" } nix = { version = "0.26", features = ["net"], optional = true }
[target.'cfg(target_os = "espidf")'.dependencies]
esp-idf-sys = { version = "0.33", default-features = false, features = ["native", "binstart"] }
esp-idf-hal = { version = "0.41", features = ["embassy-sync", "critical-section"] }
esp-idf-svc = { version = "0.46", features = ["embassy-time-driver"] }
embedded-svc = "0.25"
[build-dependencies]
embuild = "0.31.2"
[[example]] [[example]]
name = "onoff_light" name = "onoff_light"
path = "../examples/onoff_light/src/main.rs" path = "../examples/onoff_light/src/main.rs"
# [[example]]
[[example]] # name = "speaker"
name = "speaker" # path = "../examples/speaker/src/main.rs"
path = "../examples/speaker/src/main.rs"

11
matter/build.rs Normal file
View file

@ -0,0 +1,11 @@
use std::env::var;
// Necessary because of this issue: https://github.com/rust-lang/cargo/issues/9641
fn main() -> Result<(), Box<dyn std::error::Error>> {
if var("TARGET").unwrap().ends_with("-espidf") {
embuild::build::CfgArgs::output_propagated("ESP_IDF")?;
embuild::build::LinkArgs::output_propagated("ESP_IDF")?;
}
Ok(())
}

View file

@ -15,19 +15,15 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::{cell::RefCell, fmt::Display};
fmt::Display,
sync::{Arc, Mutex, MutexGuard, RwLock},
};
use crate::{ use crate::{
data_model::objects::{Access, ClusterId, EndptId, Privilege}, data_model::objects::{Access, ClusterId, EndptId, Privilege},
error::Error, error::{Error, ErrorCode},
fabric, fabric,
interaction_model::messages::GenericPath, interaction_model::messages::GenericPath,
sys::Psm, tlv::{self, FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV},
tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC},
transport::session::MAX_CAT_IDS_PER_NOC,
utils::writebuf::WriteBuf, utils::writebuf::WriteBuf,
}; };
use log::error; use log::error;
@ -54,7 +50,7 @@ impl FromTLV<'_> for AuthMode {
{ {
num::FromPrimitive::from_u32(t.u32()?) num::FromPrimitive::from_u32(t.u32()?)
.filter(|a| *a != AuthMode::Invalid) .filter(|a| *a != AuthMode::Invalid)
.ok_or(Error::Invalid) .ok_or_else(|| ErrorCode::Invalid.into())
} }
} }
@ -116,7 +112,7 @@ impl AccessorSubjects {
return Ok(()); return Ok(());
} }
} }
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
/// Match the match_subject with any of the current subjects /// Match the match_subject with any of the current subjects
@ -146,7 +142,7 @@ impl AccessorSubjects {
} }
impl Display for AccessorSubjects { impl Display for AccessorSubjects {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::result::Result<(), core::fmt::Error> {
write!(f, "[")?; write!(f, "[")?;
for i in self.0 { for i in self.0 {
if is_noc_cat(i) { if is_noc_cat(i) {
@ -160,7 +156,7 @@ impl Display for AccessorSubjects {
} }
/// The Accessor Object /// The Accessor Object
pub struct Accessor { pub struct Accessor<'a> {
/// The fabric index of the accessor /// The fabric index of the accessor
pub fab_idx: u8, pub fab_idx: u8,
/// Accessor's subject: could be node-id, NoC CAT, group id /// Accessor's subject: could be node-id, NoC CAT, group id
@ -168,15 +164,37 @@ pub struct Accessor {
/// The Authmode of this session /// The Authmode of this session
auth_mode: AuthMode, auth_mode: AuthMode,
// TODO: Is this the right place for this though, or should we just use a global-acl-handle-get // TODO: Is this the right place for this though, or should we just use a global-acl-handle-get
acl_mgr: Arc<AclMgr>, acl_mgr: &'a RefCell<AclMgr>,
} }
impl Accessor { impl<'a> Accessor<'a> {
pub fn new( pub fn for_session(session: &Session, acl_mgr: &'a RefCell<AclMgr>) -> Self {
match session.get_session_mode() {
SessionMode::Case(c) => {
let mut subject =
AccessorSubjects::new(session.get_peer_node_id().unwrap_or_default());
for i in c.cat_ids {
if i != 0 {
let _ = subject.add_catid(i);
}
}
Accessor::new(c.fab_idx, subject, AuthMode::Case, acl_mgr)
}
SessionMode::Pase => {
Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, acl_mgr)
}
SessionMode::PlainText => {
Accessor::new(0, AccessorSubjects::new(1), AuthMode::Invalid, acl_mgr)
}
}
}
pub const fn new(
fab_idx: u8, fab_idx: u8,
subjects: AccessorSubjects, subjects: AccessorSubjects,
auth_mode: AuthMode, auth_mode: AuthMode,
acl_mgr: Arc<AclMgr>, acl_mgr: &'a RefCell<AclMgr>,
) -> Self { ) -> Self {
Self { Self {
fab_idx, fab_idx,
@ -188,9 +206,9 @@ impl Accessor {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct AccessDesc<'a> { pub struct AccessDesc {
/// The object to be acted upon /// The object to be acted upon
path: &'a GenericPath, path: GenericPath,
/// The target permissions /// The target permissions
target_perms: Option<Access>, target_perms: Option<Access>,
// The operation being done // The operation being done
@ -200,8 +218,8 @@ pub struct AccessDesc<'a> {
/// Access Request Object /// Access Request Object
pub struct AccessReq<'a> { pub struct AccessReq<'a> {
accessor: &'a Accessor, accessor: &'a Accessor<'a>,
object: AccessDesc<'a>, object: AccessDesc,
} }
impl<'a> AccessReq<'a> { impl<'a> AccessReq<'a> {
@ -209,7 +227,7 @@ impl<'a> AccessReq<'a> {
/// ///
/// An access request specifies the _accessor_ attempting to access _path_ /// An access request specifies the _accessor_ attempting to access _path_
/// with _operation_ /// with _operation_
pub fn new(accessor: &'a Accessor, path: &'a GenericPath, operation: Access) -> Self { pub fn new(accessor: &'a Accessor, path: GenericPath, operation: Access) -> Self {
AccessReq { AccessReq {
accessor, accessor,
object: AccessDesc { object: AccessDesc {
@ -220,6 +238,10 @@ impl<'a> AccessReq<'a> {
} }
} }
pub fn operation(&self) -> Access {
self.object.operation
}
/// Add target's permissions to the request /// Add target's permissions to the request
/// ///
/// The permissions that are associated with the target (identified by the /// The permissions that are associated with the target (identified by the
@ -234,11 +256,11 @@ impl<'a> AccessReq<'a> {
/// _accessor_ the necessary privileges to access the target as per its /// _accessor_ the necessary privileges to access the target as per its
/// permissions /// permissions
pub fn allow(&self) -> bool { pub fn allow(&self) -> bool {
self.accessor.acl_mgr.allow(self) self.accessor.acl_mgr.borrow().allow(self)
} }
} }
#[derive(FromTLV, ToTLV, Copy, Clone, Debug, PartialEq)] #[derive(FromTLV, ToTLV, Clone, Debug, PartialEq)]
pub struct Target { pub struct Target {
cluster: Option<ClusterId>, cluster: Option<ClusterId>,
endpoint: Option<EndptId>, endpoint: Option<EndptId>,
@ -261,7 +283,7 @@ impl Target {
type Subjects = [Option<u64>; SUBJECTS_PER_ENTRY]; type Subjects = [Option<u64>; SUBJECTS_PER_ENTRY];
type Targets = [Option<Target>; TARGETS_PER_ENTRY]; type Targets = [Option<Target>; TARGETS_PER_ENTRY];
#[derive(ToTLV, FromTLV, Copy, Clone, Debug, PartialEq)] #[derive(ToTLV, FromTLV, Clone, Debug, PartialEq)]
#[tlvargs(start = 1)] #[tlvargs(start = 1)]
pub struct AclEntry { pub struct AclEntry {
privilege: Privilege, privilege: Privilege,
@ -292,7 +314,7 @@ impl AclEntry {
.subjects .subjects
.iter() .iter()
.position(|s| s.is_none()) .position(|s| s.is_none())
.ok_or(Error::NoSpace)?; .ok_or(ErrorCode::NoSpace)?;
self.subjects[index] = Some(subject); self.subjects[index] = Some(subject);
Ok(()) Ok(())
} }
@ -306,7 +328,7 @@ impl AclEntry {
.targets .targets
.iter() .iter()
.position(|s| s.is_none()) .position(|s| s.is_none())
.ok_or(Error::NoSpace)?; .ok_or(ErrorCode::NoSpace)?;
self.targets[index] = Some(target); self.targets[index] = Some(target);
Ok(()) Ok(())
} }
@ -367,35 +389,151 @@ impl AclEntry {
} }
const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS;
type AclEntries = [Option<AclEntry>; MAX_ACL_ENTRIES];
#[derive(ToTLV, FromTLV, Debug)] type AclEntries = heapless::Vec<Option<AclEntry>, MAX_ACL_ENTRIES>;
struct AclMgrInner {
pub struct AclMgr {
entries: AclEntries, entries: AclEntries,
changed: bool,
} }
const ACL_KV_ENTRY: &str = "acl"; impl AclMgr {
const ACL_KV_MAX_SIZE: usize = 300; #[inline(always)]
impl AclMgrInner { pub const fn new() -> Self {
pub fn store(&self, psm: &MutexGuard<Psm>) -> Result<(), Error> { Self {
let mut acl_tlvs = [0u8; ACL_KV_MAX_SIZE]; entries: AclEntries::new(),
let mut wb = WriteBuf::new(&mut acl_tlvs, ACL_KV_MAX_SIZE); changed: false,
let mut tw = TLVWriter::new(&mut wb); }
self.entries.to_tlv(&mut tw, TagType::Anonymous)?;
psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice())
} }
pub fn load(psm: &MutexGuard<Psm>) -> Result<Self, Error> { pub fn erase_all(&mut self) -> Result<(), Error> {
let mut acl_tlvs = Vec::new(); self.entries.clear();
psm.get_kv_slice(ACL_KV_ENTRY, &mut acl_tlvs)?; self.changed = true;
let root = TLVList::new(&acl_tlvs)
Ok(())
}
pub fn add(&mut self, entry: AclEntry) -> Result<(), Error> {
let cnt = self
.entries
.iter() .iter()
.next() .flatten()
.ok_or(Error::Invalid)?; .filter(|a| a.fab_idx == entry.fab_idx)
.count();
if cnt >= ENTRIES_PER_FABRIC {
Err(ErrorCode::NoSpace)?;
}
Ok(Self { let slot = self.entries.iter().position(|a| a.is_none());
entries: AclEntries::from_tlv(&root)?,
}) if slot.is_some() || self.entries.len() < MAX_ACL_ENTRIES {
if let Some(index) = slot {
self.entries[index] = Some(entry);
} else {
self.entries
.push(Some(entry))
.map_err(|_| ErrorCode::NoSpace)
.unwrap();
}
self.changed = true;
}
Ok(())
}
// Since the entries are fabric-scoped, the index is only for entries with the matching fabric index
pub fn edit(&mut self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> {
let old = self.for_index_in_fabric(index, fab_idx)?;
*old = Some(new);
self.changed = true;
Ok(())
}
pub fn delete(&mut self, index: u8, fab_idx: u8) -> Result<(), Error> {
let old = self.for_index_in_fabric(index, fab_idx)?;
*old = None;
self.changed = true;
Ok(())
}
pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> {
for entry in &mut self.entries {
if entry
.as_ref()
.map(|e| e.fab_idx == Some(fab_idx))
.unwrap_or(false)
{
*entry = None;
self.changed = true;
}
}
Ok(())
}
pub fn for_each_acl<T>(&self, mut f: T) -> Result<(), Error>
where
T: FnMut(&AclEntry) -> Result<(), Error>,
{
for entry in self.entries.iter().flatten() {
f(entry)?;
}
Ok(())
}
pub fn allow(&self, req: &AccessReq) -> bool {
// PASE Sessions have implicit access grant
if req.accessor.auth_mode == AuthMode::Pase {
return true;
}
for e in self.entries.iter().flatten() {
if e.allow(req) {
return true;
}
}
error!(
"ACL Disallow for subjects {} fab idx {}",
req.accessor.subjects, req.accessor.fab_idx
);
error!("{}", self);
false
}
pub fn load(&mut self, data: &[u8]) -> Result<(), Error> {
let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?;
tlv::from_tlv(&mut self.entries, &root)?;
self.changed = false;
Ok(())
}
pub fn store<'a>(&mut self, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> {
if self.changed {
let mut wb = WriteBuf::new(buf);
let mut tw = TLVWriter::new(&mut wb);
self.entries
.as_slice()
.to_tlv(&mut tw, TagType::Anonymous)?;
self.changed = false;
let len = tw.get_tail();
Ok(Some(&buf[..len]))
} else {
Ok(None)
}
}
pub fn is_changed(&self) -> bool {
self.changed
} }
/// Traverse fabric specific entries to find the index /// Traverse fabric specific entries to find the index
@ -411,180 +549,25 @@ impl AclMgrInner {
for (curr_index, entry) in self for (curr_index, entry) in self
.entries .entries
.iter_mut() .iter_mut()
.filter(|e| e.filter(|e1| e1.fab_idx == Some(fab_idx)).is_some()) .filter(|e| {
e.as_ref()
.filter(|e1| e1.fab_idx == Some(fab_idx))
.is_some()
})
.enumerate() .enumerate()
{ {
if curr_index == index as usize { if curr_index == index as usize {
return Ok(entry); return Ok(entry);
} }
} }
Err(Error::NotFound) Err(ErrorCode::NotFound.into())
} }
} }
pub struct AclMgr { impl core::fmt::Display for AclMgr {
inner: RwLock<AclMgrInner>, fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
// The Option<> is solely because test execution is faster
// Doing this here adds the least overhead during ACL verification
psm: Option<Arc<Mutex<Psm>>>,
}
impl AclMgr {
pub fn new() -> Result<Self, Error> {
AclMgr::new_with(true)
}
pub fn new_with(psm_support: bool) -> Result<Self, Error> {
const INIT: Option<AclEntry> = None;
let mut psm = None;
let inner = if !psm_support {
AclMgrInner {
entries: [INIT; MAX_ACL_ENTRIES],
}
} else {
let psm_handle = Psm::get()?;
let inner = {
let psm_lock = psm_handle.lock().unwrap();
AclMgrInner::load(&psm_lock)
};
psm = Some(psm_handle);
inner.unwrap_or({
// Error loading from PSM
AclMgrInner {
entries: [INIT; MAX_ACL_ENTRIES],
}
})
};
Ok(Self {
inner: RwLock::new(inner),
psm,
})
}
pub fn erase_all(&self) {
let mut inner = self.inner.write().unwrap();
for i in 0..MAX_ACL_ENTRIES {
inner.entries[i] = None;
}
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
let _ = inner.store(&psm).map_err(|e| {
error!("Error in storing ACLs {}", e);
});
}
}
pub fn add(&self, entry: AclEntry) -> Result<(), Error> {
let mut inner = self.inner.write().unwrap();
let cnt = inner
.entries
.iter()
.flatten()
.filter(|a| a.fab_idx == entry.fab_idx)
.count();
if cnt >= ENTRIES_PER_FABRIC {
return Err(Error::NoSpace);
}
let index = inner
.entries
.iter()
.position(|a| a.is_none())
.ok_or(Error::NoSpace)?;
inner.entries[index] = Some(entry);
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
inner.store(&psm)
} else {
Ok(())
}
}
// Since the entries are fabric-scoped, the index is only for entries with the matching fabric index
pub fn edit(&self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> {
let mut inner = self.inner.write().unwrap();
let old = inner.for_index_in_fabric(index, fab_idx)?;
*old = Some(new);
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
inner.store(&psm)
} else {
Ok(())
}
}
pub fn delete(&self, index: u8, fab_idx: u8) -> Result<(), Error> {
let mut inner = self.inner.write().unwrap();
let old = inner.for_index_in_fabric(index, fab_idx)?;
*old = None;
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
inner.store(&psm)
} else {
Ok(())
}
}
pub fn delete_for_fabric(&self, fab_idx: u8) -> Result<(), Error> {
let mut inner = self.inner.write().unwrap();
for i in 0..MAX_ACL_ENTRIES {
if inner.entries[i]
.filter(|e| e.fab_idx == Some(fab_idx))
.is_some()
{
inner.entries[i] = None;
}
}
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
inner.store(&psm)
} else {
Ok(())
}
}
pub fn for_each_acl<T>(&self, mut f: T) -> Result<(), Error>
where
T: FnMut(&AclEntry),
{
let inner = self.inner.read().unwrap();
for entry in inner.entries.iter().flatten() {
f(entry)
}
Ok(())
}
pub fn allow(&self, req: &AccessReq) -> bool {
// PASE Sessions have implicit access grant
if req.accessor.auth_mode == AuthMode::Pase {
return true;
}
let inner = self.inner.read().unwrap();
for e in inner.entries.iter().flatten() {
if e.allow(req) {
return true;
}
}
error!(
"ACL Disallow for subjects {} fab idx {}",
req.accessor.subjects, req.accessor.fab_idx
);
error!("{}", self);
false
}
}
impl std::fmt::Display for AclMgr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.inner.read().unwrap();
write!(f, "ACLS: [")?; write!(f, "ACLS: [")?;
for i in inner.entries.iter().flatten() { for i in self.entries.iter().flatten() {
write!(f, " {{ {:?} }}, ", i)?; write!(f, " {{ {:?} }}, ", i)?;
} }
write!(f, "]") write!(f, "]")
@ -594,22 +577,23 @@ impl std::fmt::Display for AclMgr {
#[cfg(test)] #[cfg(test)]
#[allow(clippy::bool_assert_comparison)] #[allow(clippy::bool_assert_comparison)]
mod tests { mod tests {
use core::cell::RefCell;
use crate::{ use crate::{
acl::{gen_noc_cat, AccessorSubjects}, acl::{gen_noc_cat, AccessorSubjects},
data_model::objects::{Access, Privilege}, data_model::objects::{Access, Privilege},
interaction_model::messages::GenericPath, interaction_model::messages::GenericPath,
}; };
use std::sync::Arc;
use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target}; use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target};
#[test] #[test]
fn test_basic_empty_subject_target() { fn test_basic_empty_subject_target() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Default deny // Default deny
@ -617,46 +601,46 @@ mod tests {
// Deny for session mode mismatch // Deny for session mode mismatch
let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Pase); let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Pase);
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Deny for fab idx mismatch // Deny for fab idx mismatch
let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Case);
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow // Allow
let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_subject() { fn test_subject() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Deny for subject mismatch // Deny for subject mismatch
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject(112232).unwrap(); new.add_subject(112232).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow for subject match - target is wildcard // Allow for subject match - target is wildcard
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_cat() { fn test_cat() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let allow_cat = 0xABCD; let allow_cat = 0xABCD;
let disallow_cat = 0xCAFE; let disallow_cat = 0xCAFE;
@ -666,35 +650,35 @@ mod tests {
let mut subjects = AccessorSubjects::new(112233); let mut subjects = AccessorSubjects::new(112233);
subjects.add_catid(gen_noc_cat(allow_cat, v2)).unwrap(); subjects.add_catid(gen_noc_cat(allow_cat, v2)).unwrap();
let accessor = Accessor::new(2, subjects, AuthMode::Case, am.clone()); let accessor = Accessor::new(2, subjects, AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Deny for CAT id mismatch // Deny for CAT id mismatch
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) new.add_subject_catid(gen_noc_cat(disallow_cat, v2))
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Deny of CAT version mismatch // Deny of CAT version mismatch
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(allow_cat, v3)).unwrap(); new.add_subject_catid(gen_noc_cat(allow_cat, v3)).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow for CAT match // Allow for CAT match
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_cat_version() { fn test_cat_version() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let allow_cat = 0xABCD; let allow_cat = 0xABCD;
let disallow_cat = 0xCAFE; let disallow_cat = 0xCAFE;
@ -704,32 +688,32 @@ mod tests {
let mut subjects = AccessorSubjects::new(112233); let mut subjects = AccessorSubjects::new(112233);
subjects.add_catid(gen_noc_cat(allow_cat, v3)).unwrap(); subjects.add_catid(gen_noc_cat(allow_cat, v3)).unwrap();
let accessor = Accessor::new(2, subjects, AuthMode::Case, am.clone()); let accessor = Accessor::new(2, subjects, AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Deny for CAT id mismatch // Deny for CAT id mismatch
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) new.add_subject_catid(gen_noc_cat(disallow_cat, v2))
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow for CAT match and version more than ACL version // Allow for CAT match and version more than ACL version
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_target() { fn test_target() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Deny for target mismatch // Deny for target mismatch
@ -740,7 +724,7 @@ mod tests {
device_type: None, device_type: None,
}) })
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow for cluster match - subject wildcard // Allow for cluster match - subject wildcard
@ -751,11 +735,11 @@ mod tests {
device_type: None, device_type: None,
}) })
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
// Clean Slate // Clean Slate
am.erase_all(); am.borrow_mut().erase_all().unwrap();
// Allow for endpoint match - subject wildcard // Allow for endpoint match - subject wildcard
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
@ -765,11 +749,11 @@ mod tests {
device_type: None, device_type: None,
}) })
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
// Clean Slate // Clean Slate
am.erase_all(); am.borrow_mut().erase_all().unwrap();
// Allow for exact match // Allow for exact match
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
@ -780,16 +764,15 @@ mod tests {
}) })
.unwrap(); .unwrap();
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_privilege() { fn test_privilege() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone());
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
// Create an Exact Match ACL with View privilege // Create an Exact Match ACL with View privilege
@ -801,10 +784,10 @@ mod tests {
}) })
.unwrap(); .unwrap();
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
// Write on an RWVA without admin access - deny // Write on an RWVA without admin access - deny
let mut req = AccessReq::new(&accessor, &path, Access::WRITE); let mut req = AccessReq::new(&accessor, path.clone(), Access::WRITE);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
@ -817,40 +800,40 @@ mod tests {
}) })
.unwrap(); .unwrap();
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
// Write on an RWVA with admin access - allow // Write on an RWVA with admin access - allow
let mut req = AccessReq::new(&accessor, &path, Access::WRITE); let mut req = AccessReq::new(&accessor, path, Access::WRITE);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_delete_for_fabric() { fn test_delete_for_fabric() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let mut req2 = AccessReq::new(&accessor2, &path, Access::READ); let mut req2 = AccessReq::new(&accessor2, path.clone(), Access::READ);
req2.set_target_perms(Access::RWVA); req2.set_target_perms(Access::RWVA);
let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, &am);
let mut req3 = AccessReq::new(&accessor3, &path, Access::READ); let mut req3 = AccessReq::new(&accessor3, path, Access::READ);
req3.set_target_perms(Access::RWVA); req3.set_target_perms(Access::RWVA);
// Allow for subject match - target is wildcard - Fabric idx 2 // Allow for subject match - target is wildcard - Fabric idx 2
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
// Allow for subject match - target is wildcard - Fabric idx 3 // Allow for subject match - target is wildcard - Fabric idx 3
let mut new = AclEntry::new(3, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(3, Privilege::VIEW, AuthMode::Case);
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
// Req for Fabric idx 2 gets denied, and that for Fabric idx 3 is allowed // Req for Fabric idx 2 gets denied, and that for Fabric idx 3 is allowed
assert_eq!(req2.allow(), true); assert_eq!(req2.allow(), true);
assert_eq!(req3.allow(), true); assert_eq!(req3.allow(), true);
am.delete_for_fabric(2).unwrap(); am.borrow_mut().delete_for_fabric(2).unwrap();
assert_eq!(req2.allow(), false); assert_eq!(req2.allow(), false);
assert_eq!(req3.allow(), true); assert_eq!(req3.allow(), true);
} }

View file

@ -15,10 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
use time::OffsetDateTime;
use super::{CertConsumer, MAX_DEPTH}; use super::{CertConsumer, MAX_DEPTH};
use crate::error::Error; use crate::{
use chrono::{Datelike, TimeZone, Utc}; error::{Error, ErrorCode},
use log::warn; utils::epoch::MATTER_EPOCH_SECS,
};
use core::fmt::Write;
#[derive(Debug)] #[derive(Debug)]
pub struct ASN1Writer<'a> { pub struct ASN1Writer<'a> {
@ -52,7 +56,7 @@ impl<'a> ASN1Writer<'a> {
self.offset += size; self.offset += size;
return Ok(()); return Ok(());
} }
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
pub fn append_tlv<F>(&mut self, tag: u8, len: usize, f: F) -> Result<(), Error> pub fn append_tlv<F>(&mut self, tag: u8, len: usize, f: F) -> Result<(), Error>
@ -68,7 +72,7 @@ impl<'a> ASN1Writer<'a> {
self.offset += len; self.offset += len;
return Ok(()); return Ok(());
} }
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
fn add_compound(&mut self, val: u8) -> Result<(), Error> { fn add_compound(&mut self, val: u8) -> Result<(), Error> {
@ -78,7 +82,7 @@ impl<'a> ASN1Writer<'a> {
self.depth[self.current_depth] = self.offset; self.depth[self.current_depth] = self.offset;
self.current_depth += 1; self.current_depth += 1;
if self.current_depth >= MAX_DEPTH { if self.current_depth >= MAX_DEPTH {
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} else { } else {
Ok(()) Ok(())
} }
@ -111,7 +115,7 @@ impl<'a> ASN1Writer<'a> {
fn end_compound(&mut self) -> Result<(), Error> { fn end_compound(&mut self) -> Result<(), Error> {
if self.current_depth == 0 { if self.current_depth == 0 {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
let seq_len = self.get_compound_len(); let seq_len = self.get_compound_len();
let write_offset = self.get_length_encoding_offset(); let write_offset = self.get_length_encoding_offset();
@ -146,7 +150,7 @@ impl<'a> ASN1Writer<'a> {
// This is done with an 0xA2 followed by 2 bytes of actual len // This is done with an 0xA2 followed by 2 bytes of actual len
3 3
} else { } else {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?
}; };
Ok(len) Ok(len)
} }
@ -261,28 +265,38 @@ impl<'a> CertConsumer for ASN1Writer<'a> {
} }
fn utctime(&mut self, _tag: &str, epoch: u32) -> Result<(), Error> { fn utctime(&mut self, _tag: &str, epoch: u32) -> Result<(), Error> {
let mut matter_epoch = Utc let matter_epoch = MATTER_EPOCH_SECS + epoch as u64;
.with_ymd_and_hms(2000, 1, 1, 0, 0, 0)
.unwrap()
.timestamp();
matter_epoch += epoch as i64; let dt = OffsetDateTime::from_unix_timestamp(matter_epoch as _).unwrap();
let dt = match Utc.timestamp_opt(matter_epoch, 0) { let mut time_str: heapless::String<32> = heapless::String::<32>::new();
chrono::LocalResult::None => return Err(Error::InvalidTime),
chrono::LocalResult::Single(s) => s,
chrono::LocalResult::Ambiguous(_, a) => {
warn!("Ambiguous time for epoch {epoch}; returning latest timestamp: {a}");
a
}
};
if dt.year() >= 2050 { if dt.year() >= 2050 {
// If year is >= 2050, ASN.1 requires it to be Generalised Time // If year is >= 2050, ASN.1 requires it to be Generalised Time
let time_str = format!("{}Z", dt.format("%Y%m%d%H%M%S")); write!(
&mut time_str,
"{:04}{:02}{:02}{:02}{:02}{:02}Z",
dt.year(),
dt.month() as u8,
dt.day(),
dt.hour(),
dt.minute(),
dt.second()
)
.unwrap();
self.write_str(0x18, time_str.as_bytes()) self.write_str(0x18, time_str.as_bytes())
} else { } else {
let time_str = format!("{}Z", dt.format("%y%m%d%H%M%S")); write!(
&mut time_str,
"{:02}{:02}{:02}{:02}{:02}{:02}Z",
dt.year() % 100,
dt.month() as u8,
dt.day(),
dt.hour(),
dt.minute(),
dt.second()
)
.unwrap();
self.write_str(0x17, time_str.as_bytes()) self.write_str(0x17, time_str.as_bytes())
} }
} }

View file

@ -15,12 +15,12 @@
* limitations under the License. * limitations under the License.
*/ */
use std::fmt; use core::fmt::{self, Write};
use crate::{ use crate::{
crypto::{CryptoKeyPair, KeyPair}, crypto::KeyPair,
error::Error, error::{Error, ErrorCode},
tlv::{self, FromTLV, TLVArrayOwned, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV},
utils::writebuf::WriteBuf, utils::writebuf::WriteBuf,
}; };
use log::error; use log::error;
@ -29,6 +29,8 @@ use num_derive::FromPrimitive;
pub use self::asn1_writer::ASN1Writer; pub use self::asn1_writer::ASN1Writer;
use self::printer::CertPrinter; use self::printer::CertPrinter;
pub const MAX_CERT_TLV_LEN: usize = 1024; // TODO
// As per https://datatracker.ietf.org/doc/html/rfc5280 // As per https://datatracker.ietf.org/doc/html/rfc5280
const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01]; const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01];
@ -113,8 +115,10 @@ macro_rules! add_if {
}; };
} }
fn get_print_str(key_usage: u16) -> String { fn get_print_str(key_usage: u16) -> heapless::String<256> {
format!( let mut string = heapless::String::new();
write!(
&mut string,
"{}{}{}{}{}{}{}{}{}", "{}{}{}{}{}{}{}{}{}",
add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "), add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "),
add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "), add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "),
@ -126,6 +130,9 @@ fn get_print_str(key_usage: u16) -> String {
add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "), add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "),
add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "), add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "),
) )
.unwrap();
string
} }
#[allow(unused_assignments)] #[allow(unused_assignments)]
@ -137,7 +144,7 @@ fn encode_key_usage(key_usage: u16, w: &mut dyn CertConsumer) -> Result<(), Erro
} }
fn encode_extended_key_usage( fn encode_extended_key_usage(
list: &TLVArrayOwned<u8>, list: impl Iterator<Item = u8>,
w: &mut dyn CertConsumer, w: &mut dyn CertConsumer,
) -> Result<(), Error> { ) -> Result<(), Error> {
const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01]; const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01];
@ -157,19 +164,18 @@ fn encode_extended_key_usage(
]; ];
w.start_seq("")?; w.start_seq("")?;
for t in list.iter() { for t in list {
let t = *t as usize; let t = t as usize;
if t > 0 && t <= encoding.len() { if t > 0 && t <= encoding.len() {
w.oid(encoding[t].0, encoding[t].1)?; w.oid(encoding[t].0, encoding[t].1)?;
} else { } else {
error!("Skipping encoding key usage out of bounds"); error!("Skipping encoding key usage out of bounds");
} }
} }
w.end_seq()?; w.end_seq()
Ok(())
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1)] #[tlvargs(start = 1)]
struct BasicConstraints { struct BasicConstraints {
is_ca: bool, is_ca: bool,
@ -209,18 +215,18 @@ fn encode_extension_end(w: &mut dyn CertConsumer) -> Result<(), Error> {
w.end_seq() w.end_seq()
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1, datatype = "list")] #[tlvargs(lifetime = "'a", start = 1, datatype = "list")]
struct Extensions { struct Extensions<'a> {
basic_const: Option<BasicConstraints>, basic_const: Option<BasicConstraints>,
key_usage: Option<u16>, key_usage: Option<u16>,
ext_key_usage: Option<TLVArrayOwned<u8>>, ext_key_usage: Option<TLVArray<'a, u8>>,
subj_key_id: Option<Vec<u8>>, subj_key_id: Option<OctetStr<'a>>,
auth_key_id: Option<Vec<u8>>, auth_key_id: Option<OctetStr<'a>>,
future_extensions: Option<Vec<u8>>, future_extensions: Option<OctetStr<'a>>,
} }
impl Extensions { impl<'a> Extensions<'a> {
fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> {
const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13]; const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13];
const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F]; const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F];
@ -242,30 +248,29 @@ impl Extensions {
} }
if let Some(t) = &self.ext_key_usage { if let Some(t) = &self.ext_key_usage {
encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?; encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?;
encode_extended_key_usage(t, w)?; encode_extended_key_usage(t.iter(), w)?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.subj_key_id { if let Some(t) = &self.subj_key_id {
encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?; encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?;
w.ostr("", t.as_slice())?; w.ostr("", t.0)?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.auth_key_id { if let Some(t) = &self.auth_key_id {
encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?; encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?;
w.start_seq("")?; w.start_seq("")?;
w.ctx("", 0, t.as_slice())?; w.ctx("", 0, t.0)?;
w.end_seq()?; w.end_seq()?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.future_extensions { if let Some(t) = &self.future_extensions {
error!("Future Extensions Not Yet Supported: {:x?}", t.as_slice()) error!("Future Extensions Not Yet Supported: {:x?}", t.0);
} }
w.end_seq()?; w.end_seq()?;
w.end_ctx()?; w.end_ctx()?;
Ok(()) Ok(())
} }
} }
const MAX_DN_ENTRIES: usize = 5;
#[derive(FromPrimitive, Copy, Clone)] #[derive(FromPrimitive, Copy, Clone)]
enum DnTags { enum DnTags {
@ -293,20 +298,23 @@ enum DnTags {
NocCat = 22, NocCat = 22,
} }
enum DistNameValue { #[derive(Debug)]
enum DistNameValue<'a> {
Uint(u64), Uint(u64),
Utf8Str(Vec<u8>), Utf8Str(&'a [u8]),
PrintableStr(Vec<u8>), PrintableStr(&'a [u8]),
} }
#[derive(Default)] const MAX_DN_ENTRIES: usize = 5;
struct DistNames {
#[derive(Default, Debug)]
struct DistNames<'a> {
// The order in which the DNs arrive is important, as the signing // The order in which the DNs arrive is important, as the signing
// requires that the ASN1 notation retains the same order // requires that the ASN1 notation retains the same order
dn: Vec<(u8, DistNameValue)>, dn: heapless::Vec<(u8, DistNameValue<'a>), MAX_DN_ENTRIES>,
} }
impl DistNames { impl<'a> DistNames<'a> {
fn u64(&self, match_id: DnTags) -> Option<u64> { fn u64(&self, match_id: DnTags) -> Option<u64> {
self.dn self.dn
.iter() .iter()
@ -336,24 +344,27 @@ impl DistNames {
const PRINTABLE_STR_THRESHOLD: u8 = 0x80; const PRINTABLE_STR_THRESHOLD: u8 = 0x80;
impl<'a> FromTLV<'a> for DistNames { impl<'a> FromTLV<'a> for DistNames<'a> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> { fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> {
let mut d = Self { let mut d = Self {
dn: Vec::with_capacity(MAX_DN_ENTRIES), dn: heapless::Vec::new(),
}; };
let iter = t.confirm_list()?.enter().ok_or(Error::Invalid)?; let iter = t.confirm_list()?.enter().ok_or(ErrorCode::Invalid)?;
for t in iter { for t in iter {
if let TagType::Context(tag) = t.get_tag() { if let TagType::Context(tag) = t.get_tag() {
if let Ok(value) = t.u64() { if let Ok(value) = t.u64() {
d.dn.push((tag, DistNameValue::Uint(value))); d.dn.push((tag, DistNameValue::Uint(value)))
.map_err(|_| ErrorCode::BufferTooSmall)?;
} else if let Ok(value) = t.slice() { } else if let Ok(value) = t.slice() {
if tag > PRINTABLE_STR_THRESHOLD { if tag > PRINTABLE_STR_THRESHOLD {
d.dn.push(( d.dn.push((
tag - PRINTABLE_STR_THRESHOLD, tag - PRINTABLE_STR_THRESHOLD,
DistNameValue::PrintableStr(value.to_vec()), DistNameValue::PrintableStr(value),
)); ))
.map_err(|_| ErrorCode::BufferTooSmall)?;
} else { } else {
d.dn.push((tag, DistNameValue::Utf8Str(value.to_vec()))); d.dn.push((tag, DistNameValue::Utf8Str(value)))
.map_err(|_| ErrorCode::BufferTooSmall)?;
} }
} }
} }
@ -362,24 +373,23 @@ impl<'a> FromTLV<'a> for DistNames {
} }
} }
impl ToTLV for DistNames { impl<'a> ToTLV for DistNames<'a> {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.start_list(tag)?; tw.start_list(tag)?;
for (name, value) in &self.dn { for (name, value) in &self.dn {
match value { match value {
DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?, DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?,
DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v.as_slice())?, DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v)?,
DistNameValue::PrintableStr(v) => tw.utf8( DistNameValue::PrintableStr(v) => {
TagType::Context(*name + PRINTABLE_STR_THRESHOLD), tw.utf8(TagType::Context(*name + PRINTABLE_STR_THRESHOLD), v)?
v.as_slice(), }
)?,
} }
} }
tw.end_container() tw.end_container()
} }
} }
impl DistNames { impl<'a> DistNames<'a> {
fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> { fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> {
const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03]; const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03];
const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04]; const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04];
@ -509,52 +519,60 @@ fn encode_dn_value(
w.oid(name, oid)?; w.oid(name, oid)?;
match value { match value {
DistNameValue::Uint(v) => match expected_len { DistNameValue::Uint(v) => match expected_len {
Some(IntToStringLen::Len16) => w.utf8str("", format!("{:016X}", v).as_str())?, Some(IntToStringLen::Len16) => {
Some(IntToStringLen::Len8) => w.utf8str("", format!("{:08X}", v).as_str())?, let mut string = heapless::String::<32>::new();
write!(&mut string, "{:016X}", v).unwrap();
w.utf8str("", &string)?
}
Some(IntToStringLen::Len8) => {
let mut string = heapless::String::<32>::new();
write!(&mut string, "{:08X}", v).unwrap();
w.utf8str("", &string)?
}
_ => { _ => {
error!("Invalid encoding"); error!("Invalid encoding");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?
} }
}, },
DistNameValue::Utf8Str(v) => { DistNameValue::Utf8Str(v) => {
let str = String::from_utf8(v.to_vec())?; w.utf8str("", core::str::from_utf8(v)?)?;
w.utf8str("", &str)?;
} }
DistNameValue::PrintableStr(v) => { DistNameValue::PrintableStr(v) => {
let str = String::from_utf8(v.to_vec())?; w.printstr("", core::str::from_utf8(v)?)?;
w.printstr("", &str)?;
} }
} }
w.end_seq()?; w.end_seq()?;
w.end_set() w.end_set()
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1)] #[tlvargs(lifetime = "'a", start = 1)]
pub struct Cert { pub struct Cert<'a> {
serial_no: Vec<u8>, serial_no: OctetStr<'a>,
sign_algo: u8, sign_algo: u8,
issuer: DistNames, issuer: DistNames<'a>,
not_before: u32, not_before: u32,
not_after: u32, not_after: u32,
subject: DistNames, subject: DistNames<'a>,
pubkey_algo: u8, pubkey_algo: u8,
ec_curve_id: u8, ec_curve_id: u8,
pubkey: Vec<u8>, pubkey: OctetStr<'a>,
extensions: Extensions, extensions: Extensions<'a>,
signature: Vec<u8>, signature: OctetStr<'a>,
} }
// TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding // TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding
// rules in terms of sequence may get complicated. Need to look into this // rules in terms of sequence may get complicated. Need to look into this
impl Cert { impl<'a> Cert<'a> {
pub fn new(cert_bin: &[u8]) -> Result<Self, Error> { pub fn new(cert_bin: &'a [u8]) -> Result<Self, Error> {
let root = tlv::get_root_node(cert_bin)?; let root = tlv::get_root_node(cert_bin)?;
Cert::from_tlv(&root) Cert::from_tlv(&root)
} }
pub fn get_node_id(&self) -> Result<u64, Error> { pub fn get_node_id(&self) -> Result<u64, Error> {
self.subject.u64(DnTags::NodeId).ok_or(Error::NoNodeId) self.subject
.u64(DnTags::NodeId)
.ok_or_else(|| Error::from(ErrorCode::NoNodeId))
} }
pub fn get_cat_ids(&self, output: &mut [u32]) { pub fn get_cat_ids(&self, output: &mut [u32]) {
@ -562,21 +580,27 @@ impl Cert {
} }
pub fn get_fabric_id(&self) -> Result<u64, Error> { pub fn get_fabric_id(&self) -> Result<u64, Error> {
self.subject.u64(DnTags::FabricId).ok_or(Error::NoFabricId) self.subject
.u64(DnTags::FabricId)
.ok_or_else(|| Error::from(ErrorCode::NoFabricId))
} }
pub fn get_pubkey(&self) -> &[u8] { pub fn get_pubkey(&self) -> &[u8] {
self.pubkey.as_slice() self.pubkey.0
} }
pub fn get_subject_key_id(&self) -> Result<&[u8], Error> { pub fn get_subject_key_id(&self) -> Result<&[u8], Error> {
self.extensions.subj_key_id.as_deref().ok_or(Error::Invalid) if let Some(id) = self.extensions.subj_key_id.as_ref() {
Ok(id.0)
} else {
Err(ErrorCode::Invalid.into())
}
} }
pub fn is_authority(&self, their: &Cert) -> Result<bool, Error> { pub fn is_authority(&self, their: &Cert) -> Result<bool, Error> {
if let Some(our_auth_key) = &self.extensions.auth_key_id { if let Some(our_auth_key) = &self.extensions.auth_key_id {
let their_subject = their.get_subject_key_id()?; let their_subject = their.get_subject_key_id()?;
if our_auth_key == their_subject { if our_auth_key.0 == their_subject {
Ok(true) Ok(true)
} else { } else {
Ok(false) Ok(false)
@ -587,11 +611,11 @@ impl Cert {
} }
pub fn get_signature(&self) -> &[u8] { pub fn get_signature(&self) -> &[u8] {
self.signature.as_slice() self.signature.0
} }
pub fn as_tlv(&self, buf: &mut [u8]) -> Result<usize, Error> { pub fn as_tlv(&self, buf: &mut [u8]) -> Result<usize, Error> {
let mut wb = WriteBuf::new(buf, buf.len()); let mut wb = WriteBuf::new(buf);
let mut tw = TLVWriter::new(&mut wb); let mut tw = TLVWriter::new(&mut wb);
self.to_tlv(&mut tw, TagType::Anonymous)?; self.to_tlv(&mut tw, TagType::Anonymous)?;
Ok(wb.as_slice().len()) Ok(wb.as_slice().len())
@ -614,10 +638,10 @@ impl Cert {
w.integer("", &[2])?; w.integer("", &[2])?;
w.end_ctx()?; w.end_ctx()?;
w.integer("Serial Num:", self.serial_no.as_slice())?; w.integer("Serial Num:", self.serial_no.0)?;
w.start_seq("Signature Algorithm:")?; w.start_seq("Signature Algorithm:")?;
let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(Error::Invalid)? { let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(ErrorCode::Invalid)? {
SignAlgoValue::ECDSAWithSHA256 => ("ECDSA with SHA256", OID_ECDSA_WITH_SHA256), SignAlgoValue::ECDSAWithSHA256 => ("ECDSA with SHA256", OID_ECDSA_WITH_SHA256),
}; };
w.oid(str, &oid)?; w.oid(str, &oid)?;
@ -634,17 +658,17 @@ impl Cert {
w.start_seq("")?; w.start_seq("")?;
w.start_seq("Public Key Algorithm")?; w.start_seq("Public Key Algorithm")?;
let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(Error::Invalid)? { let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(ErrorCode::Invalid)? {
PubKeyAlgoValue::EcPubKey => ("ECPubKey", OID_PUB_KEY_ECPUBKEY), PubKeyAlgoValue::EcPubKey => ("ECPubKey", OID_PUB_KEY_ECPUBKEY),
}; };
w.oid(str, &pub_key)?; w.oid(str, &pub_key)?;
let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(Error::Invalid)? { let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(ErrorCode::Invalid)? {
EcCurveIdValue::Prime256V1 => ("Prime256v1", OID_EC_TYPE_PRIME256V1), EcCurveIdValue::Prime256V1 => ("Prime256v1", OID_EC_TYPE_PRIME256V1),
}; };
w.oid(str, &curve_id)?; w.oid(str, &curve_id)?;
w.end_seq()?; w.end_seq()?;
w.bitstr("Public-Key:", false, self.pubkey.as_slice())?; w.bitstr("Public-Key:", false, self.pubkey.0)?;
w.end_seq()?; w.end_seq()?;
self.extensions.encode(w)?; self.extensions.encode(w)?;
@ -655,7 +679,7 @@ impl Cert {
} }
} }
impl fmt::Display for Cert { impl<'a> fmt::Display for Cert<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut printer = CertPrinter::new(f); let mut printer = CertPrinter::new(f);
let _ = self let _ = self
@ -667,7 +691,7 @@ impl fmt::Display for Cert {
} }
pub struct CertVerifier<'a> { pub struct CertVerifier<'a> {
cert: &'a Cert, cert: &'a Cert<'a>,
} }
impl<'a> CertVerifier<'a> { impl<'a> CertVerifier<'a> {
@ -677,7 +701,7 @@ impl<'a> CertVerifier<'a> {
pub fn add_cert(self, parent: &'a Cert) -> Result<CertVerifier<'a>, Error> { pub fn add_cert(self, parent: &'a Cert) -> Result<CertVerifier<'a>, Error> {
if !self.cert.is_authority(parent)? { if !self.cert.is_authority(parent)? {
return Err(Error::InvalidAuthKey); Err(ErrorCode::InvalidAuthKey)?;
} }
let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE]; let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE];
let len = self.cert.as_asn1(&mut asn1)?; let len = self.cert.as_asn1(&mut asn1)?;
@ -731,8 +755,9 @@ mod printer;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use log::info;
use crate::cert::Cert; use crate::cert::Cert;
use crate::error::Error;
use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV};
use crate::utils::writebuf::WriteBuf; use crate::utils::writebuf::WriteBuf;
@ -777,29 +802,41 @@ mod tests {
#[test] #[test]
fn test_verify_chain_incomplete() { fn test_verify_chain_incomplete() {
// The chain doesn't lead up to a self-signed certificate // The chain doesn't lead up to a self-signed certificate
use crate::error::ErrorCode;
let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap();
let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap();
let a = noc.verify_chain_start(); let a = noc.verify_chain_start();
assert_eq!( assert_eq!(
Err(Error::InvalidAuthKey), Err(ErrorCode::InvalidAuthKey),
a.add_cert(&icac).unwrap().finalise() a.add_cert(&icac).unwrap().finalise().map_err(|e| e.code())
); );
} }
#[test] #[test]
fn test_auth_key_chain_incorrect() { fn test_auth_key_chain_incorrect() {
use crate::error::ErrorCode;
let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap(); let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap();
let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap();
let a = noc.verify_chain_start(); let a = noc.verify_chain_start();
assert_eq!(Err(Error::InvalidAuthKey), a.add_cert(&icac).map(|_| ())); assert_eq!(
Err(ErrorCode::InvalidAuthKey),
a.add_cert(&icac).map(|_| ()).map_err(|e| e.code())
);
} }
#[test] #[test]
fn test_cert_corrupted() { fn test_cert_corrupted() {
use crate::error::ErrorCode;
let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap(); let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap();
let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap();
let a = noc.verify_chain_start(); let a = noc.verify_chain_start();
assert_eq!(Err(Error::InvalidSignature), a.add_cert(&icac).map(|_| ())); assert_eq!(
Err(ErrorCode::InvalidSignature),
a.add_cert(&icac).map(|_| ()).map_err(|e| e.code())
);
} }
#[test] #[test]
@ -811,12 +848,11 @@ mod tests {
]; ];
for input in test_input.iter() { for input in test_input.iter() {
println!("Testing next input..."); info!("Testing next input...");
let root = tlv::get_root_node(input).unwrap(); let root = tlv::get_root_node(input).unwrap();
let cert = Cert::from_tlv(&root).unwrap(); let cert = Cert::from_tlv(&root).unwrap();
let mut buf = [0u8; 1024]; let mut buf = [0u8; 1024];
let buf_len = buf.len(); let mut wb = WriteBuf::new(&mut buf);
let mut wb = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut wb); let mut tw = TLVWriter::new(&mut wb);
cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); cert.to_tlv(&mut tw, TagType::Anonymous).unwrap();
assert_eq!(*input, wb.as_slice()); assert_eq!(*input, wb.as_slice());

View file

@ -15,11 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use time::OffsetDateTime;
use super::{CertConsumer, MAX_DEPTH}; use super::{CertConsumer, MAX_DEPTH};
use crate::error::Error; use crate::{error::Error, utils::epoch::MATTER_EPOCH_SECS};
use chrono::{TimeZone, Utc}; use core::fmt;
use log::warn;
use std::fmt;
pub struct CertPrinter<'a, 'b> { pub struct CertPrinter<'a, 'b> {
level: usize, level: usize,
@ -123,23 +123,11 @@ impl<'a, 'b> CertConsumer for CertPrinter<'a, 'b> {
Ok(()) Ok(())
} }
fn utctime(&mut self, tag: &str, epoch: u32) -> Result<(), Error> { fn utctime(&mut self, tag: &str, epoch: u32) -> Result<(), Error> {
let mut matter_epoch = Utc let matter_epoch = MATTER_EPOCH_SECS + epoch as u64;
.with_ymd_and_hms(2000, 1, 1, 0, 0, 0)
.unwrap()
.timestamp();
matter_epoch += epoch as i64; let dt = OffsetDateTime::from_unix_timestamp(matter_epoch as _).unwrap();
let dt = match Utc.timestamp_opt(matter_epoch, 0) { let _ = writeln!(self.f, "{} {} {:?}", SPACE[self.level], tag, dt);
chrono::LocalResult::None => return Err(Error::InvalidTime),
chrono::LocalResult::Single(s) => s,
chrono::LocalResult::Ambiguous(_, a) => {
warn!("Ambiguous time for epoch {epoch}; returning latest timestamp: {a}");
a
}
};
let _ = writeln!(self.f, "{} {} {}", SPACE[self.level], tag, dt);
Ok(()) Ok(())
} }
} }

View file

@ -17,7 +17,7 @@
//! Base38 encoding and decoding functions. //! Base38 encoding and decoding functions.
use crate::error::Error; use crate::error::{Error, ErrorCode};
const BASE38_CHARS: [char; 38] = [ const BASE38_CHARS: [char; 38] = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
@ -77,60 +77,68 @@ const DECODE_BASE38: [u8; 46] = [
35, // 'Z', =90 35, // 'Z', =90
]; ];
const BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK: [u8; 3] = [2, 4, 5];
const RADIX: u32 = BASE38_CHARS.len() as u32; const RADIX: u32 = BASE38_CHARS.len() as u32;
/// Encode a byte array into a base38 string. /// Encode a byte array into a base38 string.
/// ///
/// # Arguments /// # Arguments
/// * `bytes` - byte array to encode /// * `bytes` - byte array to encode
/// * `length` - optional length of the byte array to encode. If not specified, the entire byte array is encoded. pub fn encode_string<const N: usize>(bytes: &[u8]) -> Result<heapless::String<N>, Error> {
pub fn encode(bytes: &[u8], length: Option<usize>) -> String { let mut string = heapless::String::new();
let mut offset = 0; for c in encode(bytes) {
let mut result = String::new(); string.push(c).map_err(|_| ErrorCode::NoSpace)?;
}
// if length is specified, use it, otherwise use the length of the byte array Ok(string)
// if length is specified but is greater than the length of the byte array, use the length of the byte array }
let b_len = bytes.len();
let length = length.map(|l| l.min(b_len)).unwrap_or(b_len);
while offset < length { pub fn encode(bytes: &[u8]) -> impl Iterator<Item = char> + '_ {
let remaining = length - offset; (0..bytes.len() / 3)
match remaining.cmp(&2) { .flat_map(move |index| {
std::cmp::Ordering::Greater => { let offset = index * 3;
result.push_str(&encode_base38(
encode_base38(
((bytes[offset + 2] as u32) << 16) ((bytes[offset + 2] as u32) << 16)
| ((bytes[offset + 1] as u32) << 8) | ((bytes[offset + 1] as u32) << 8)
| (bytes[offset] as u32), | (bytes[offset] as u32),
5, 5,
)); )
offset += 3; })
} .chain(
std::cmp::Ordering::Equal => { core::iter::once(bytes.len() % 3).flat_map(move |remainder| {
result.push_str(&encode_base38( let offset = bytes.len() / 3 * 3;
match remainder {
2 => encode_base38(
((bytes[offset + 1] as u32) << 8) | (bytes[offset] as u32), ((bytes[offset + 1] as u32) << 8) | (bytes[offset] as u32),
4, 4,
)); ),
break; 1 => encode_base38(bytes[offset] as u32, 2),
} _ => encode_base38(0, 0),
std::cmp::Ordering::Less => {
result.push_str(&encode_base38(bytes[offset] as u32, 2));
break;
}
} }
}),
)
} }
result fn encode_base38(mut value: u32, repeat: usize) -> impl Iterator<Item = char> {
(0..repeat).map(move |_| {
let remainder = value % RADIX;
let c = BASE38_CHARS[remainder as usize];
value = (value - remainder) / RADIX;
c
})
} }
fn encode_base38(mut value: u32, char_count: u8) -> String { pub fn decode_vec<const N: usize>(base38_str: &str) -> Result<heapless::Vec<u8, N>, Error> {
let mut result = String::new(); let mut vec = heapless::Vec::new();
for _ in 0..char_count {
let remainder = value % 38; for byte in decode(base38_str) {
result.push(BASE38_CHARS[remainder as usize]); vec.push(byte?).map_err(|_| ErrorCode::NoSpace)?;
value = (value - remainder) / 38;
} }
result
Ok(vec)
} }
/// Decode a base38-encoded string into a byte slice /// Decode a base38-encoded string into a byte slice
@ -138,64 +146,71 @@ fn encode_base38(mut value: u32, char_count: u8) -> String {
/// # Arguments /// # Arguments
/// * `base38_str` - base38-encoded string to decode /// * `base38_str` - base38-encoded string to decode
/// ///
/// Fails if the string contains invalid characters /// Fails if the string contains invalid characters or if the supplied buffer is too small to fit the decoded data
pub fn decode(base38_str: &str) -> Result<Vec<u8>, Error> { pub fn decode(base38_str: &str) -> impl Iterator<Item = Result<u8, Error>> + '_ {
let mut result = Vec::new(); let stru = base38_str.as_bytes();
let mut base38_characters_number: usize = base38_str.len();
let mut decoded_base38_characters: usize = 0;
while base38_characters_number > 0 { (0..stru.len() / 5)
let base38_characters_in_chunk: usize; .flat_map(move |index| {
let bytes_in_decoded_chunk: usize; let offset = index * 5;
decode_base38(&stru[offset..offset + 5])
if base38_characters_number >= BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[2] as usize { })
base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[2] as usize; .chain({
bytes_in_decoded_chunk = 3; let offset = stru.len() / 5 * 5;
} else if base38_characters_number == BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[1] as usize { decode_base38(&stru[offset..])
base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[1] as usize; })
bytes_in_decoded_chunk = 2; .take_while(Result::is_ok)
} else if base38_characters_number == BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[0] as usize {
base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[0] as usize;
bytes_in_decoded_chunk = 1;
} else {
return Err(Error::InvalidData);
} }
fn decode_base38(chars: &[u8]) -> impl Iterator<Item = Result<u8, Error>> {
let mut value = 0u32; let mut value = 0u32;
let mut cerr = None;
for i in (1..=base38_characters_in_chunk).rev() { let repeat = match chars.len() {
let mut base38_chars = base38_str.chars(); 5 => 3,
let v = decode_char(base38_chars.nth(decoded_base38_characters + i - 1).unwrap())?; 4 => 2,
2 => 1,
0 => 0,
_ => -1,
};
value = value * RADIX + v as u32; if repeat >= 0 {
for c in chars.iter().rev() {
match decode_char(*c) {
Ok(v) => value = value * RADIX + v as u32,
Err(err) => {
cerr = Some(err.code());
break;
}
}
}
} else {
cerr = Some(ErrorCode::InvalidData)
} }
decoded_base38_characters += base38_characters_in_chunk; (0..repeat)
base38_characters_number -= base38_characters_in_chunk; .map(move |_| {
if let Some(err) = cerr {
Err(err.into())
} else {
let byte = (value & 0xff) as u8;
for _i in 0..bytes_in_decoded_chunk {
result.push(value as u8);
value >>= 8; value >>= 8;
Ok(byte)
}
})
.take_while(Result::is_ok)
} }
if value > 0 { fn decode_char(c: u8) -> Result<u8, Error> {
// encoded value is too big to represent a correct chunk of size 1, 2 or 3 bytes
return Err(Error::InvalidArgument);
}
}
Ok(result)
}
fn decode_char(c: char) -> Result<u8, Error> {
let c = c as u8;
if !(45..=90).contains(&c) { if !(45..=90).contains(&c) {
return Err(Error::InvalidData); Err(ErrorCode::InvalidData)?;
} }
let c = DECODE_BASE38[c as usize - 45]; let c = DECODE_BASE38[c as usize - 45];
if c == UNUSED { if c == UNUSED {
return Err(Error::InvalidData); Err(ErrorCode::InvalidData)?;
} }
Ok(c) Ok(c)
@ -211,15 +226,17 @@ mod tests {
#[test] #[test]
fn can_base38_encode() { fn can_base38_encode() {
assert_eq!(encode(&DECODED, None), ENCODED); assert_eq!(
assert_eq!(encode(&DECODED, Some(11)), ENCODED); encode_string::<{ ENCODED.len() }>(&DECODED).unwrap(),
ENCODED
// length is greater than the length of the byte array );
assert_eq!(encode(&DECODED, Some(12)), ENCODED);
} }
#[test] #[test]
fn can_base38_decode() { fn can_base38_decode() {
assert_eq!(decode(ENCODED).expect("can not decode base38"), DECODED); assert_eq!(
decode_vec::<{ DECODED.len() }>(ENCODED).expect("Cannot decode base38"),
DECODED
);
} }
} }

View file

@ -15,21 +15,24 @@
* limitations under the License. * limitations under the License.
*/ */
use core::{borrow::Borrow, cell::RefCell};
use crate::{ use crate::{
acl::AclMgr, acl::AclMgr,
data_model::{ data_model::{
cluster_basic_information::BasicInfoConfig, core::DataModel, cluster_basic_information::BasicInfoConfig,
sdm::dev_att::DevAttDataFetcher, sdm::{dev_att::DevAttDataFetcher, failsafe::FailSafe},
}, },
error::*, error::*,
fabric::FabricMgr, fabric::FabricMgr,
interaction_model::InteractionModel,
mdns::Mdns, mdns::Mdns,
pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities},
secure_channel::{core::SecureChannel, pake::PaseMgr, spake2p::VerifierData}, secure_channel::{pake::PaseMgr, spake2p::VerifierData},
transport, utils::{epoch::Epoch, rand::Rand},
}; };
use std::sync::Arc;
/* The Matter Port */
pub const MATTER_PORT: u16 = 5540;
/// Device Commissioning Data /// Device Commissioning Data
pub struct CommissioningData { pub struct CommissioningData {
@ -40,71 +43,172 @@ pub struct CommissioningData {
} }
/// The primary Matter Object /// The primary Matter Object
pub struct Matter { pub struct Matter<'a> {
transport_mgr: transport::mgr::Mgr, pub fabric_mgr: RefCell<FabricMgr>,
data_model: DataModel, pub acl_mgr: RefCell<AclMgr>,
fabric_mgr: Arc<FabricMgr>, pub pase_mgr: RefCell<PaseMgr>,
pub failsafe: RefCell<FailSafe>,
pub mdns: &'a dyn Mdns,
pub epoch: Epoch,
pub rand: Rand,
pub dev_det: &'a BasicInfoConfig<'a>,
pub dev_att: &'a dyn DevAttDataFetcher,
pub port: u16,
}
impl<'a> Matter<'a> {
#[cfg(feature = "std")]
#[inline(always)]
pub fn new_default(
dev_det: &'a BasicInfoConfig<'a>,
dev_att: &'a dyn DevAttDataFetcher,
mdns: &'a dyn Mdns,
port: u16,
) -> Self {
use crate::utils::epoch::sys_epoch;
use crate::utils::rand::sys_rand;
Self::new(dev_det, dev_att, mdns, sys_epoch, sys_rand, port)
} }
impl Matter {
/// Creates a new Matter object /// Creates a new Matter object
/// ///
/// # Parameters /// # Parameters
/// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device
/// requires a set of device attestation certificates and keys. It is the responsibility of /// requires a set of device attestation certificates and keys. It is the responsibility of
/// this object to return the device attestation details when queried upon. /// this object to return the device attestation details when queried upon.
#[inline(always)]
pub fn new( pub fn new(
dev_det: BasicInfoConfig, dev_det: &'a BasicInfoConfig<'a>,
dev_att: Box<dyn DevAttDataFetcher>, dev_att: &'a dyn DevAttDataFetcher,
mdns: &'a dyn Mdns,
epoch: Epoch,
rand: Rand,
port: u16,
) -> Self {
Self {
fabric_mgr: RefCell::new(FabricMgr::new()),
acl_mgr: RefCell::new(AclMgr::new()),
pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)),
failsafe: RefCell::new(FailSafe::new()),
mdns,
epoch,
rand,
dev_det,
dev_att,
port,
}
}
pub fn dev_det(&self) -> &BasicInfoConfig<'_> {
self.dev_det
}
pub fn dev_att(&self) -> &dyn DevAttDataFetcher {
self.dev_att
}
pub fn port(&self) -> u16 {
self.port
}
pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> {
self.fabric_mgr.borrow_mut().load(data, self.mdns)
}
pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> {
self.acl_mgr.borrow_mut().load(data)
}
pub fn store_fabrics<'b>(&self, buf: &'b mut [u8]) -> Result<Option<&'b [u8]>, Error> {
self.fabric_mgr.borrow_mut().store(buf)
}
pub fn store_acls<'b>(&self, buf: &'b mut [u8]) -> Result<Option<&'b [u8]>, Error> {
self.acl_mgr.borrow_mut().store(buf)
}
pub fn is_changed(&self) -> bool {
self.acl_mgr.borrow().is_changed() || self.fabric_mgr.borrow().is_changed()
}
pub fn start_comissioning(
&self,
dev_comm: CommissioningData, dev_comm: CommissioningData,
) -> Result<Box<Matter>, Error> { buf: &mut [u8],
let mdns = Mdns::get()?; ) -> Result<bool, Error> {
mdns.set_values(dev_det.vid, dev_det.pid, &dev_det.device_name); if !self.pase_mgr.borrow().is_pase_session_enabled() && self.fabric_mgr.borrow().is_empty()
{
print_pairing_code_and_qr(
self.dev_det,
&dev_comm,
DiscoveryCapabilities::default(),
buf,
)?;
let fabric_mgr = Arc::new(FabricMgr::new()?); self.pase_mgr.borrow_mut().enable_pase_session(
let open_comm_window = fabric_mgr.is_empty(); dev_comm.verifier,
if open_comm_window { dev_comm.discriminator,
print_pairing_code_and_qr(&dev_det, &dev_comm, DiscoveryCapabilities::default()); self.mdns,
} )?;
let acl_mgr = Arc::new(AclMgr::new()?); Ok(true)
let mut pase = PaseMgr::new(); } else {
let data_model = Ok(false)
DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr, pase.clone())?; }
let mut matter = Box::new(Matter { }
transport_mgr: transport::mgr::Mgr::new()?, }
data_model,
fabric_mgr, impl<'a> Borrow<RefCell<FabricMgr>> for Matter<'a> {
}); fn borrow(&self) -> &RefCell<FabricMgr> {
let interaction_model = &self.fabric_mgr
Box::new(InteractionModel::new(Box::new(matter.data_model.clone()))); }
matter.transport_mgr.register_protocol(interaction_model)?; }
if open_comm_window { impl<'a> Borrow<RefCell<AclMgr>> for Matter<'a> {
pase.enable_pase_session(dev_comm.verifier, dev_comm.discriminator)?; fn borrow(&self) -> &RefCell<AclMgr> {
} &self.acl_mgr
}
let secure_channel = Box::new(SecureChannel::new(pase, matter.fabric_mgr.clone())); }
matter.transport_mgr.register_protocol(secure_channel)?;
Ok(matter) impl<'a> Borrow<RefCell<PaseMgr>> for Matter<'a> {
} fn borrow(&self) -> &RefCell<PaseMgr> {
&self.pase_mgr
/// Returns an Arc to [DataModel] }
/// }
/// The Data Model is where you express what is the type of your device. Typically
/// once you gets this reference, you acquire the write lock and add your device impl<'a> Borrow<RefCell<FailSafe>> for Matter<'a> {
/// types, clusters, attributes, commands to the data model. fn borrow(&self) -> &RefCell<FailSafe> {
pub fn get_data_model(&self) -> DataModel { &self.failsafe
self.data_model.clone() }
} }
/// Starts the Matter daemon impl<'a> Borrow<BasicInfoConfig<'a>> for Matter<'a> {
/// fn borrow(&self) -> &BasicInfoConfig<'a> {
/// This call does NOT return self.dev_det
/// }
/// This call starts the Matter daemon that starts communication with other Matter }
/// devices on the network.
pub fn start_daemon(&mut self) -> Result<(), Error> { impl<'a> Borrow<dyn DevAttDataFetcher + 'a> for Matter<'a> {
self.transport_mgr.start() fn borrow(&self) -> &(dyn DevAttDataFetcher + 'a) {
self.dev_att
}
}
impl<'a> Borrow<dyn Mdns + 'a> for Matter<'a> {
fn borrow(&self) -> &(dyn Mdns + 'a) {
self.mdns
}
}
impl<'a> Borrow<Epoch> for Matter<'a> {
fn borrow(&self) -> &Epoch {
&self.epoch
}
}
impl<'a> Borrow<Rand> for Matter<'a> {
fn borrow(&self) -> &Rand {
&self.rand
} }
} }

View file

@ -17,41 +17,120 @@
use log::error; use log::error;
use crate::error::Error; use crate::{
error::{Error, ErrorCode},
utils::rand::Rand,
};
use super::CryptoKeyPair; pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> {
error!("This API should never get called");
Ok(())
}
pub struct KeyPairDummy {} #[derive(Clone, Debug)]
pub struct Sha256 {}
impl KeyPairDummy { impl Sha256 {
pub fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
Ok(Self {}) Ok(Self {})
} }
pub fn update(&mut self, _data: &[u8]) -> Result<(), Error> {
Ok(())
} }
impl CryptoKeyPair for KeyPairDummy { pub fn finish(self, _digest: &mut [u8]) -> Result<(), Error> {
fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { Ok(())
}
}
pub struct HmacSha256 {}
impl HmacSha256 {
pub fn new(_key: &[u8]) -> Result<Self, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Ok(Self {})
} }
fn get_public_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
pub fn update(&mut self, _data: &[u8]) -> Result<(), Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Ok(())
} }
fn get_private_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
pub fn finish(self, _out: &mut [u8]) -> Result<(), Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Ok(())
} }
fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result<usize, Error> { }
#[derive(Debug)]
pub struct KeyPair;
impl KeyPair {
pub fn new(_rand: Rand) -> Result<Self, Error> {
Ok(Self)
}
pub fn new_from_components(_pub_key: &[u8], _priv_key: &[u8]) -> Result<Self, Error> {
Ok(Self {})
}
pub fn new_from_public(_pub_key: &[u8]) -> Result<Self, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid)
Ok(Self {})
} }
fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result<usize, Error> {
pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> {
pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
Ok(0)
}
pub fn get_private_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
Ok(0)
}
pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
}
pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called");
Err(ErrorCode::Invalid.into())
}
pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> {
error!("This API should never get called");
Err(ErrorCode::Invalid.into())
} }
} }
pub fn pbkdf2_hmac(_pass: &[u8], _iter: usize, _salt: &[u8], _key: &mut [u8]) -> Result<(), Error> {
error!("This API should never get called");
Ok(())
}
pub fn encrypt_in_place(
_key: &[u8],
_nonce: &[u8],
_ad: &[u8],
_data: &mut [u8],
_data_len: usize,
) -> Result<usize, Error> {
Ok(0)
}
pub fn decrypt_in_place(
_key: &[u8],
_nonce: &[u8],
_ad: &[u8],
_data: &mut [u8],
) -> Result<usize, Error> {
Ok(0)
}

View file

@ -17,9 +17,8 @@
use log::error; use log::error;
use crate::error::Error; use crate::error::{Error, ErrorCode};
use crate::utils::rand::Rand;
use super::CryptoKeyPair;
pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> {
error!("This API should never get called"); error!("This API should never get called");
@ -62,18 +61,17 @@ impl HmacSha256 {
} }
} }
#[derive(Debug)]
pub struct KeyPair {} pub struct KeyPair {}
impl KeyPair { impl KeyPair {
pub fn new() -> Result<Self, Error> { pub fn new(_rand: Rand) -> Result<Self, Error> {
error!("This API should never get called"); error!("This API should never get called");
Ok(Self {}) Ok(Self {})
} }
pub fn new_from_components(_pub_key: &[u8], priv_key: &[u8]) -> Result<Self, Error> { pub fn new_from_components(_pub_key: &[u8], priv_key: &[u8]) -> Result<Self, Error> {
error!("This API should never get called");
Ok(Self {}) Ok(Self {})
} }
@ -82,28 +80,33 @@ impl KeyPair {
Ok(Self {}) Ok(Self {})
} }
pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
error!("This API should never get called");
Err(ErrorCode::Invalid.into())
} }
impl CryptoKeyPair for KeyPair { pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { Ok(0)
error!("This API should never get called");
Err(Error::Invalid)
} }
fn get_public_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called"); pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
Err(Error::Invalid) Ok(0)
} }
fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result<usize, Error> {
pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result<usize, Error> {
pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> {
pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }

View file

@ -15,9 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::Arc; extern crate alloc;
use log::error; use alloc::sync::Arc;
use log::{error, info};
use mbedtls::{ use mbedtls::{
bignum::Mpi, bignum::Mpi,
cipher::{Authenticated, Cipher}, cipher::{Authenticated, Cipher},
@ -28,12 +30,12 @@ use mbedtls::{
x509, x509,
}; };
use super::CryptoKeyPair;
use crate::{ use crate::{
// TODO: We should move ASN1Writer out of Cert, // TODO: We should move ASN1Writer out of Cert,
// so Crypto doesn't have to depend on Cert // so Crypto doesn't have to depend on Cert
cert::{ASN1Writer, CertConsumer}, cert::{ASN1Writer, CertConsumer},
error::Error, error::{Error, ErrorCode},
utils::rand::Rand,
}; };
pub struct HmacSha256 { pub struct HmacSha256 {
@ -48,11 +50,13 @@ impl HmacSha256 {
} }
pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { pub fn update(&mut self, data: &[u8]) -> Result<(), Error> {
self.inner.update(data).map_err(|_| Error::TLSStack) self.inner
.update(data)
.map_err(|_| ErrorCode::TLSStack.into())
} }
pub fn finish(self, out: &mut [u8]) -> Result<(), Error> { pub fn finish(self, out: &mut [u8]) -> Result<(), Error> {
self.inner.finish(out).map_err(|_| Error::TLSStack)?; self.inner.finish(out).map_err(|_| ErrorCode::TLSStack)?;
Ok(()) Ok(())
} }
} }
@ -62,7 +66,7 @@ pub struct KeyPair {
} }
impl KeyPair { impl KeyPair {
pub fn new() -> Result<Self, Error> { pub fn new(_rand: Rand) -> Result<Self, Error> {
let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?; let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?;
Ok(Self { Ok(Self {
key: Pk::generate_ec(&mut ctr_drbg, EcGroupId::SecP256R1)?, key: Pk::generate_ec(&mut ctr_drbg, EcGroupId::SecP256R1)?,
@ -85,10 +89,8 @@ impl KeyPair {
key: Pk::public_from_ec_components(group, pub_key)?, key: Pk::public_from_ec_components(group, pub_key)?,
}) })
} }
}
impl CryptoKeyPair for KeyPair { pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
let tmp_priv = self.key.ec_private()?; let tmp_priv = self.key.ec_private()?;
let mut tmp_key = let mut tmp_key =
Pk::private_from_ec_components(EcGroup::new(EcGroupId::SecP256R1)?, tmp_priv)?; Pk::private_from_ec_components(EcGroup::new(EcGroupId::SecP256R1)?, tmp_priv)?;
@ -103,16 +105,16 @@ impl CryptoKeyPair for KeyPair {
Ok(Some(a)) => Ok(a), Ok(Some(a)) => Ok(a),
Ok(None) => { Ok(None) => {
error!("Error in writing CSR: None received"); error!("Error in writing CSR: None received");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
Err(e) => { Err(e) => {
error!("Error in writing CSR {}", e); error!("Error in writing CSR {}", e);
Err(Error::TLSStack) Err(ErrorCode::TLSStack.into())
} }
} }
} }
fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> { pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> {
let public_key = self.key.ec_public()?; let public_key = self.key.ec_public()?;
let group = EcGroup::new(EcGroupId::SecP256R1)?; let group = EcGroup::new(EcGroupId::SecP256R1)?;
let vec = public_key.to_binary(&group, false)?; let vec = public_key.to_binary(&group, false)?;
@ -122,7 +124,7 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> { pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
let priv_key_mpi = self.key.ec_private()?; let priv_key_mpi = self.key.ec_private()?;
let vec = priv_key_mpi.to_binary()?; let vec = priv_key_mpi.to_binary()?;
@ -131,7 +133,7 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> { pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> {
// mbedtls requires a 'mut' key. Instead of making a change in our Trait, // mbedtls requires a 'mut' key. Instead of making a change in our Trait,
// we just clone the key this way // we just clone the key this way
@ -149,7 +151,7 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> { pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> {
// mbedtls requires a 'mut' key. Instead of making a change in our Trait, // mbedtls requires a 'mut' key. Instead of making a change in our Trait,
// we just clone the key this way // we just clone the key this way
let tmp_key = self.key.ec_private()?; let tmp_key = self.key.ec_private()?;
@ -162,7 +164,7 @@ impl CryptoKeyPair for KeyPair {
let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?; let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?;
if signature.len() < super::EC_SIGNATURE_LEN_BYTES { if signature.len() < super::EC_SIGNATURE_LEN_BYTES {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
safemem::write_bytes(signature, 0); safemem::write_bytes(signature, 0);
@ -175,7 +177,7 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> {
// mbedtls requires a 'mut' key. Instead of making a change in our Trait, // mbedtls requires a 'mut' key. Instead of making a change in our Trait,
// we just clone the key this way // we just clone the key this way
let tmp_key = self.key.ec_public()?; let tmp_key = self.key.ec_public()?;
@ -192,14 +194,20 @@ impl CryptoKeyPair for KeyPair {
let mbedtls_sign = &mbedtls_sign[..len]; let mbedtls_sign = &mbedtls_sign[..len];
if let Err(e) = tmp_key.verify(hash::Type::Sha256, &msg_hash, mbedtls_sign) { if let Err(e) = tmp_key.verify(hash::Type::Sha256, &msg_hash, mbedtls_sign) {
println!("The error is {}", e); info!("The error is {}", e);
Err(Error::InvalidSignature) Err(ErrorCode::InvalidSignature.into())
} else { } else {
Ok(()) Ok(())
} }
} }
} }
impl core::fmt::Debug for KeyPair {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("KeyPair").finish()
}
}
fn convert_r_s_to_asn1_sign(signature: &[u8], mbedtls_sign: &mut [u8]) -> Result<usize, Error> { fn convert_r_s_to_asn1_sign(signature: &[u8], mbedtls_sign: &mut [u8]) -> Result<usize, Error> {
let r = &signature[0..32]; let r = &signature[0..32];
let s = &signature[32..64]; let s = &signature[32..64];
@ -224,7 +232,7 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result<usize, Error> {
// Type 0x2 is Integer (first integer is r) // Type 0x2 is Integer (first integer is r)
if signature[offset] != 2 { if signature[offset] != 2 {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
offset += 1; offset += 1;
@ -249,7 +257,7 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result<usize, Error> {
// Type 0x2 is Integer (this integer is s) // Type 0x2 is Integer (this integer is s)
if signature[offset] != 2 { if signature[offset] != 2 {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
offset += 1; offset += 1;
@ -268,17 +276,17 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result<usize, Error> {
Ok(64) Ok(64)
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> { pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> {
mbedtls::hash::pbkdf2_hmac(Type::Sha256, pass, salt, iter as u32, key) mbedtls::hash::pbkdf2_hmac(Type::Sha256, pass, salt, iter as u32, key)
.map_err(|_e| Error::TLSStack) .map_err(|_e| ErrorCode::TLSStack.into())
} }
pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> { pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> {
Hkdf::hkdf(Type::Sha256, salt, ikm, info, key).map_err(|_e| Error::TLSStack) Hkdf::hkdf(Type::Sha256, salt, ikm, info, key).map_err(|_e| ErrorCode::TLSStack.into())
} }
pub fn encrypt_in_place( pub fn encrypt_in_place(
@ -299,7 +307,7 @@ pub fn encrypt_in_place(
cipher cipher
.encrypt_auth_inplace(ad, data, tag) .encrypt_auth_inplace(ad, data, tag)
.map(|(len, _)| len) .map(|(len, _)| len)
.map_err(|_e| Error::TLSStack) .map_err(|_e| ErrorCode::TLSStack.into())
} }
pub fn decrypt_in_place( pub fn decrypt_in_place(
@ -321,7 +329,7 @@ pub fn decrypt_in_place(
.map(|(len, _)| len) .map(|(len, _)| len)
.map_err(|e| { .map_err(|e| {
error!("Error during decryption: {:?}", e); error!("Error during decryption: {:?}", e);
Error::TLSStack ErrorCode::TLSStack.into()
}) })
} }
@ -338,12 +346,12 @@ impl Sha256 {
} }
pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { pub fn update(&mut self, data: &[u8]) -> Result<(), Error> {
self.ctx.update(data).map_err(|_| Error::TLSStack)?; self.ctx.update(data).map_err(|_| ErrorCode::TLSStack)?;
Ok(()) Ok(())
} }
pub fn finish(self, digest: &mut [u8]) -> Result<(), Error> { pub fn finish(self, digest: &mut [u8]) -> Result<(), Error> {
self.ctx.finish(digest).map_err(|_| Error::TLSStack)?; self.ctx.finish(digest).map_err(|_| ErrorCode::TLSStack)?;
Ok(()) Ok(())
} }
} }

View file

@ -15,9 +15,10 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error; use crate::error::{Error, ErrorCode};
use crate::utils::rand::Rand;
use super::CryptoKeyPair; use alloc::vec;
use foreign_types::ForeignTypeRef; use foreign_types::ForeignTypeRef;
use log::error; use log::error;
use openssl::asn1::Asn1Type; use openssl::asn1::Asn1Type;
@ -40,6 +41,9 @@ use openssl::x509::{X509NameBuilder, X509ReqBuilder, X509};
// problem while using OpenSSL's Signer // problem while using OpenSSL's Signer
// TODO: Use proper OpenSSL method for this // TODO: Use proper OpenSSL method for this
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
extern crate alloc;
pub struct HmacSha256 { pub struct HmacSha256 {
ctx: Hmac<sha2::Sha256>, ctx: Hmac<sha2::Sha256>,
} }
@ -47,7 +51,8 @@ pub struct HmacSha256 {
impl HmacSha256 { impl HmacSha256 {
pub fn new(key: &[u8]) -> Result<Self, Error> { pub fn new(key: &[u8]) -> Result<Self, Error> {
Ok(Self { Ok(Self {
ctx: Hmac::<sha2::Sha256>::new_from_slice(key).map_err(|_x| Error::InvalidKeyLength)?, ctx: Hmac::<sha2::Sha256>::new_from_slice(key)
.map_err(|_x| ErrorCode::InvalidKeyLength)?,
}) })
} }
@ -62,16 +67,18 @@ impl HmacSha256 {
} }
} }
#[derive(Debug)]
pub enum KeyType { pub enum KeyType {
Public(EcKey<pkey::Public>), Public(EcKey<pkey::Public>),
Private(EcKey<pkey::Private>), Private(EcKey<pkey::Private>),
} }
#[derive(Debug)]
pub struct KeyPair { pub struct KeyPair {
key: KeyType, key: KeyType,
} }
impl KeyPair { impl KeyPair {
pub fn new() -> Result<Self, Error> { pub fn new(_rand: Rand) -> Result<Self, Error> {
let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
let key = EcKey::generate(&group)?; let key = EcKey::generate(&group)?;
Ok(Self { Ok(Self {
@ -108,14 +115,12 @@ impl KeyPair {
fn private_key(&self) -> Result<&EcKey<Private>, Error> { fn private_key(&self) -> Result<&EcKey<Private>, Error> {
match &self.key { match &self.key {
KeyType::Public(_) => Err(Error::Invalid), KeyType::Public(_) => Err(ErrorCode::Invalid.into()),
KeyType::Private(k) => Ok(&k), KeyType::Private(k) => Ok(&k),
} }
} }
}
impl CryptoKeyPair for KeyPair { pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> {
fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> {
let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
let mut bn_ctx = BigNumContext::new()?; let mut bn_ctx = BigNumContext::new()?;
let s = self.public_key_point().to_bytes( let s = self.public_key_point().to_bytes(
@ -128,14 +133,14 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> { pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
let s = self.private_key()?.private_key().to_vec(); let s = self.private_key()?.private_key().to_vec();
let len = s.len(); let len = s.len();
priv_key[..len].copy_from_slice(s.as_slice()); priv_key[..len].copy_from_slice(s.as_slice());
Ok(len) Ok(len)
} }
fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> { pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> {
let self_pkey = PKey::from_ec_key(self.private_key()?.clone())?; let self_pkey = PKey::from_ec_key(self.private_key()?.clone())?;
let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
@ -149,7 +154,7 @@ impl CryptoKeyPair for KeyPair {
Ok(deriver.derive(secret)?) Ok(deriver.derive(secret)?)
} }
fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
let mut builder = X509ReqBuilder::new()?; let mut builder = X509ReqBuilder::new()?;
builder.set_version(0)?; builder.set_version(0)?;
@ -170,18 +175,18 @@ impl CryptoKeyPair for KeyPair {
a.copy_from_slice(csr); a.copy_from_slice(csr);
Ok(a) Ok(a)
} else { } else {
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
} }
fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> { pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> {
// First get the SHA256 of the message // First get the SHA256 of the message
let mut h = Hasher::new(MessageDigest::sha256())?; let mut h = Hasher::new(MessageDigest::sha256())?;
h.update(msg)?; h.update(msg)?;
let msg = h.finish()?; let msg = h.finish()?;
if signature.len() < super::EC_SIGNATURE_LEN_BYTES { if signature.len() < super::EC_SIGNATURE_LEN_BYTES {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
safemem::write_bytes(signature, 0); safemem::write_bytes(signature, 0);
@ -193,7 +198,7 @@ impl CryptoKeyPair for KeyPair {
Ok(64) Ok(64)
} }
fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> {
// First get the SHA256 of the message // First get the SHA256 of the message
let mut h = Hasher::new(MessageDigest::sha256())?; let mut h = Hasher::new(MessageDigest::sha256())?;
h.update(msg)?; h.update(msg)?;
@ -208,11 +213,11 @@ impl CryptoKeyPair for KeyPair {
KeyType::Public(key) => key, KeyType::Public(key) => key,
_ => { _ => {
error!("Not yet supported"); error!("Not yet supported");
return Err(Error::Invalid); return Err(ErrorCode::Invalid.into());
} }
}; };
if !sig.verify(&msg, k)? { if !sig.verify(&msg, k)? {
Err(Error::InvalidSignature) Err(ErrorCode::InvalidSignature.into())
} else { } else {
Ok(()) Ok(())
} }
@ -223,7 +228,7 @@ const P256_KEY_LEN: usize = 256 / 8;
pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Error> { pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Error> {
if out_key.len() != P256_KEY_LEN { if out_key.len() != P256_KEY_LEN {
error!("Insufficient length"); error!("Insufficient length");
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} else { } else {
let key = X509::from_der(der)?.public_key()?.public_key_to_der()?; let key = X509::from_der(der)?.public_key()?.public_key_to_der()?;
let len = key.len(); let len = key.len();
@ -235,7 +240,7 @@ pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Erro
pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> { pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> {
openssl::pkcs5::pbkdf2_hmac(pass, salt, iter, MessageDigest::sha256(), key) openssl::pkcs5::pbkdf2_hmac(pass, salt, iter, MessageDigest::sha256(), key)
.map_err(|_e| Error::TLSStack) .map_err(|_e| ErrorCode::TLSStack.into())
} }
pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> { pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> {
@ -295,7 +300,7 @@ pub fn lowlevel_encrypt_aead(
aad: &[u8], aad: &[u8],
data: &[u8], data: &[u8],
tag: &mut [u8], tag: &mut [u8],
) -> Result<Vec<u8>, ErrorStack> { ) -> Result<alloc::vec::Vec<u8>, ErrorStack> {
let t = symm::Cipher::aes_128_ccm(); let t = symm::Cipher::aes_128_ccm();
let mut ctx = CipherCtx::new()?; let mut ctx = CipherCtx::new()?;
CipherCtxRef::encrypt_init( CipherCtxRef::encrypt_init(
@ -331,7 +336,7 @@ pub fn lowlevel_decrypt_aead(
aad: &[u8], aad: &[u8],
data: &[u8], data: &[u8],
tag: &[u8], tag: &[u8],
) -> Result<Vec<u8>, ErrorStack> { ) -> Result<alloc::vec::Vec<u8>, ErrorStack> {
let t = symm::Cipher::aes_128_ccm(); let t = symm::Cipher::aes_128_ccm();
let mut ctx = CipherCtx::new()?; let mut ctx = CipherCtx::new()?;
CipherCtxRef::decrypt_init( CipherCtxRef::decrypt_init(
@ -375,7 +380,9 @@ impl Sha256 {
} }
pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { pub fn update(&mut self, data: &[u8]) -> Result<(), Error> {
self.hasher.update(data).map_err(|_| Error::TLSStack) self.hasher
.update(data)
.map_err(|_| ErrorCode::TLSStack.into())
} }
pub fn finish(mut self, data: &mut [u8]) -> Result<(), Error> { pub fn finish(mut self, data: &mut [u8]) -> Result<(), Error> {

View file

@ -15,9 +15,10 @@
* limitations under the License. * limitations under the License.
*/ */
use std::convert::{TryFrom, TryInto}; use core::convert::{TryFrom, TryInto};
use aes::Aes128; use aes::Aes128;
use alloc::vec;
use ccm::{ use ccm::{
aead::generic_array::GenericArray, aead::generic_array::GenericArray,
consts::{U13, U16}, consts::{U13, U16},
@ -33,20 +34,24 @@ use p256::{
use sha2::Digest; use sha2::Digest;
use x509_cert::{ use x509_cert::{
attr::AttributeType, attr::AttributeType,
der::{asn1::BitString, Any, Encode}, der::{asn1::BitString, Any, Encode, Writer},
name::RdnSequence, name::RdnSequence,
request::CertReq, request::CertReq,
spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned}, spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned},
}; };
use crate::error::Error; use crate::{
error::{Error, ErrorCode},
use super::CryptoKeyPair; secure_channel::crypto_rustcrypto::RandRngCore,
utils::rand::Rand,
};
type HmacSha256I = hmac::Hmac<sha2::Sha256>; type HmacSha256I = hmac::Hmac<sha2::Sha256>;
type AesCcm = Ccm<Aes128, U16, U13>; type AesCcm = Ccm<Aes128, U16, U13>;
#[derive(Clone)] extern crate alloc;
#[derive(Debug, Clone)]
pub struct Sha256 { pub struct Sha256 {
hasher: sha2::Sha256, hasher: sha2::Sha256,
} }
@ -79,7 +84,7 @@ impl HmacSha256 {
Ok(Self { Ok(Self {
inner: HmacSha256I::new_from_slice(key).map_err(|e| { inner: HmacSha256I::new_from_slice(key).map_err(|e| {
error!("Error creating HmacSha256 {:?}", e); error!("Error creating HmacSha256 {:?}", e);
Error::TLSStack ErrorCode::TLSStack
})?, })?,
}) })
} }
@ -96,18 +101,20 @@ impl HmacSha256 {
} }
} }
#[derive(Debug)]
pub enum KeyType { pub enum KeyType {
Private(SecretKey), Private(SecretKey),
Public(PublicKey), Public(PublicKey),
} }
#[derive(Debug)]
pub struct KeyPair { pub struct KeyPair {
key: KeyType, key: KeyType,
} }
impl KeyPair { impl KeyPair {
pub fn new() -> Result<Self, Error> { pub fn new(rand: Rand) -> Result<Self, Error> {
let mut rng = rand::thread_rng(); let mut rng = RandRngCore(rand);
let secret_key = SecretKey::random(&mut rng); let secret_key = SecretKey::random(&mut rng);
Ok(Self { Ok(Self {
@ -143,13 +150,11 @@ impl KeyPair {
fn private_key(&self) -> Result<&SecretKey, Error> { fn private_key(&self) -> Result<&SecretKey, Error> {
match &self.key { match &self.key {
KeyType::Private(key) => Ok(key), KeyType::Private(key) => Ok(key),
KeyType::Public(_) => Err(Error::Crypto), KeyType::Public(_) => Err(ErrorCode::Crypto.into()),
}
} }
} }
impl CryptoKeyPair for KeyPair { pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
match &self.key { match &self.key {
KeyType::Private(key) => { KeyType::Private(key) => {
let bytes = key.to_bytes(); let bytes = key.to_bytes();
@ -158,10 +163,10 @@ impl CryptoKeyPair for KeyPair {
priv_key[..slice.len()].copy_from_slice(slice); priv_key[..slice.len()].copy_from_slice(slice);
Ok(len) Ok(len)
} }
KeyType::Public(_) => Err(Error::Crypto), KeyType::Public(_) => Err(ErrorCode::Crypto.into()),
} }
} }
fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
use p256::ecdsa::signature::Signer; use p256::ecdsa::signature::Signer;
let subject = RdnSequence(vec![x509_cert::name::RelativeDistinguishedName( let subject = RdnSequence(vec![x509_cert::name::RelativeDistinguishedName(
@ -200,7 +205,7 @@ impl CryptoKeyPair for KeyPair {
attributes: Default::default(), attributes: Default::default(),
}; };
let mut message = vec![]; let mut message = vec![];
info.encode(&mut message).unwrap(); info.encode(&mut VecWriter(&mut message)).unwrap();
// Can't use self.sign_msg as the signature has to be in DER format // Can't use self.sign_msg as the signature has to be in DER format
let private_key = self.private_key()?; let private_key = self.private_key()?;
@ -224,14 +229,14 @@ impl CryptoKeyPair for KeyPair {
Ok(a) Ok(a)
} }
fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> { pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> {
let point = self.public_key_point().to_encoded_point(false); let point = self.public_key_point().to_encoded_point(false);
let bytes = point.as_bytes(); let bytes = point.as_bytes();
let len = bytes.len(); let len = bytes.len();
pub_key[..len].copy_from_slice(bytes); pub_key[..len].copy_from_slice(bytes);
Ok(len) Ok(len)
} }
fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> { pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> {
let encoded_point = EncodedPoint::from_bytes(peer_pub_key).unwrap(); let encoded_point = EncodedPoint::from_bytes(peer_pub_key).unwrap();
let peer_pubkey = PublicKey::from_encoded_point(&encoded_point).unwrap(); let peer_pubkey = PublicKey::from_encoded_point(&encoded_point).unwrap();
let private_key = self.private_key()?; let private_key = self.private_key()?;
@ -247,11 +252,11 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> { pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> {
use p256::ecdsa::signature::Signer; use p256::ecdsa::signature::Signer;
if signature.len() < super::EC_SIGNATURE_LEN_BYTES { if signature.len() < super::EC_SIGNATURE_LEN_BYTES {
return Err(Error::NoSpace); return Err(ErrorCode::NoSpace.into());
} }
match &self.key { match &self.key {
@ -266,7 +271,7 @@ impl CryptoKeyPair for KeyPair {
KeyType::Public(_) => todo!(), KeyType::Public(_) => todo!(),
} }
} }
fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> {
use p256::ecdsa::signature::Verifier; use p256::ecdsa::signature::Verifier;
let verifying_key = VerifyingKey::from_affine(self.public_key_point()).unwrap(); let verifying_key = VerifyingKey::from_affine(self.public_key_point()).unwrap();
@ -274,7 +279,7 @@ impl CryptoKeyPair for KeyPair {
verifying_key verifying_key
.verify(msg, &signature) .verify(msg, &signature)
.map_err(|_| Error::InvalidSignature)?; .map_err(|_| ErrorCode::InvalidSignature)?;
Ok(()) Ok(())
} }
@ -291,7 +296,7 @@ pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Resu
.expand(info, key) .expand(info, key)
.map_err(|e| { .map_err(|e| {
error!("Error with hkdf_sha256 {:?}", e); error!("Error with hkdf_sha256 {:?}", e);
Error::TLSStack ErrorCode::TLSStack.into()
}) })
} }
@ -370,3 +375,13 @@ impl<'a> ccm::aead::Buffer for SliceBuffer<'a> {
self.len = len; self.len = len;
} }
} }
struct VecWriter<'a>(&'a mut alloc::vec::Vec<u8>);
impl<'a> Writer for VecWriter<'a> {
fn write(&mut self, slice: &[u8]) -> x509_cert::der::Result<()> {
self.0.extend_from_slice(slice);
Ok(())
}
}

View file

@ -14,8 +14,10 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
use crate::{
use crate::error::Error; error::{Error, ErrorCode},
tlv::{FromTLV, TLVWriter, TagType, ToTLV},
};
pub const SYMM_KEY_LEN_BITS: usize = 128; pub const SYMM_KEY_LEN_BITS: usize = 128;
pub const SYMM_KEY_LEN_BYTES: usize = SYMM_KEY_LEN_BITS / 8; pub const SYMM_KEY_LEN_BYTES: usize = SYMM_KEY_LEN_BITS / 8;
@ -35,24 +37,14 @@ pub const ECDH_SHARED_SECRET_LEN_BYTES: usize = 32;
pub const EC_SIGNATURE_LEN_BYTES: usize = 64; pub const EC_SIGNATURE_LEN_BYTES: usize = 64;
// APIs particular to a KeyPair so a KeyPair object can be defined #[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))]
pub trait CryptoKeyPair {
fn get_csr<'a>(&self, csr: &'a mut [u8]) -> Result<&'a [u8], Error>;
fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error>;
fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error>;
fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error>;
fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error>;
fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error>;
}
#[cfg(feature = "crypto_esp_mbedtls")]
mod crypto_esp_mbedtls; mod crypto_esp_mbedtls;
#[cfg(feature = "crypto_esp_mbedtls")] #[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))]
pub use self::crypto_esp_mbedtls::*; pub use self::crypto_esp_mbedtls::*;
#[cfg(feature = "crypto_mbedtls")] #[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))]
mod crypto_mbedtls; mod crypto_mbedtls;
#[cfg(feature = "crypto_mbedtls")] #[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))]
pub use self::crypto_mbedtls::*; pub use self::crypto_mbedtls::*;
#[cfg(feature = "crypto_openssl")] #[cfg(feature = "crypto_openssl")]
@ -65,13 +57,58 @@ mod crypto_rustcrypto;
#[cfg(feature = "crypto_rustcrypto")] #[cfg(feature = "crypto_rustcrypto")]
pub use self::crypto_rustcrypto::*; pub use self::crypto_rustcrypto::*;
#[cfg(not(any(
feature = "crypto_openssl",
feature = "crypto_mbedtls",
feature = "crypto_rustcrypto"
)))]
pub mod crypto_dummy; pub mod crypto_dummy;
#[cfg(not(any(
feature = "crypto_openssl",
feature = "crypto_mbedtls",
feature = "crypto_rustcrypto"
)))]
pub use self::crypto_dummy::*;
impl<'a> FromTLV<'a> for KeyPair {
fn from_tlv(t: &crate::tlv::TLVElement<'a>) -> Result<Self, Error>
where
Self: Sized,
{
t.confirm_array()?.enter();
if let Some(mut array) = t.enter() {
let pub_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?;
let priv_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?;
KeyPair::new_from_components(pub_key, priv_key)
} else {
Err(ErrorCode::Invalid.into())
}
}
}
impl ToTLV for KeyPair {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
let mut buf = [0; 1024]; // TODO
tw.start_array(tag)?;
let size = self.get_public_key(&mut buf)?;
tw.str16(TagType::Anonymous, &buf[..size])?;
let size = self.get_private_key(&mut buf)?;
tw.str16(TagType::Anonymous, &buf[..size])?;
tw.end_container()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::error::Error; use crate::error::ErrorCode;
use super::{CryptoKeyPair, KeyPair}; use super::KeyPair;
#[test] #[test]
fn test_verify_msg_success() { fn test_verify_msg_success() {
@ -83,8 +120,9 @@ mod tests {
fn test_verify_msg_fail() { fn test_verify_msg_fail() {
let key = KeyPair::new_from_public(&test_vectors::PUB_KEY1).unwrap(); let key = KeyPair::new_from_public(&test_vectors::PUB_KEY1).unwrap();
assert_eq!( assert_eq!(
key.verify_msg(&test_vectors::MSG1_FAIL, &test_vectors::SIGNATURE1), key.verify_msg(&test_vectors::MSG1_FAIL, &test_vectors::SIGNATURE1)
Err(Error::InvalidSignature) .map_err(|e| e.code()),
Err(ErrorCode::InvalidSignature)
); );
} }

View file

@ -15,14 +15,29 @@
* limitations under the License. * limitations under the License.
*/ */
use core::convert::TryInto;
use super::objects::*; use super::objects::*;
use crate::error::*; use crate::{attribute_enum, error::Error, utils::rand::Rand};
use num_derive::FromPrimitive; use strum::FromRepr;
pub const ID: u32 = 0x0028; pub const ID: u32 = 0x0028;
#[derive(FromPrimitive)] #[derive(Clone, Copy, Debug, FromRepr)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
DMRevision(AttrType<u8>) = 0,
VendorId(AttrType<u16>) = 2,
ProductId(AttrType<u16>) = 4,
HwVer(AttrType<u16>) = 7,
SwVer(AttrType<u32>) = 9,
SwVerString(AttrUtfType) = 0xa,
SerialNo(AttrUtfType) = 0x0f,
}
attribute_enum!(Attributes);
pub enum AttributesDiscriminants {
DMRevision = 0, DMRevision = 0,
VendorId = 2, VendorId = 2,
ProductId = 4, ProductId = 4,
@ -33,82 +48,106 @@ pub enum Attributes {
} }
#[derive(Default)] #[derive(Default)]
pub struct BasicInfoConfig { pub struct BasicInfoConfig<'a> {
pub vid: u16, pub vid: u16,
pub pid: u16, pub pid: u16,
pub hw_ver: u16, pub hw_ver: u16,
pub sw_ver: u32, pub sw_ver: u32,
pub sw_ver_str: String, pub sw_ver_str: &'a str,
pub serial_no: String, pub serial_no: &'a str,
/// Device name; up to 32 characters /// Device name; up to 32 characters
pub device_name: String, pub device_name: &'a str,
} }
pub struct BasicInfoCluster { pub const CLUSTER: Cluster<'static> = Cluster {
base: Cluster, id: ID as _,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new(
AttributesDiscriminants::DMRevision as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::VendorId as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::ProductId as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::HwVer as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::SwVer as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::SwVerString as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::SerialNo as u16,
Access::RV,
Quality::FIXED,
),
],
commands: &[],
};
pub struct BasicInfoCluster<'a> {
data_ver: Dataver,
cfg: &'a BasicInfoConfig<'a>,
} }
impl BasicInfoCluster { impl<'a> BasicInfoCluster<'a> {
pub fn new(cfg: BasicInfoConfig) -> Result<Box<Self>, Error> { pub fn new(cfg: &'a BasicInfoConfig<'a>, rand: Rand) -> Self {
let mut cluster = Box::new(BasicInfoCluster { Self {
base: Cluster::new(ID)?, data_ver: Dataver::new(rand),
}); cfg,
let attrs = [
Attribute::new(
Attributes::DMRevision as u16,
AttrValue::Uint8(1),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::VendorId as u16,
AttrValue::Uint16(cfg.vid),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::ProductId as u16,
AttrValue::Uint16(cfg.pid),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::HwVer as u16,
AttrValue::Uint16(cfg.hw_ver),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::SwVer as u16,
AttrValue::Uint32(cfg.sw_ver),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::SwVerString as u16,
AttrValue::Utf8(cfg.sw_ver_str),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::SerialNo as u16,
AttrValue::Utf8(cfg.serial_no),
Access::RV,
Quality::FIXED,
),
];
cluster.base.add_attributes(&attrs[..])?;
Ok(cluster)
} }
} }
impl ClusterType for BasicInfoCluster { pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
fn base(&self) -> &Cluster { if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
&self.base if attr.is_system() {
} CLUSTER.read(attr.attr_id, writer)
fn base_mut(&mut self) -> &mut Cluster { } else {
&mut self.base match attr.attr_id.try_into()? {
Attributes::DMRevision(codec) => codec.encode(writer, 1),
Attributes::VendorId(codec) => codec.encode(writer, self.cfg.vid),
Attributes::ProductId(codec) => codec.encode(writer, self.cfg.pid),
Attributes::HwVer(codec) => codec.encode(writer, self.cfg.hw_ver),
Attributes::SwVer(codec) => codec.encode(writer, self.cfg.sw_ver),
Attributes::SwVerString(codec) => codec.encode(writer, self.cfg.sw_ver_str),
Attributes::SerialNo(codec) => codec.encode(writer, self.cfg.serial_no),
}
}
} else {
Ok(())
}
}
}
impl<'a> Handler for BasicInfoCluster<'a> {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
BasicInfoCluster::read(self, attr, encoder)
}
}
impl<'a> NonBlockingHandler for BasicInfoCluster<'a> {}
impl<'a> ChangeNotifier<()> for BasicInfoCluster<'a> {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
} }
} }

View file

@ -15,114 +15,156 @@
* limitations under the License. * limitations under the License.
*/ */
use core::convert::TryInto;
use super::objects::*; use super::objects::*;
use crate::{ use crate::{
cmd_enter, attribute_enum, cmd_enter, command_enum, error::Error, interaction_model::core::Transaction,
error::*, tlv::TLVElement, utils::rand::Rand,
interaction_model::{command::CommandReq, core::IMStatusCode},
}; };
use log::info; use log::info;
use num_derive::FromPrimitive; use strum::{EnumDiscriminants, FromRepr};
pub const ID: u32 = 0x0006; pub const ID: u32 = 0x0006;
#[derive(FromRepr, EnumDiscriminants)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
OnOff = 0x0, OnOff(AttrType<bool>) = 0x0,
} }
#[derive(FromPrimitive)] attribute_enum!(Attributes);
#[derive(FromRepr, EnumDiscriminants)]
#[repr(u32)]
pub enum Commands { pub enum Commands {
Off = 0x0, Off = 0x0,
On = 0x01, On = 0x01,
Toggle = 0x02, Toggle = 0x02,
} }
fn attr_on_off_new() -> Attribute { command_enum!(Commands);
// OnOff, Value: false
pub const CLUSTER: Cluster<'static> = Cluster {
id: ID as _,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new( Attribute::new(
Attributes::OnOff as u16, AttributesDiscriminants::OnOff as u16,
AttrValue::Bool(false),
Access::RV, Access::RV,
Quality::PERSISTENT, Quality::PERSISTENT,
) ),
} ],
commands: &[
CommandsDiscriminants::Off as _,
CommandsDiscriminants::On as _,
CommandsDiscriminants::Toggle as _,
],
};
pub struct OnOffCluster { pub struct OnOffCluster {
base: Cluster, data_ver: Dataver,
on: bool,
} }
impl OnOffCluster { impl OnOffCluster {
pub fn new() -> Result<Box<Self>, Error> { pub fn new(rand: Rand) -> Self {
let mut cluster = Box::new(OnOffCluster { Self {
base: Cluster::new(ID)?, data_ver: Dataver::new(rand),
}); on: false,
cluster.base.add_attribute(attr_on_off_new())?;
Ok(cluster)
} }
} }
impl ClusterType for OnOffCluster { pub fn set(&mut self, on: bool) {
fn base(&self) -> &Cluster { if self.on != on {
&self.base self.on = on;
self.data_ver.changed();
} }
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
} }
fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
let cmd = cmd_req if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
.cmd if attr.is_system() {
.path CLUSTER.read(attr.attr_id, writer)
.leaf } else {
.map(num::FromPrimitive::from_u32) match attr.attr_id.try_into()? {
.ok_or(IMStatusCode::UnsupportedCommand)? Attributes::OnOff(codec) => codec.encode(writer, self.on),
.ok_or(IMStatusCode::UnsupportedCommand)?; }
match cmd { }
} else {
Ok(())
}
}
pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
let data = data.with_dataver(self.data_ver.get())?;
match attr.attr_id.try_into()? {
Attributes::OnOff(codec) => self.set(codec.decode(data)?),
}
self.data_ver.changed();
Ok(())
}
pub fn invoke(
&mut self,
transaction: &mut Transaction,
cmd: &CmdDetails,
_data: &TLVElement,
_encoder: CmdDataEncoder,
) -> Result<(), Error> {
match cmd.cmd_id.try_into()? {
Commands::Off => { Commands::Off => {
cmd_enter!("Off"); cmd_enter!("Off");
let value = self self.set(false);
.base
.read_attribute_raw(Attributes::OnOff as u16)
.unwrap();
if AttrValue::Bool(true) == *value {
self.base
.write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(false))
.map_err(|_| IMStatusCode::Failure)?;
}
cmd_req.trans.complete();
Err(IMStatusCode::Success)
} }
Commands::On => { Commands::On => {
cmd_enter!("On"); cmd_enter!("On");
let value = self self.set(true);
.base
.read_attribute_raw(Attributes::OnOff as u16)
.unwrap();
if AttrValue::Bool(false) == *value {
self.base
.write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(true))
.map_err(|_| IMStatusCode::Failure)?;
}
cmd_req.trans.complete();
Err(IMStatusCode::Success)
} }
Commands::Toggle => { Commands::Toggle => {
cmd_enter!("Toggle"); cmd_enter!("Toggle");
let value = match self self.set(!self.on);
.base
.read_attribute_raw(Attributes::OnOff as u16)
.unwrap()
{
&AttrValue::Bool(v) => v,
_ => false,
};
self.base
.write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(!value))
.map_err(|_| IMStatusCode::Failure)?;
cmd_req.trans.complete();
Err(IMStatusCode::Success)
} }
} }
transaction.complete();
self.data_ver.changed();
Ok(())
}
}
impl Handler for OnOffCluster {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
OnOffCluster::read(self, attr, encoder)
}
fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
OnOffCluster::write(self, attr, data)
}
fn invoke(
&mut self,
transaction: &mut Transaction,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
OnOffCluster::invoke(self, transaction, cmd, data, encoder)
}
}
// TODO: Might be removed once the `on` member is externalized
impl NonBlockingHandler for OnOffCluster {}
impl ChangeNotifier<()> for OnOffCluster {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
} }
} }

View file

@ -16,29 +16,59 @@
*/ */
use crate::{ use crate::{
data_model::objects::{Cluster, ClusterType}, data_model::objects::{Cluster, Handler},
error::Error, error::{Error, ErrorCode},
utils::rand::Rand,
};
use super::objects::{
AttrDataEncoder, AttrDetails, ChangeNotifier, Dataver, NonBlockingHandler, ATTRIBUTE_LIST,
FEATURE_MAP,
}; };
const CLUSTER_NETWORK_COMMISSIONING_ID: u32 = 0x0031; const CLUSTER_NETWORK_COMMISSIONING_ID: u32 = 0x0031;
pub struct TemplateCluster { pub const CLUSTER: Cluster<'static> = Cluster {
base: Cluster, id: CLUSTER_NETWORK_COMMISSIONING_ID as _,
} feature_map: 0,
attributes: &[FEATURE_MAP, ATTRIBUTE_LIST],
commands: &[],
};
impl ClusterType for TemplateCluster { pub struct TemplateCluster {
fn base(&self) -> &Cluster { data_ver: Dataver,
&self.base
}
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
}
} }
impl TemplateCluster { impl TemplateCluster {
pub fn new() -> Result<Box<Self>, Error> { pub fn new(rand: Rand) -> Self {
Ok(Box::new(Self { Self {
base: Cluster::new(CLUSTER_NETWORK_COMMISSIONING_ID)?, data_ver: Dataver::new(rand),
})) }
}
pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
Err(ErrorCode::AttributeNotFound.into())
}
} else {
Ok(())
}
}
}
impl Handler for TemplateCluster {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
TemplateCluster::read(self, attr, encoder)
}
}
impl NonBlockingHandler for TemplateCluster {}
impl ChangeNotifier<()> for TemplateCluster {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
} }
} }

View file

@ -0,0 +1,301 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use core::cell::RefCell;
use super::objects::*;
use crate::{
acl::{Accessor, AclMgr},
error::*,
interaction_model::core::{Interaction, Transaction},
tlv::TLVWriter,
transport::packet::Packet,
};
pub struct DataModel<'a, T> {
pub acl_mgr: &'a RefCell<AclMgr>,
pub node: &'a Node<'a>,
pub handler: T,
}
impl<'a, T> DataModel<'a, T> {
pub const fn new(acl_mgr: &'a RefCell<AclMgr>, node: &'a Node<'a>, handler: T) -> Self {
Self {
acl_mgr,
node,
handler,
}
}
pub fn handle(
&mut self,
interaction: Interaction,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error>
where
T: Handler,
{
let accessor = Accessor::for_session(transaction.session(), self.acl_mgr);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
match interaction {
Interaction::Read(req) => {
let mut resume_path = None;
for item in self.node.read(&req, &accessor) {
if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
Interaction::Write(req) => {
for item in self.node.write(&req, &accessor) {
AttrDataEncoder::handle_write(item, &mut self.handler, &mut tw)?;
}
req.complete(tx, transaction)
}
Interaction::Invoke(req) => {
for item in self.node.invoke(&req, &accessor) {
CmdDataEncoder::handle(item, &mut self.handler, transaction, &mut tw)?;
}
req.complete(tx, transaction)
}
Interaction::Subscribe(req) => {
let mut resume_path = None;
for item in self.node.subscribing_read(&req, &accessor) {
if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
Interaction::Timed(_) => Ok(false),
Interaction::ResumeRead(req) => {
let mut resume_path = None;
for item in self.node.resume_read(&req, &accessor) {
if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
Interaction::ResumeSubscribe(req) => {
let mut resume_path = None;
for item in self.node.resume_subscribing_read(&req, &accessor) {
if let Some(path) = AttrDataEncoder::handle_read(item, &self.handler, &mut tw)?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
}
}
#[cfg(feature = "nightly")]
pub async fn handle_async<'p>(
&mut self,
interaction: Interaction<'_>,
tx: &'p mut Packet<'_>,
transaction: &mut Transaction<'_, '_>,
) -> Result<bool, Error>
where
T: super::objects::asynch::AsyncHandler,
{
let accessor = Accessor::for_session(transaction.session(), self.acl_mgr);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
match interaction {
Interaction::Read(req) => {
let mut resume_path = None;
for item in self.node.read(&req, &accessor) {
if let Some(path) =
AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
Interaction::Write(req) => {
for item in self.node.write(&req, &accessor) {
AttrDataEncoder::handle_write_async(item, &mut self.handler, &mut tw).await?;
}
req.complete(tx, transaction)
}
Interaction::Invoke(req) => {
for item in self.node.invoke(&req, &accessor) {
CmdDataEncoder::handle_async(item, &mut self.handler, transaction, &mut tw)
.await?;
}
req.complete(tx, transaction)
}
Interaction::Subscribe(req) => {
let mut resume_path = None;
for item in self.node.subscribing_read(&req, &accessor) {
if let Some(path) =
AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
Interaction::Timed(_) => Ok(false),
Interaction::ResumeRead(req) => {
let mut resume_path = None;
for item in self.node.resume_read(&req, &accessor) {
if let Some(path) =
AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
Interaction::ResumeSubscribe(req) => {
let mut resume_path = None;
for item in self.node.resume_subscribing_read(&req, &accessor) {
if let Some(path) =
AttrDataEncoder::handle_read_async(item, &self.handler, &mut tw).await?
{
resume_path = Some(path);
break;
}
}
req.complete(tx, transaction, resume_path)
}
}
}
}
pub trait DataHandler {
fn handle(
&mut self,
interaction: Interaction,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error>;
}
impl<T> DataHandler for &mut T
where
T: DataHandler,
{
fn handle(
&mut self,
interaction: Interaction,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error> {
(**self).handle(interaction, tx, transaction)
}
}
impl<'a, T> DataHandler for DataModel<'a, T>
where
T: Handler,
{
fn handle(
&mut self,
interaction: Interaction,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error> {
DataModel::handle(self, interaction, tx, transaction)
}
}
#[cfg(feature = "nightly")]
pub mod asynch {
use crate::{
data_model::objects::asynch::AsyncHandler,
error::Error,
interaction_model::core::{Interaction, Transaction},
transport::packet::Packet,
};
use super::DataModel;
pub trait AsyncDataHandler {
async fn handle(
&mut self,
interaction: Interaction<'_>,
tx: &mut Packet,
transaction: &mut Transaction,
) -> Result<bool, Error>;
}
impl<T> AsyncDataHandler for &mut T
where
T: AsyncDataHandler,
{
async fn handle(
&mut self,
interaction: Interaction<'_>,
tx: &mut Packet<'_>,
transaction: &mut Transaction<'_, '_>,
) -> Result<bool, Error> {
(**self).handle(interaction, tx, transaction).await
}
}
impl<'a, T> AsyncDataHandler for DataModel<'a, T>
where
T: AsyncHandler,
{
async fn handle(
&mut self,
interaction: Interaction<'_>,
tx: &mut Packet<'_>,
transaction: &mut Transaction<'_, '_>,
) -> Result<bool, Error> {
DataModel::handle_async(self, interaction, tx, transaction).await
}
}
}

View file

@ -1,394 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use self::subscribe::SubsCtx;
use super::{
cluster_basic_information::BasicInfoConfig,
device_types::device_type_add_root_node,
objects::{self, *},
sdm::dev_att::DevAttDataFetcher,
system_model::descriptor::DescriptorCluster,
};
use crate::{
acl::{AccessReq, Accessor, AccessorSubjects, AclMgr, AuthMode},
error::*,
fabric::FabricMgr,
interaction_model::{
command::CommandReq,
core::{IMStatusCode, OpCode},
messages::{
ib::{self, AttrData, DataVersionFilter},
msg::{self, InvReq, ReadReq, WriteReq},
GenericPath,
},
InteractionConsumer, Transaction,
},
secure_channel::pake::PaseMgr,
tlv::{self, FromTLV, TLVArray, TLVWriter, TagType, ToTLV},
transport::{
proto_demux::ResponseRequired,
session::{Session, SessionMode},
},
};
use log::{error, info};
use std::sync::{Arc, RwLock};
#[derive(Clone)]
pub struct DataModel {
pub node: Arc<RwLock<Box<Node>>>,
acl_mgr: Arc<AclMgr>,
}
impl DataModel {
pub fn new(
dev_details: BasicInfoConfig,
dev_att: Box<dyn DevAttDataFetcher>,
fabric_mgr: Arc<FabricMgr>,
acl_mgr: Arc<AclMgr>,
pase_mgr: PaseMgr,
) -> Result<Self, Error> {
let dm = DataModel {
node: Arc::new(RwLock::new(Node::new()?)),
acl_mgr: acl_mgr.clone(),
};
{
let mut node = dm.node.write()?;
node.set_changes_cb(Box::new(dm.clone()));
device_type_add_root_node(
&mut node,
dev_details,
dev_att,
fabric_mgr,
acl_mgr,
pase_mgr,
)?;
}
Ok(dm)
}
// Encode a write attribute from a path that may or may not be wildcard
fn handle_write_attr_path(
node: &mut Node,
accessor: &Accessor,
attr_data: &AttrData,
tw: &mut TLVWriter,
) {
let gen_path = attr_data.path.to_gp();
let mut encoder = AttrWriteEncoder::new(tw, TagType::Anonymous);
encoder.set_path(gen_path);
// The unsupported pieces of the wildcard path
if attr_data.path.cluster.is_none() {
encoder.encode_status(IMStatusCode::UnsupportedCluster, 0);
return;
}
if attr_data.path.attr.is_none() {
encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0);
return;
}
// Get the data
let write_data = match &attr_data.data {
EncodeValue::Closure(_) | EncodeValue::Value(_) => {
error!("Not supported");
return;
}
EncodeValue::Tlv(t) => t,
};
if gen_path.is_wildcard() {
// This is a wildcard path, skip error
// This is required because there could be access control errors too that need
// to be taken care of.
encoder.skip_error();
}
let mut attr = AttrDetails {
// will be udpated in the loop below
attr_id: 0,
list_index: attr_data.path.list_index,
fab_filter: false,
fab_idx: accessor.fab_idx,
};
let result = node.for_each_cluster_mut(&gen_path, |path, c| {
if attr_data.data_ver.is_some() && Some(c.base().get_dataver()) != attr_data.data_ver {
encoder.encode_status(IMStatusCode::DataVersionMismatch, 0);
return Ok(());
}
attr.attr_id = path.leaf.unwrap_or_default() as u16;
encoder.set_path(*path);
let mut access_req = AccessReq::new(accessor, path, Access::WRITE);
let r = match Cluster::write_attribute(c, &mut access_req, write_data, &attr) {
Ok(_) => IMStatusCode::Success,
Err(e) => e,
};
encoder.encode_status(r, 0);
Ok(())
});
if let Err(e) = result {
// We hit this only if this is a non-wildcard path and some parts of the path are missing
encoder.encode_status(e, 0);
}
}
// Handle command from a path that may or may not be wildcard
fn handle_command_path(node: &mut Node, cmd_req: &mut CommandReq) {
let wildcard = cmd_req.cmd.path.is_wildcard();
let path = cmd_req.cmd.path;
let result = node.for_each_cluster_mut(&path, |path, c| {
cmd_req.cmd.path = *path;
let result = c.handle_command(cmd_req);
if let Err(e) = result {
// It is likely that we might have to do an 'Access' aware traversal
// if there are other conditions in the wildcard scenario that shouldn't be
// encoded as CmdStatus
if !(wildcard && e == IMStatusCode::UnsupportedCommand) {
let invoke_resp = ib::InvResp::status_new(cmd_req.cmd, e, 0);
let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous);
}
}
Ok(())
});
if !wildcard {
if let Err(e) = result {
// We hit this only if this is a non-wildcard path
let invoke_resp = ib::InvResp::status_new(cmd_req.cmd, e, 0);
let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous);
}
}
}
fn sess_to_accessor(&self, sess: &Session) -> Accessor {
match sess.get_session_mode() {
SessionMode::Case(c) => {
let mut subject =
AccessorSubjects::new(sess.get_peer_node_id().unwrap_or_default());
for i in c.cat_ids {
if i != 0 {
let _ = subject.add_catid(i);
}
}
Accessor::new(c.fab_idx, subject, AuthMode::Case, self.acl_mgr.clone())
}
SessionMode::Pase => Accessor::new(
0,
AccessorSubjects::new(1),
AuthMode::Pase,
self.acl_mgr.clone(),
),
SessionMode::PlainText => Accessor::new(
0,
AccessorSubjects::new(1),
AuthMode::Invalid,
self.acl_mgr.clone(),
),
}
}
/// Returns true if the path matches the cluster path and the data version is a match
fn data_filter_matches(
filters: &Option<&TLVArray<DataVersionFilter>>,
path: &GenericPath,
data_ver: u32,
) -> bool {
if let Some(filters) = *filters {
for filter in filters.iter() {
// TODO: No handling of 'node' comparision yet
if Some(filter.path.endpoint) == path.endpoint
&& Some(filter.path.cluster) == path.cluster
&& filter.data_ver == data_ver
{
return true;
}
}
}
false
}
}
pub mod read;
pub mod subscribe;
/// Type of Resume Request
enum ResumeReq {
Subscribe(subscribe::SubsCtx),
Read(read::ResumeReadReq),
}
impl objects::ChangeConsumer for DataModel {
fn endpoint_added(&self, id: EndptId, endpoint: &mut Endpoint) -> Result<(), Error> {
endpoint.add_cluster(DescriptorCluster::new(id, self.clone())?)?;
Ok(())
}
}
impl InteractionConsumer for DataModel {
fn consume_write_attr(
&self,
write_req: &WriteReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error> {
let accessor = self.sess_to_accessor(trans.session);
tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?;
let mut node = self.node.write().unwrap();
for attr_data in write_req.write_requests.iter() {
DataModel::handle_write_attr_path(&mut node, &accessor, &attr_data, tw);
}
tw.end_container()?;
Ok(())
}
fn consume_read_attr(
&self,
rx_buf: &[u8],
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error> {
let mut resume_from = None;
let root = tlv::get_root_node(rx_buf)?;
let req = ReadReq::from_tlv(&root)?;
self.handle_read_req(&req, trans, tw, &mut resume_from)?;
if resume_from.is_some() {
// This is a multi-hop read transaction, remember this read request
let resume = read::ResumeReadReq::new(rx_buf, &resume_from)?;
if !trans.exch.is_data_none() {
error!("Exchange data already set, and multi-hop read");
return Err(Error::InvalidState);
}
trans.exch.set_data_boxed(Box::new(ResumeReq::Read(resume)));
}
Ok(())
}
fn consume_invoke_cmd(
&self,
inv_req_msg: &InvReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error> {
let mut node = self.node.write().unwrap();
if let Some(inv_requests) = &inv_req_msg.inv_requests {
// Array of InvokeResponse IBs
tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?;
for i in inv_requests.iter() {
let data = if let Some(data) = i.data.unwrap_tlv() {
data
} else {
continue;
};
info!("Invoke Commmand Handler executing: {:?}", i.path);
let mut cmd_req = CommandReq {
cmd: i.path,
data,
trans,
resp: tw,
};
DataModel::handle_command_path(&mut node, &mut cmd_req);
}
tw.end_container()?;
}
Ok(())
}
fn consume_status_report(
&self,
req: &msg::StatusResp,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error> {
if let Some(mut resume) = trans.exch.take_data_boxed::<ResumeReq>() {
let result = match *resume {
ResumeReq::Read(ref mut read) => self.handle_resume_read(read, trans, tw)?,
ResumeReq::Subscribe(ref mut ctx) => ctx.handle_status_report(trans, tw, self)?,
};
trans.exch.set_data_boxed(resume);
Ok(result)
} else {
// Nothing to do for now
trans.complete();
info!("Received status report with status {:?}", req.status);
Ok((OpCode::Reserved, ResponseRequired::No))
}
}
fn consume_subscribe(
&self,
rx_buf: &[u8],
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error> {
if !trans.exch.is_data_none() {
error!("Exchange data already set!");
return Err(Error::InvalidState);
}
let ctx = SubsCtx::new(rx_buf, trans, tw, self)?;
trans
.exch
.set_data_boxed(Box::new(ResumeReq::Subscribe(ctx)));
Ok((OpCode::ReportData, ResponseRequired::Yes))
}
}
/// Encoder for generating a response to a write request
pub struct AttrWriteEncoder<'a, 'b, 'c> {
tw: &'a mut TLVWriter<'b, 'c>,
tag: TagType,
path: GenericPath,
skip_error: bool,
}
impl<'a, 'b, 'c> AttrWriteEncoder<'a, 'b, 'c> {
pub fn new(tw: &'a mut TLVWriter<'b, 'c>, tag: TagType) -> Self {
Self {
tw,
tag,
path: Default::default(),
skip_error: false,
}
}
pub fn skip_error(&mut self) {
self.skip_error = true;
}
pub fn set_path(&mut self, path: GenericPath) {
self.path = path;
}
}
impl<'a, 'b, 'c> Encoder for AttrWriteEncoder<'a, 'b, 'c> {
fn encode(&mut self, _value: EncodeValue) {
// Only status encodes for AttrWriteResponse
}
fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16) {
if self.skip_error && status != IMStatusCode::Success {
// Don't encode errors
return;
}
let resp = ib::AttrStatus::new(&self.path, status, cluster_status);
let _ = resp.to_tlv(self.tw, self.tag);
}
}

View file

@ -1,319 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::{
acl::{AccessReq, Accessor},
data_model::{core::DataModel, objects::*},
error::*,
interaction_model::{
core::{IMStatusCode, OpCode},
messages::{
ib::{self, DataVersionFilter},
msg::{self, ReadReq, ReportDataTag::MoreChunkedMsgs, ReportDataTag::SupressResponse},
GenericPath,
},
Transaction,
},
tlv::{self, FromTLV, TLVArray, TLVWriter, TagType, ToTLV},
transport::{packet::Packet, proto_demux::ResponseRequired},
utils::writebuf::WriteBuf,
wb_shrink, wb_unshrink,
};
use log::error;
/// Encoder for generating a response to a read request
pub struct AttrReadEncoder<'a, 'b, 'c> {
tw: &'a mut TLVWriter<'b, 'c>,
data_ver: u32,
path: GenericPath,
skip_error: bool,
data_ver_filters: Option<&'a TLVArray<'a, DataVersionFilter>>,
is_buffer_full: bool,
}
impl<'a, 'b, 'c> AttrReadEncoder<'a, 'b, 'c> {
pub fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self {
Self {
tw,
data_ver: 0,
skip_error: false,
path: Default::default(),
data_ver_filters: None,
is_buffer_full: false,
}
}
pub fn skip_error(&mut self, skip: bool) {
self.skip_error = skip;
}
pub fn set_data_ver(&mut self, data_ver: u32) {
self.data_ver = data_ver;
}
pub fn set_data_ver_filters(&mut self, filters: &'a TLVArray<'a, DataVersionFilter>) {
self.data_ver_filters = Some(filters);
}
pub fn set_path(&mut self, path: GenericPath) {
self.path = path;
}
pub fn is_buffer_full(&self) -> bool {
self.is_buffer_full
}
}
impl<'a, 'b, 'c> Encoder for AttrReadEncoder<'a, 'b, 'c> {
fn encode(&mut self, value: EncodeValue) {
let resp = ib::AttrResp::Data(ib::AttrData::new(
Some(self.data_ver),
ib::AttrPath::new(&self.path),
value,
));
let anchor = self.tw.get_tail();
if resp.to_tlv(self.tw, TagType::Anonymous).is_err() {
self.is_buffer_full = true;
self.tw.rewind_to(anchor);
}
}
fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16) {
if !self.skip_error {
let resp =
ib::AttrResp::Status(ib::AttrStatus::new(&self.path, status, cluster_status));
let _ = resp.to_tlv(self.tw, TagType::Anonymous);
}
}
}
/// State to maintain when a Read Request needs to be resumed
/// resumed - the next chunk of the read needs to be returned
#[derive(Default)]
pub struct ResumeReadReq {
/// The Read Request Attribute Path that caused chunking, and this is the path
/// that needs to be resumed.
pub pending_req: Option<Packet<'static>>,
/// The Attribute that couldn't be encoded because our buffer got full. The next chunk
/// will start encoding from this attribute onwards.
/// Note that given wildcard reads, one PendingPath in the member above can generated
/// multiple encode paths. Hence this has to be maintained separately.
pub resume_from: Option<GenericPath>,
}
impl ResumeReadReq {
pub fn new(rx_buf: &[u8], resume_from: &Option<GenericPath>) -> Result<Self, Error> {
let mut packet = Packet::new_rx()?;
let dst = packet.as_borrow_slice();
let src_len = rx_buf.len();
dst[..src_len].copy_from_slice(rx_buf);
packet.get_parsebuf()?.set_len(src_len);
Ok(ResumeReadReq {
pending_req: Some(packet),
resume_from: *resume_from,
})
}
}
impl DataModel {
pub fn read_attribute_raw(
&self,
endpoint: EndptId,
cluster: ClusterId,
attr: AttrId,
) -> Result<AttrValue, IMStatusCode> {
let node = self.node.read().unwrap();
let cluster = node.get_cluster(endpoint, cluster)?;
cluster.base().read_attribute_raw(attr).map(|a| a.clone())
}
/// Encode a read attribute from a path that may or may not be wildcard
///
/// If the buffer gets full while generating the read response, we will return
/// an Err(path), where the path is the path that we should resume from, for the next chunk.
/// This facilitates chunk management
fn handle_read_attr_path(
node: &Node,
accessor: &Accessor,
attr_encoder: &mut AttrReadEncoder,
attr_details: &mut AttrDetails,
resume_from: &mut Option<GenericPath>,
) -> Result<(), Error> {
let mut status = Ok(());
let path = attr_encoder.path;
// Skip error reporting for wildcard paths, don't for concrete paths
attr_encoder.skip_error(path.is_wildcard());
let result = node.for_each_attribute(&path, |path, c| {
// Ignore processing if data filter matches.
// For a wildcard attribute, this may end happening unnecessarily for all attributes, although
// a single skip for the cluster is sufficient. That requires us to replace this for_each with a
// for_each_cluster
let cluster_data_ver = c.base().get_dataver();
if Self::data_filter_matches(&attr_encoder.data_ver_filters, path, cluster_data_ver) {
return Ok(());
}
// The resume_from indicates that this is the next chunk of a previous Read Request. In such cases, we
// need to skip until we hit this path.
if let Some(r) = resume_from {
// If resume_from is valid, and we haven't hit the resume_from yet, skip encoding
if r != path {
return Ok(());
} else {
// Else, wipe out the resume_from so subsequent paths can be encoded
*resume_from = None;
}
}
attr_details.attr_id = path.leaf.unwrap_or_default() as u16;
// Overwrite the previous path with the concrete path
attr_encoder.set_path(*path);
// Set the cluster's data version
attr_encoder.set_data_ver(cluster_data_ver);
let mut access_req = AccessReq::new(accessor, path, Access::READ);
Cluster::read_attribute(c, &mut access_req, attr_encoder, attr_details);
if attr_encoder.is_buffer_full() {
// Buffer is full, next time resume from this attribute
*resume_from = Some(*path);
status = Err(Error::NoSpace);
}
Ok(())
});
if let Err(e) = result {
// We hit this only if this is a non-wildcard path
attr_encoder.encode_status(e, 0);
}
status
}
/// Process an array of Attribute Read Requests
///
/// When the API returns the chunked read is on, if *resume_from is Some(x) otherwise
/// the read is complete
pub(super) fn handle_read_attr_array(
&self,
read_req: &ReadReq,
trans: &mut Transaction,
old_tw: &mut TLVWriter,
resume_from: &mut Option<GenericPath>,
) -> Result<(), Error> {
let old_wb = old_tw.get_buf();
// Note, this function may be called from multiple places: a) an actual read
// request, a b) resumed read request, c) subscribe request or d) resumed subscribe
// request. Hopefully 18 is sufficient to address all those scenarios.
//
// This is the amount of space we reserve for other things to be attached towards
// the end
const RESERVE_SIZE: usize = 24;
let mut new_wb = wb_shrink!(old_wb, RESERVE_SIZE);
let mut tw = TLVWriter::new(&mut new_wb);
let mut attr_encoder = AttrReadEncoder::new(&mut tw);
if let Some(filters) = &read_req.dataver_filters {
attr_encoder.set_data_ver_filters(filters);
}
if let Some(attr_requests) = &read_req.attr_requests {
let accessor = self.sess_to_accessor(trans.session);
let mut attr_details = AttrDetails::new(accessor.fab_idx, read_req.fabric_filtered);
let node = self.node.read().unwrap();
attr_encoder
.tw
.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?;
let mut result = Ok(());
for attr_path in attr_requests.iter() {
attr_encoder.set_path(attr_path.to_gp());
// Extract the attr_path fields into various structures
attr_details.list_index = attr_path.list_index;
result = DataModel::handle_read_attr_path(
&node,
&accessor,
&mut attr_encoder,
&mut attr_details,
resume_from,
);
if result.is_err() {
break;
}
}
// Now that all the read reports are captured, let's use the old_tw that is
// the full writebuf, and hopefully as all the necessary space to store this
wb_unshrink!(old_wb, new_wb);
old_tw.end_container()?; // Finish the AttrReports
if result.is_err() {
// If there was an error, indicate chunking. The resume_read_req would have been
// already populated in the loop above.
old_tw.bool(TagType::Context(MoreChunkedMsgs as u8), true)?;
} else {
// A None resume_from indicates no chunking
*resume_from = None;
}
}
Ok(())
}
/// Handle a read request
///
/// This could be called from an actual read request or a resumed read request. Subscription
/// requests do not come to this function.
/// When the API returns the chunked read is on, if *resume_from is Some(x) otherwise
/// the read is complete
pub fn handle_read_req(
&self,
read_req: &ReadReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
resume_from: &mut Option<GenericPath>,
) -> Result<(OpCode, ResponseRequired), Error> {
tw.start_struct(TagType::Anonymous)?;
self.handle_read_attr_array(read_req, trans, tw, resume_from)?;
if resume_from.is_none() {
tw.bool(TagType::Context(SupressResponse as u8), true)?;
// Mark transaction complete, if not chunked
trans.complete();
}
tw.end_container()?;
Ok((OpCode::ReportData, ResponseRequired::Yes))
}
/// Handle a resumed read request
pub fn handle_resume_read(
&self,
resume_read_req: &mut ResumeReadReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error> {
if let Some(packet) = resume_read_req.pending_req.as_mut() {
let rx_buf = packet.get_parsebuf()?.as_borrow_slice();
let root = tlv::get_root_node(rx_buf)?;
let req = ReadReq::from_tlv(&root)?;
self.handle_read_req(&req, trans, tw, &mut resume_read_req.resume_from)
} else {
// No pending req, is that even possible?
error!("This shouldn't have happened");
Ok((OpCode::Reserved, ResponseRequired::No))
}
}
}

View file

@ -1,142 +0,0 @@
/*
*
* Copyright (c) 2023 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use std::sync::atomic::{AtomicU32, Ordering};
use crate::{
error::Error,
interaction_model::{
core::OpCode,
messages::{
msg::{self, SubscribeReq, SubscribeResp},
GenericPath,
},
},
tlv::{self, get_root_node_struct, FromTLV, TLVWriter, TagType, ToTLV},
transport::proto_demux::ResponseRequired,
};
use super::{read::ResumeReadReq, DataModel, Transaction};
static SUBS_ID: AtomicU32 = AtomicU32::new(1);
#[derive(PartialEq)]
enum SubsState {
Confirming,
Confirmed,
}
pub struct SubsCtx {
state: SubsState,
id: u32,
resume_read_req: Option<ResumeReadReq>,
}
impl SubsCtx {
pub fn new(
rx_buf: &[u8],
trans: &mut Transaction,
tw: &mut TLVWriter,
dm: &DataModel,
) -> Result<Self, Error> {
let root = get_root_node_struct(rx_buf)?;
let req = SubscribeReq::from_tlv(&root)?;
let mut ctx = SubsCtx {
state: SubsState::Confirming,
// TODO
id: SUBS_ID.fetch_add(1, Ordering::SeqCst),
resume_read_req: None,
};
let mut resume_from = None;
ctx.do_read(&req, trans, tw, dm, &mut resume_from)?;
if resume_from.is_some() {
// This is a multi-hop read transaction, remember this read request
ctx.resume_read_req = Some(ResumeReadReq::new(rx_buf, &resume_from)?);
}
Ok(ctx)
}
pub fn handle_status_report(
&mut self,
trans: &mut Transaction,
tw: &mut TLVWriter,
dm: &DataModel,
) -> Result<(OpCode, ResponseRequired), Error> {
if self.state != SubsState::Confirming {
// Not relevant for us
trans.complete();
return Err(Error::Invalid);
}
// Is there a previous resume read pending
if self.resume_read_req.is_some() {
let mut resume_read_req = self.resume_read_req.take().unwrap();
if let Some(packet) = resume_read_req.pending_req.as_mut() {
let rx_buf = packet.get_parsebuf()?.as_borrow_slice();
let root = tlv::get_root_node(rx_buf)?;
let req = SubscribeReq::from_tlv(&root)?;
self.do_read(&req, trans, tw, dm, &mut resume_read_req.resume_from)?;
if resume_read_req.resume_from.is_some() {
// More chunks are pending, setup resume_read_req again
self.resume_read_req = Some(resume_read_req);
}
return Ok((OpCode::ReportData, ResponseRequired::Yes));
}
}
// We are here implies that the read is now complete
self.confirm_subscription(trans, tw)
}
fn confirm_subscription(
&mut self,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error> {
self.state = SubsState::Confirmed;
// TODO
let resp = SubscribeResp::new(self.id, 40);
resp.to_tlv(tw, TagType::Anonymous)?;
trans.complete();
Ok((OpCode::SubscriptResponse, ResponseRequired::Yes))
}
fn do_read(
&mut self,
req: &SubscribeReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
dm: &DataModel,
resume_from: &mut Option<GenericPath>,
) -> Result<(), Error> {
let read_req = req.to_read_req();
tw.start_struct(TagType::Anonymous)?;
tw.u32(
TagType::Context(msg::ReportDataTag::SubscriptionId as u8),
self.id,
)?;
dm.handle_read_attr_array(&read_req, trans, tw, resume_from)?;
tw.end_container()?;
Ok(())
}
}

View file

@ -15,60 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
use super::cluster_basic_information::BasicInfoCluster; use super::objects::DeviceType;
use super::cluster_basic_information::BasicInfoConfig;
use super::cluster_on_off::OnOffCluster;
use super::objects::*;
use super::sdm::admin_commissioning::AdminCommCluster;
use super::sdm::dev_att::DevAttDataFetcher;
use super::sdm::general_commissioning::GenCommCluster;
use super::sdm::noc::NocCluster;
use super::sdm::nw_commissioning::NwCommCluster;
use super::system_model::access_control::AccessControlCluster;
use crate::acl::AclMgr;
use crate::error::*;
use crate::fabric::FabricMgr;
use crate::secure_channel::pake::PaseMgr;
use std::sync::Arc;
use std::sync::RwLockWriteGuard;
pub const DEV_TYPE_ROOT_NODE: DeviceType = DeviceType { pub const DEV_TYPE_ROOT_NODE: DeviceType = DeviceType {
dtype: 0x0016, dtype: 0x0016,
drev: 1, drev: 1,
}; };
type WriteNode<'a> = RwLockWriteGuard<'a, Box<Node>>; pub const DEV_TYPE_ON_OFF_LIGHT: DeviceType = DeviceType {
pub fn device_type_add_root_node(
node: &mut WriteNode,
dev_info: BasicInfoConfig,
dev_att: Box<dyn DevAttDataFetcher>,
fabric_mgr: Arc<FabricMgr>,
acl_mgr: Arc<AclMgr>,
pase_mgr: PaseMgr,
) -> Result<EndptId, Error> {
// Add the root endpoint
let endpoint = node.add_endpoint(DEV_TYPE_ROOT_NODE)?;
if endpoint != 0 {
// Somehow endpoint 0 was already added, this shouldn't be the case
return Err(Error::Invalid);
};
// Add the mandatory clusters
node.add_cluster(0, BasicInfoCluster::new(dev_info)?)?;
let general_commissioning = GenCommCluster::new()?;
let failsafe = general_commissioning.failsafe();
node.add_cluster(0, general_commissioning)?;
node.add_cluster(0, NwCommCluster::new()?)?;
node.add_cluster(0, AdminCommCluster::new(pase_mgr)?)?;
node.add_cluster(
0,
NocCluster::new(dev_att, fabric_mgr, acl_mgr.clone(), failsafe)?,
)?;
node.add_cluster(0, AccessControlCluster::new(acl_mgr)?)?;
Ok(endpoint)
}
const DEV_TYPE_ON_OFF_LIGHT: DeviceType = DeviceType {
dtype: 0x0100, dtype: 0x0100,
drev: 2, drev: 2,
}; };
@ -77,9 +31,3 @@ pub const DEV_TYPE_ON_SMART_SPEAKER: DeviceType = DeviceType {
dtype: 0x0022, dtype: 0x0022,
drev: 2, drev: 2,
}; };
pub fn device_type_add_on_off_light(node: &mut WriteNode) -> Result<EndptId, Error> {
let endpoint = node.add_endpoint(DEV_TYPE_ON_OFF_LIGHT)?;
node.add_cluster(endpoint, OnOffCluster::new()?)?;
Ok(endpoint)
}

View file

@ -20,8 +20,9 @@ pub mod device_types;
pub mod objects; pub mod objects;
pub mod cluster_basic_information; pub mod cluster_basic_information;
pub mod cluster_media_playback; // TODO pub mod cluster_media_playback;
pub mod cluster_on_off; pub mod cluster_on_off;
pub mod cluster_template; pub mod cluster_template;
pub mod root_endpoint;
pub mod sdm; pub mod sdm;
pub mod system_model; pub mod system_model;

View file

@ -15,15 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use super::{AttrId, GlobalElements, Privilege}; use crate::data_model::objects::GlobalElements;
use crate::{
error::*, use super::{AttrId, Privilege};
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{TLVElement, TLVWriter, TagType, ToTLV},
};
use bitflags::bitflags; use bitflags::bitflags;
use log::error; use core::fmt::{self, Debug};
use std::fmt::{self, Debug, Formatter};
bitflags! { bitflags! {
#[derive(Default)] #[derive(Default)]
@ -83,110 +79,24 @@ bitflags! {
} }
} }
/* This file needs some major revamp.
* - instead of allocating all over the heap, we should use some kind of slab/block allocator
* - instead of arrays, can use linked-lists to conserve space and avoid the internal fragmentation
*/
#[derive(PartialEq, PartialOrd, Clone)]
pub enum AttrValue {
Int64(i64),
Uint8(u8),
Uint16(u16),
Uint32(u32),
Uint64(u64),
Bool(bool),
Utf8(String),
Custom,
}
impl Debug for AttrValue {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match &self {
AttrValue::Int64(v) => write!(f, "{:?}", *v),
AttrValue::Uint8(v) => write!(f, "{:?}", *v),
AttrValue::Uint16(v) => write!(f, "{:?}", *v),
AttrValue::Uint32(v) => write!(f, "{:?}", *v),
AttrValue::Uint64(v) => write!(f, "{:?}", *v),
AttrValue::Bool(v) => write!(f, "{:?}", *v),
AttrValue::Utf8(v) => write!(f, "{:?}", *v),
AttrValue::Custom => write!(f, "custom-attribute"),
}?;
Ok(())
}
}
impl ToTLV for AttrValue {
fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
// What is the time complexity of such long match statements?
match self {
AttrValue::Bool(v) => tw.bool(tag_type, *v),
AttrValue::Uint8(v) => tw.u8(tag_type, *v),
AttrValue::Uint16(v) => tw.u16(tag_type, *v),
AttrValue::Uint32(v) => tw.u32(tag_type, *v),
AttrValue::Uint64(v) => tw.u64(tag_type, *v),
AttrValue::Utf8(v) => tw.utf8(tag_type, v.as_bytes()),
_ => {
error!("Attribute type not yet supported");
Err(Error::AttributeNotFound)
}
}
}
}
impl AttrValue {
pub fn update_from_tlv(&mut self, tr: &TLVElement) -> Result<(), Error> {
match self {
AttrValue::Bool(v) => *v = tr.bool()?,
AttrValue::Uint8(v) => *v = tr.u8()?,
AttrValue::Uint16(v) => *v = tr.u16()?,
AttrValue::Uint32(v) => *v = tr.u32()?,
AttrValue::Uint64(v) => *v = tr.u64()?,
_ => {
error!("Attribute type not yet supported");
return Err(Error::AttributeNotFound);
}
}
Ok(())
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Attribute { pub struct Attribute {
pub(super) id: AttrId, pub id: AttrId,
pub(super) value: AttrValue, pub quality: Quality,
pub(super) quality: Quality, pub access: Access,
pub(super) access: Access,
}
impl Default for Attribute {
fn default() -> Attribute {
Attribute {
id: 0,
value: AttrValue::Bool(true),
quality: Default::default(),
access: Default::default(),
}
}
} }
impl Attribute { impl Attribute {
pub fn new(id: AttrId, value: AttrValue, access: Access, quality: Quality) -> Self { pub const fn new(id: AttrId, access: Access, quality: Quality) -> Self {
Attribute { Self {
id, id,
value,
access, access,
quality, quality,
} }
} }
pub fn set_value(&mut self, value: AttrValue) -> Result<(), Error> { pub fn is_system(&self) -> bool {
if !self.quality.contains(Quality::FIXED) { Self::is_system_attr(self.id)
self.value = value;
Ok(())
} else {
Err(Error::Invalid)
}
} }
pub fn is_system_attr(attr_id: AttrId) -> bool { pub fn is_system_attr(attr_id: AttrId) -> bool {
@ -194,9 +104,9 @@ impl Attribute {
} }
} }
impl std::fmt::Display for Attribute { impl core::fmt::Display for Attribute {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {:?}", self.id, self.value) write!(f, "{}", self.id)
} }
} }

View file

@ -15,25 +15,31 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::{
acl::AccessReq,
data_model::objects::{Access, AttrValue, Attribute, EncodeValue, Quality},
error::*,
interaction_model::{command::CommandReq, core::IMStatusCode},
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{Nullable, TLVElement, TLVWriter, TagType},
};
use log::error; use log::error;
use num_derive::FromPrimitive; use strum::FromRepr;
use rand::Rng;
use std::fmt::{self, Debug};
use super::{AttrId, ClusterId, Encoder}; use crate::{
acl::{AccessReq, Accessor},
attribute_enum,
data_model::objects::*,
error::{Error, ErrorCode},
interaction_model::{
core::IMStatusCode,
messages::{
ib::{AttrPath, AttrStatus, CmdPath, CmdStatus},
GenericPath,
},
},
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{Nullable, TLVWriter, TagType},
};
use core::{
convert::TryInto,
fmt::{self, Debug},
};
pub const ATTRS_PER_CLUSTER: usize = 10; #[derive(Clone, Copy, Debug, Eq, PartialEq, FromRepr)]
pub const CMDS_PER_CLUSTER: usize = 8; #[repr(u16)]
#[derive(FromPrimitive, Debug)]
pub enum GlobalElements { pub enum GlobalElements {
_ClusterRevision = 0xFFFD, _ClusterRevision = 0xFFFD,
FeatureMap = 0xFFFC, FeatureMap = 0xFFFC,
@ -44,297 +50,292 @@ pub enum GlobalElements {
FabricIndex = 0xFE, FabricIndex = 0xFE,
} }
attribute_enum!(GlobalElements);
pub const FEATURE_MAP: Attribute =
Attribute::new(GlobalElements::FeatureMap as _, Access::RV, Quality::NONE);
pub const ATTRIBUTE_LIST: Attribute = Attribute::new(
GlobalElements::AttributeList as _,
Access::RV,
Quality::NONE,
);
// TODO: What if we instead of creating this, we just pass the AttrData/AttrPath to the read/write // TODO: What if we instead of creating this, we just pass the AttrData/AttrPath to the read/write
// methods? // methods?
/// The Attribute Details structure records the details about the attribute under consideration. /// The Attribute Details structure records the details about the attribute under consideration.
/// Typically this structure is progressively built as we proceed through the request processing. #[derive(Debug)]
pub struct AttrDetails { pub struct AttrDetails<'a> {
/// Fabric Filtering Activated pub node: &'a Node<'a>,
pub fab_filter: bool, /// The actual endpoint ID
/// The current Fabric Index pub endpoint_id: EndptId,
pub fab_idx: u8, /// The actual cluster ID
/// List Index, if any pub cluster_id: ClusterId,
pub list_index: Option<Nullable<u16>>,
/// The actual attribute ID /// The actual attribute ID
pub attr_id: AttrId, pub attr_id: AttrId,
/// List Index, if any
pub list_index: Option<Nullable<u16>>,
/// The current Fabric Index
pub fab_idx: u8,
/// Fabric Filtering Activated
pub fab_filter: bool,
pub dataver: Option<u32>,
pub wildcard: bool,
} }
impl AttrDetails { impl<'a> AttrDetails<'a> {
pub fn new(fab_idx: u8, fab_filter: bool) -> Self { pub fn is_system(&self) -> bool {
Self { Attribute::is_system_attr(self.attr_id)
fab_filter,
fab_idx,
list_index: None,
attr_id: 0,
}
}
} }
pub trait ClusterType { pub fn path(&self) -> AttrPath {
// TODO: 5 methods is going to be quite expensive for vtables of all the clusters AttrPath {
fn base(&self) -> &Cluster; endpoint: Some(self.endpoint_id),
fn base_mut(&mut self) -> &mut Cluster; cluster: Some(self.cluster_id),
fn read_custom_attribute(&self, _encoder: &mut dyn Encoder, _attr: &AttrDetails) {} attr: Some(self.attr_id),
list_index: self.list_index,
fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { ..Default::default()
let cmd = cmd_req.cmd.path.leaf.map(|a| a as u16);
println!("Received command: {:?}", cmd);
Err(IMStatusCode::UnsupportedCommand)
}
/// Write an attribute
///
/// Note that if this method is defined, you must handle the write for all the attributes. Even those
/// that are not 'custom'. This is different from how you handle the read_custom_attribute() method.
/// The reason for this being, you may want to handle an attribute write request even though it is a
/// standard attribute like u16, u32 etc.
///
/// If you wish to update the standard attribute in the data model database, you must call the
/// write_attribute_from_tlv() method from the base cluster, as is shown here in the default case
fn write_attribute(
&mut self,
attr: &AttrDetails,
data: &TLVElement,
) -> Result<(), IMStatusCode> {
self.base_mut().write_attribute_from_tlv(attr.attr_id, data)
} }
} }
pub struct Cluster { pub fn status(&self, status: IMStatusCode) -> Result<Option<AttrStatus>, Error> {
pub(super) id: ClusterId, if self.should_report(status) {
attributes: Vec<Attribute>, Ok(Some(AttrStatus::new(
data_ver: u32, &GenericPath {
endpoint: Some(self.endpoint_id),
cluster: Some(self.cluster_id),
leaf: Some(self.attr_id as _),
},
status,
0,
)))
} else {
Ok(None)
}
} }
impl Cluster { fn should_report(&self, status: IMStatusCode) -> bool {
pub fn new(id: ClusterId) -> Result<Cluster, Error> { !self.wildcard
let mut c = Cluster { || !matches!(
id, status,
attributes: Vec::with_capacity(ATTRS_PER_CLUSTER), IMStatusCode::UnsupportedEndpoint
data_ver: rand::thread_rng().gen_range(0..0xFFFFFFFF), | IMStatusCode::UnsupportedCluster
}; | IMStatusCode::UnsupportedAttribute
c.add_default_attributes()?; | IMStatusCode::UnsupportedCommand
Ok(c) | IMStatusCode::UnsupportedAccess
| IMStatusCode::UnsupportedRead
| IMStatusCode::UnsupportedWrite
| IMStatusCode::DataVersionMismatch
)
}
} }
pub fn id(&self) -> ClusterId { #[derive(Debug)]
self.id pub struct CmdDetails<'a> {
pub node: &'a Node<'a>,
pub endpoint_id: EndptId,
pub cluster_id: ClusterId,
pub cmd_id: CmdId,
pub wildcard: bool,
} }
pub fn get_dataver(&self) -> u32 { impl<'a> CmdDetails<'a> {
self.data_ver pub fn path(&self) -> CmdPath {
CmdPath::new(
Some(self.endpoint_id),
Some(self.cluster_id),
Some(self.cmd_id),
)
} }
pub fn set_feature_map(&mut self, map: u32) -> Result<(), Error> { pub fn success(&self, tracker: &CmdDataTracker) -> Option<CmdStatus> {
self.write_attribute_raw(GlobalElements::FeatureMap as u16, AttrValue::Uint32(map)) if tracker.needs_status() {
.map_err(|_| Error::Invalid)?; self.status(IMStatusCode::Success)
Ok(()) } else {
None
}
} }
fn add_default_attributes(&mut self) -> Result<(), Error> { pub fn status(&self, status: IMStatusCode) -> Option<CmdStatus> {
// Default feature map is 0 if self.should_report(status) {
self.add_attribute(Attribute::new( Some(CmdStatus::new(
GlobalElements::FeatureMap as u16, CmdPath::new(
AttrValue::Uint32(0), Some(self.endpoint_id),
Access::RV, Some(self.cluster_id),
Quality::NONE, Some(self.cmd_id),
))?; ),
status,
self.add_attribute(Attribute::new( 0,
GlobalElements::AttributeList as u16,
AttrValue::Custom,
Access::RV,
Quality::NONE,
)) ))
}
pub fn add_attributes(&mut self, attrs: &[Attribute]) -> Result<(), Error> {
if self.attributes.len() + attrs.len() <= self.attributes.capacity() {
self.attributes.extend_from_slice(attrs);
Ok(())
} else { } else {
Err(Error::NoSpace) None
} }
} }
pub fn add_attribute(&mut self, attr: Attribute) -> Result<(), Error> { fn should_report(&self, status: IMStatusCode) -> bool {
if self.attributes.len() < self.attributes.capacity() { !self.wildcard
self.attributes.push(attr); || !matches!(
Ok(()) status,
} else { IMStatusCode::UnsupportedEndpoint
Err(Error::NoSpace) | IMStatusCode::UnsupportedCluster
| IMStatusCode::UnsupportedAttribute
| IMStatusCode::UnsupportedCommand
| IMStatusCode::UnsupportedAccess
| IMStatusCode::UnsupportedRead
| IMStatusCode::UnsupportedWrite
)
} }
} }
fn get_attribute_index(&self, attr_id: AttrId) -> Option<usize> { #[derive(Debug, Clone)]
self.attributes.iter().position(|c| c.id == attr_id) pub struct Cluster<'a> {
pub id: ClusterId,
pub feature_map: u32,
pub attributes: &'a [Attribute],
pub commands: &'a [CmdId],
} }
fn get_attribute(&self, attr_id: AttrId) -> Result<&Attribute, Error> { impl<'a> Cluster<'a> {
let index = self pub const fn new(
.get_attribute_index(attr_id) id: ClusterId,
.ok_or(Error::AttributeNotFound)?; feature_map: u32,
Ok(&self.attributes[index]) attributes: &'a [Attribute],
commands: &'a [CmdId],
) -> Self {
Self {
id,
feature_map,
attributes,
commands,
}
} }
fn get_attribute_mut(&mut self, attr_id: AttrId) -> Result<&mut Attribute, Error> { pub fn match_attributes(
let index = self
.get_attribute_index(attr_id)
.ok_or(Error::AttributeNotFound)?;
Ok(&mut self.attributes[index])
}
// Returns a slice of attribute, with either a single attribute or all (wildcard)
pub fn get_wildcard_attribute(
&self, &self,
attribute: Option<AttrId>, attr: Option<AttrId>,
) -> Result<(&[Attribute], bool), IMStatusCode> { ) -> impl Iterator<Item = &'_ Attribute> + '_ {
if let Some(a) = attribute { self.attributes
if let Some(i) = self.get_attribute_index(a) { .iter()
Ok((&self.attributes[i..i + 1], false)) .filter(move |attribute| attr.map(|attr| attr == attribute.id).unwrap_or(true))
}
pub fn match_commands(&self, cmd: Option<CmdId>) -> impl Iterator<Item = CmdId> + '_ {
self.commands
.iter()
.filter(move |id| cmd.map(|cmd| **id == cmd).unwrap_or(true))
.copied()
}
pub fn check_attribute(
&self,
accessor: &Accessor,
ep: EndptId,
attr: AttrId,
write: bool,
) -> Result<(), IMStatusCode> {
let attribute = self
.attributes
.iter()
.find(|attribute| attribute.id == attr)
.ok_or(IMStatusCode::UnsupportedAttribute)?;
Self::check_attr_access(
accessor,
GenericPath::new(Some(ep), Some(self.id), Some(attr as _)),
write,
attribute.access,
)
}
pub fn check_command(
&self,
accessor: &Accessor,
ep: EndptId,
cmd: CmdId,
) -> Result<(), IMStatusCode> {
self.commands
.iter()
.find(|id| **id == cmd)
.ok_or(IMStatusCode::UnsupportedCommand)?;
Self::check_cmd_access(
accessor,
GenericPath::new(Some(ep), Some(self.id), Some(cmd)),
)
}
pub(crate) fn check_attr_access(
accessor: &Accessor,
path: GenericPath,
write: bool,
target_perms: Access,
) -> Result<(), IMStatusCode> {
let mut access_req = AccessReq::new(
accessor,
path,
if write { Access::WRITE } else { Access::READ },
);
if !target_perms.contains(access_req.operation()) {
Err(if matches!(access_req.operation(), Access::WRITE) {
IMStatusCode::UnsupportedWrite
} else { } else {
Err(IMStatusCode::UnsupportedAttribute) IMStatusCode::UnsupportedRead
})?;
} }
access_req.set_target_perms(target_perms);
if access_req.allow() {
Ok(())
} else { } else {
Ok((&self.attributes[..], true)) Err(IMStatusCode::UnsupportedAccess)
} }
} }
pub fn read_attribute( pub(crate) fn check_cmd_access(
c: &dyn ClusterType, accessor: &Accessor,
access_req: &mut AccessReq, path: GenericPath,
encoder: &mut dyn Encoder, ) -> Result<(), IMStatusCode> {
attr: &AttrDetails, let mut access_req = AccessReq::new(accessor, path, Access::WRITE);
) {
let mut error = IMStatusCode::Success; access_req.set_target_perms(
let base = c.base(); Access::WRITE
let a = if let Ok(a) = base.get_attribute(attr.attr_id) { .union(Access::NEED_OPERATE)
a .union(Access::NEED_MANAGE)
.union(Access::NEED_ADMIN),
); // TODO
if access_req.allow() {
Ok(())
} else { } else {
encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0); Err(IMStatusCode::UnsupportedAccess)
return;
};
if !a.access.contains(Access::READ) {
error = IMStatusCode::UnsupportedRead;
}
access_req.set_target_perms(a.access);
if !access_req.allow() {
error = IMStatusCode::UnsupportedAccess;
}
if error != IMStatusCode::Success {
encoder.encode_status(error, 0);
} else if Attribute::is_system_attr(attr.attr_id) {
c.base().read_system_attribute(encoder, a)
} else if a.value != AttrValue::Custom {
encoder.encode(EncodeValue::Value(&a.value))
} else {
c.read_custom_attribute(encoder, attr)
} }
} }
fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) { pub fn read(&self, attr: AttrId, mut writer: AttrDataWriter) -> Result<(), Error> {
let _ = tw.start_array(tag); match attr.try_into()? {
for a in &self.attributes {
let _ = tw.u16(TagType::Anonymous, a.id);
}
let _ = tw.end_container();
}
fn read_system_attribute(&self, encoder: &mut dyn Encoder, attr: &Attribute) {
let global_attr: Option<GlobalElements> = num::FromPrimitive::from_u16(attr.id);
if let Some(global_attr) = global_attr {
match global_attr {
GlobalElements::AttributeList => { GlobalElements::AttributeList => {
encoder.encode(EncodeValue::Closure(&|tag, tw| { self.encode_attribute_ids(AttrDataWriter::TAG, &mut writer)?;
self.encode_attribute_ids(tag, tw) writer.complete()
}));
return;
} }
GlobalElements::FeatureMap => { GlobalElements::FeatureMap => writer.set(self.feature_map),
encoder.encode(EncodeValue::Value(&attr.value)); other => {
return; error!("This attribute is not yet handled {:?}", other);
Err(ErrorCode::AttributeNotFound.into())
} }
_ => {
error!("This attribute not yet handled {:?}", global_attr);
}
}
}
encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0)
}
pub fn read_attribute_raw(&self, attr_id: AttrId) -> Result<&AttrValue, IMStatusCode> {
let a = self
.get_attribute(attr_id)
.map_err(|_| IMStatusCode::UnsupportedAttribute)?;
Ok(&a.value)
}
pub fn write_attribute(
c: &mut dyn ClusterType,
access_req: &mut AccessReq,
data: &TLVElement,
attr: &AttrDetails,
) -> Result<(), IMStatusCode> {
let base = c.base_mut();
let a = if let Ok(a) = base.get_attribute_mut(attr.attr_id) {
a
} else {
return Err(IMStatusCode::UnsupportedAttribute);
};
if !a.access.contains(Access::WRITE) {
return Err(IMStatusCode::UnsupportedWrite);
}
access_req.set_target_perms(a.access);
if !access_req.allow() {
return Err(IMStatusCode::UnsupportedAccess);
}
c.write_attribute(attr, data)
}
pub fn write_attribute_from_tlv(
&mut self,
attr_id: AttrId,
data: &TLVElement,
) -> Result<(), IMStatusCode> {
let a = self.get_attribute_mut(attr_id)?;
if a.value != AttrValue::Custom {
let mut value = a.value.clone();
value
.update_from_tlv(data)
.map_err(|_| IMStatusCode::Failure)?;
a.set_value(value)
.map(|_| {
self.cluster_changed();
})
.map_err(|_| IMStatusCode::UnsupportedWrite)
} else {
Err(IMStatusCode::UnsupportedAttribute)
} }
} }
pub fn write_attribute_raw(&mut self, attr_id: AttrId, value: AttrValue) -> Result<(), Error> { fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) -> Result<(), Error> {
let a = self.get_attribute_mut(attr_id)?; tw.start_array(tag)?;
a.set_value(value).map(|_| { for a in self.attributes {
self.cluster_changed(); tw.u16(TagType::Anonymous, a.id)?;
})
} }
/// This method must be called for any changes to the data model tw.end_container()
/// Currently this only increments the data version, but we can reuse the same
/// for raising events too
pub fn cluster_changed(&mut self) {
self.data_ver = self.data_ver.wrapping_add(1);
} }
} }
impl std::fmt::Display for Cluster { impl<'a> core::fmt::Display for Cluster<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "id:{}, ", self.id)?; write!(f, "id:{}, ", self.id)?;
write!(f, "attrs[")?; write!(f, "attrs[")?;

View file

@ -0,0 +1,55 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::utils::rand::Rand;
pub struct Dataver {
ver: u32,
changed: bool,
}
impl Dataver {
pub fn new(rand: Rand) -> Self {
let mut buf = [0; 4];
rand(&mut buf);
Self {
ver: u32::from_be_bytes(buf),
changed: false,
}
}
pub fn get(&self) -> u32 {
self.ver
}
pub fn changed(&mut self) -> u32 {
(self.ver, _) = self.ver.overflowing_add(1);
self.changed = true;
self.get()
}
pub fn consume_change<T>(&mut self, change: T) -> Option<T> {
if self.changed {
self.changed = false;
Some(change)
} else {
None
}
}
}

View file

@ -15,21 +15,31 @@
* limitations under the License. * limitations under the License.
*/ */
use std::fmt::{Debug, Formatter}; use core::fmt::{Debug, Formatter};
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
use crate::interaction_model::core::{IMStatusCode, Transaction};
use crate::interaction_model::messages::ib::{
AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag,
};
use crate::interaction_model::messages::GenericPath;
use crate::tlv::UtfStr;
use crate::{ use crate::{
error::Error, error::{Error, ErrorCode},
interaction_model::core::IMStatusCode, interaction_model::messages::ib::{AttrDataTag, AttrRespTag},
tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV},
}; };
use log::error; use log::error;
use super::{AttrDetails, CmdDetails, Handler};
// TODO: Should this return an IMStatusCode Error? But if yes, the higher layer // TODO: Should this return an IMStatusCode Error? But if yes, the higher layer
// may have already started encoding the 'success' headers, we might not to manage // may have already started encoding the 'success' headers, we might not want to manage
// the tw.rewind() in that case, if we add this support // the tw.rewind() in that case, if we add this support
pub type EncodeValueGen<'a> = &'a dyn Fn(TagType, &mut TLVWriter); pub type EncodeValueGen<'a> = &'a dyn Fn(TagType, &mut TLVWriter);
#[derive(Copy, Clone)] #[derive(Clone)]
/// A structure for encoding various types of values /// A structure for encoding various types of values
pub enum EncodeValue<'a> { pub enum EncodeValue<'a> {
/// This indicates a value that is dynamically generated. This variant /// This indicates a value that is dynamically generated. This variant
@ -56,13 +66,13 @@ impl<'a> EncodeValue<'a> {
impl<'a> PartialEq for EncodeValue<'a> { impl<'a> PartialEq for EncodeValue<'a> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
match *self { match self {
EncodeValue::Closure(_) => { EncodeValue::Closure(_) => {
error!("PartialEq not yet supported"); error!("PartialEq not yet supported");
false false
} }
EncodeValue::Tlv(a) => { EncodeValue::Tlv(a) => {
if let EncodeValue::Tlv(b) = *other { if let EncodeValue::Tlv(b) = other {
a == b a == b
} else { } else {
false false
@ -78,8 +88,8 @@ impl<'a> PartialEq for EncodeValue<'a> {
} }
impl<'a> Debug for EncodeValue<'a> { impl<'a> Debug for EncodeValue<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
match *self { match self {
EncodeValue::Closure(_) => write!(f, "Contains closure"), EncodeValue::Closure(_) => write!(f, "Contains closure"),
EncodeValue::Tlv(t) => write!(f, "{:?}", t), EncodeValue::Tlv(t) => write!(f, "{:?}", t),
EncodeValue::Value(_) => write!(f, "Contains EncodeValue"), EncodeValue::Value(_) => write!(f, "Contains EncodeValue"),
@ -103,21 +113,477 @@ impl<'a> ToTLV for EncodeValue<'a> {
impl<'a> FromTLV<'a> for EncodeValue<'a> { impl<'a> FromTLV<'a> for EncodeValue<'a> {
fn from_tlv(data: &TLVElement<'a>) -> Result<Self, Error> { fn from_tlv(data: &TLVElement<'a>) -> Result<Self, Error> {
Ok(EncodeValue::Tlv(*data)) Ok(EncodeValue::Tlv(data.clone()))
} }
} }
/// An object that can encode EncodeValue into the necessary hierarchical structure pub struct AttrDataEncoder<'a, 'b, 'c> {
/// as expected by the Interaction Model dataver_filter: Option<u32>,
pub trait Encoder { path: AttrPath,
/// Encode a given value tw: &'a mut TLVWriter<'b, 'c>,
fn encode(&mut self, value: EncodeValue);
/// Encode a status report
fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16);
} }
#[derive(ToTLV, Copy, Clone)] impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> {
pub struct DeviceType { pub fn handle_read<T: Handler>(
pub dtype: u16, item: Result<AttrDetails, AttrStatus>,
pub drev: u16, handler: &T,
tw: &mut TLVWriter,
) -> Result<Option<GenericPath>, Error> {
let status = match item {
Ok(attr) => {
let encoder = AttrDataEncoder::new(&attr, tw);
match handler.read(&attr, encoder) {
Ok(()) => None,
Err(e) => {
if e.code() == ErrorCode::NoSpace {
return Ok(Some(attr.path().to_gp()));
} else {
attr.status(e.into())?
}
}
}
}
Err(status) => Some(status),
};
if let Some(status) = status {
AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(None)
}
pub fn handle_write<T: Handler>(
item: Result<(AttrDetails, TLVElement), AttrStatus>,
handler: &mut T,
tw: &mut TLVWriter,
) -> Result<(), Error> {
let status = match item {
Ok((attr, data)) => match handler.write(&attr, AttrData::new(attr.dataver, &data)) {
Ok(()) => attr.status(IMStatusCode::Success)?,
Err(error) => attr.status(error.into())?,
},
Err(status) => Some(status),
};
if let Some(status) = status {
status.to_tlv(tw, TagType::Anonymous)?;
}
Ok(())
}
#[cfg(feature = "nightly")]
pub async fn handle_read_async<T: super::asynch::AsyncHandler>(
item: Result<AttrDetails<'_>, AttrStatus>,
handler: &T,
tw: &mut TLVWriter<'_, '_>,
) -> Result<Option<GenericPath>, Error> {
let status = match item {
Ok(attr) => {
let encoder = AttrDataEncoder::new(&attr, tw);
match handler.read(&attr, encoder).await {
Ok(()) => None,
Err(e) => {
if e.code() == ErrorCode::NoSpace {
return Ok(Some(attr.path().to_gp()));
} else {
attr.status(e.into())?
}
}
}
}
Err(status) => Some(status),
};
if let Some(status) = status {
AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(None)
}
#[cfg(feature = "nightly")]
pub async fn handle_write_async<T: super::asynch::AsyncHandler>(
item: Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>,
handler: &mut T,
tw: &mut TLVWriter<'_, '_>,
) -> Result<(), Error> {
let status = match item {
Ok((attr, data)) => match handler
.write(&attr, AttrData::new(attr.dataver, &data))
.await
{
Ok(()) => attr.status(IMStatusCode::Success)?,
Err(error) => attr.status(error.into())?,
},
Err(status) => Some(status),
};
if let Some(status) = status {
AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(())
}
pub fn new(attr: &AttrDetails, tw: &'a mut TLVWriter<'b, 'c>) -> Self {
Self {
dataver_filter: attr.dataver,
path: attr.path(),
tw,
}
}
pub fn with_dataver(self, dataver: u32) -> Result<Option<AttrDataWriter<'a, 'b, 'c>>, Error> {
if self
.dataver_filter
.map(|dataver_filter| dataver_filter != dataver)
.unwrap_or(true)
{
let mut writer = AttrDataWriter::new(self.tw);
writer.start_struct(TagType::Anonymous)?;
writer.start_struct(TagType::Context(AttrRespTag::Data as _))?;
writer.u32(TagType::Context(AttrDataTag::DataVer as _), dataver)?;
self.path
.to_tlv(&mut writer, TagType::Context(AttrDataTag::Path as _))?;
Ok(Some(writer))
} else {
Ok(None)
}
}
}
pub struct AttrDataWriter<'a, 'b, 'c> {
tw: &'a mut TLVWriter<'b, 'c>,
anchor: usize,
completed: bool,
}
impl<'a, 'b, 'c> AttrDataWriter<'a, 'b, 'c> {
pub const TAG: TagType = TagType::Context(AttrDataTag::Data as _);
fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self {
let anchor = tw.get_tail();
Self {
tw,
anchor,
completed: false,
}
}
pub fn set<T: ToTLV>(self, value: T) -> Result<(), Error> {
value.to_tlv(self.tw, Self::TAG)?;
self.complete()
}
pub fn complete(mut self) -> Result<(), Error> {
self.tw.end_container()?;
self.tw.end_container()?;
self.completed = true;
Ok(())
}
fn reset(&mut self) {
self.tw.rewind_to(self.anchor);
}
}
impl<'a, 'b, 'c> Drop for AttrDataWriter<'a, 'b, 'c> {
fn drop(&mut self) {
if !self.completed {
self.reset();
}
}
}
impl<'a, 'b, 'c> Deref for AttrDataWriter<'a, 'b, 'c> {
type Target = TLVWriter<'b, 'c>;
fn deref(&self) -> &Self::Target {
self.tw
}
}
impl<'a, 'b, 'c> DerefMut for AttrDataWriter<'a, 'b, 'c> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.tw
}
}
pub struct AttrData<'a> {
for_dataver: Option<u32>,
data: &'a TLVElement<'a>,
}
impl<'a> AttrData<'a> {
pub fn new(for_dataver: Option<u32>, data: &'a TLVElement<'a>) -> Self {
Self { for_dataver, data }
}
pub fn with_dataver(self, dataver: u32) -> Result<&'a TLVElement<'a>, Error> {
if let Some(req_dataver) = self.for_dataver {
if req_dataver != dataver {
Err(ErrorCode::DataVersionMismatch)?;
}
}
Ok(self.data)
}
}
#[derive(Default)]
pub struct CmdDataTracker {
skip_status: bool,
}
impl CmdDataTracker {
pub const fn new() -> Self {
Self { skip_status: false }
}
pub(crate) fn complete(&mut self) {
self.skip_status = true;
}
pub fn needs_status(&self) -> bool {
!self.skip_status
}
}
pub struct CmdDataEncoder<'a, 'b, 'c> {
tracker: &'a mut CmdDataTracker,
path: CmdPath,
tw: &'a mut TLVWriter<'b, 'c>,
}
impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> {
pub fn handle<T: Handler>(
item: Result<(CmdDetails, TLVElement), CmdStatus>,
handler: &mut T,
transaction: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error> {
let status = match item {
Ok((cmd, data)) => {
let mut tracker = CmdDataTracker::new();
let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw);
match handler.invoke(transaction, &cmd, &data, encoder) {
Ok(()) => cmd.success(&tracker),
Err(error) => {
error!("Error invoking command: {}", error);
cmd.status(error.into())
}
}
}
Err(status) => {
error!("Error invoking command: {:?}", status);
Some(status)
}
};
if let Some(status) = status {
InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(())
}
#[cfg(feature = "nightly")]
pub async fn handle_async<T: super::asynch::AsyncHandler>(
item: Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>,
handler: &mut T,
transaction: &mut Transaction<'_, '_>,
tw: &mut TLVWriter<'_, '_>,
) -> Result<(), Error> {
let status = match item {
Ok((cmd, data)) => {
let mut tracker = CmdDataTracker::new();
let encoder = CmdDataEncoder::new(&cmd, &mut tracker, tw);
match handler.invoke(transaction, &cmd, &data, encoder).await {
Ok(()) => cmd.success(&tracker),
Err(error) => cmd.status(error.into()),
}
}
Err(status) => Some(status),
};
if let Some(status) = status {
InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(())
}
pub fn new(
cmd: &CmdDetails,
tracker: &'a mut CmdDataTracker,
tw: &'a mut TLVWriter<'b, 'c>,
) -> Self {
Self {
tracker,
path: cmd.path(),
tw,
}
}
pub fn with_command(mut self, cmd: u16) -> Result<CmdDataWriter<'a, 'b, 'c>, Error> {
let mut writer = CmdDataWriter::new(self.tracker, self.tw);
writer.start_struct(TagType::Anonymous)?;
writer.start_struct(TagType::Context(InvRespTag::Cmd as _))?;
self.path.path.leaf = Some(cmd as _);
self.path
.to_tlv(&mut writer, TagType::Context(CmdDataTag::Path as _))?;
Ok(writer)
}
}
pub struct CmdDataWriter<'a, 'b, 'c> {
tracker: &'a mut CmdDataTracker,
tw: &'a mut TLVWriter<'b, 'c>,
anchor: usize,
completed: bool,
}
impl<'a, 'b, 'c> CmdDataWriter<'a, 'b, 'c> {
pub const TAG: TagType = TagType::Context(CmdDataTag::Data as _);
fn new(tracker: &'a mut CmdDataTracker, tw: &'a mut TLVWriter<'b, 'c>) -> Self {
let anchor = tw.get_tail();
Self {
tracker,
tw,
anchor,
completed: false,
}
}
pub fn set<T: ToTLV>(self, value: T) -> Result<(), Error> {
value.to_tlv(self.tw, Self::TAG)?;
self.complete()
}
pub fn complete(mut self) -> Result<(), Error> {
self.tw.end_container()?;
self.tw.end_container()?;
self.completed = true;
self.tracker.complete();
Ok(())
}
fn reset(&mut self) {
self.tw.rewind_to(self.anchor);
}
}
impl<'a, 'b, 'c> Drop for CmdDataWriter<'a, 'b, 'c> {
fn drop(&mut self) {
if !self.completed {
self.reset();
}
}
}
impl<'a, 'b, 'c> Deref for CmdDataWriter<'a, 'b, 'c> {
type Target = TLVWriter<'b, 'c>;
fn deref(&self) -> &Self::Target {
self.tw
}
}
impl<'a, 'b, 'c> DerefMut for CmdDataWriter<'a, 'b, 'c> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.tw
}
}
#[derive(Copy, Clone, Debug)]
pub struct AttrType<T>(PhantomData<fn() -> T>);
impl<T> AttrType<T> {
pub const fn new() -> Self {
Self(PhantomData)
}
pub fn encode(&self, writer: AttrDataWriter, value: T) -> Result<(), Error>
where
T: ToTLV,
{
writer.set(value)
}
pub fn decode<'a>(&self, data: &'a TLVElement) -> Result<T, Error>
where
T: FromTLV<'a>,
{
T::from_tlv(data)
}
}
impl<T> Default for AttrType<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct AttrUtfType;
impl AttrUtfType {
pub const fn new() -> Self {
Self
}
pub fn encode(&self, writer: AttrDataWriter, value: &str) -> Result<(), Error> {
writer.set(UtfStr::new(value.as_bytes()))
}
pub fn decode<'a>(&self, data: &'a TLVElement) -> Result<&'a str, IMStatusCode> {
data.str().map_err(|_| IMStatusCode::InvalidDataType)
}
}
#[allow(unused_macros)]
#[macro_export]
macro_rules! attribute_enum {
($en:ty) => {
impl core::convert::TryFrom<$crate::data_model::objects::AttrId> for $en {
type Error = $crate::error::Error;
fn try_from(id: $crate::data_model::objects::AttrId) -> Result<Self, Self::Error> {
<$en>::from_repr(id)
.ok_or_else(|| $crate::error::ErrorCode::AttributeNotFound.into())
}
}
};
}
#[allow(unused_macros)]
#[macro_export]
macro_rules! command_enum {
($en:ty) => {
impl core::convert::TryFrom<$crate::data_model::objects::CmdId> for $en {
type Error = $crate::error::Error;
fn try_from(id: $crate::data_model::objects::CmdId) -> Result<Self, Self::Error> {
<$en>::from_repr(id).ok_or_else(|| $crate::error::ErrorCode::CommandNotFound.into())
}
}
};
} }

View file

@ -15,104 +15,85 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::{data_model::objects::ClusterType, error::*, interaction_model::core::IMStatusCode}; use crate::{acl::Accessor, interaction_model::core::IMStatusCode};
use std::fmt; use core::fmt;
use super::{ClusterId, DeviceType}; use super::{AttrId, Attribute, Cluster, ClusterId, CmdId, DeviceType, EndptId};
pub const CLUSTERS_PER_ENDPT: usize = 9; #[derive(Debug, Clone)]
pub struct Endpoint<'a> {
pub struct Endpoint { pub id: EndptId,
dev_type: DeviceType, pub device_type: DeviceType,
clusters: Vec<Box<dyn ClusterType>>, pub clusters: &'a [Cluster<'a>],
} }
pub type BoxedClusters = [Box<dyn ClusterType>]; impl<'a> Endpoint<'a> {
pub fn match_attributes(
impl Endpoint {
pub fn new(dev_type: DeviceType) -> Result<Box<Endpoint>, Error> {
Ok(Box::new(Endpoint {
dev_type,
clusters: Vec::with_capacity(CLUSTERS_PER_ENDPT),
}))
}
pub fn add_cluster(&mut self, cluster: Box<dyn ClusterType>) -> Result<(), Error> {
if self.clusters.len() < self.clusters.capacity() {
self.clusters.push(cluster);
Ok(())
} else {
Err(Error::NoSpace)
}
}
pub fn get_dev_type(&self) -> &DeviceType {
&self.dev_type
}
fn get_cluster_index(&self, cluster_id: ClusterId) -> Option<usize> {
self.clusters.iter().position(|c| c.base().id == cluster_id)
}
pub fn get_cluster(&self, cluster_id: ClusterId) -> Result<&dyn ClusterType, Error> {
let index = self
.get_cluster_index(cluster_id)
.ok_or(Error::ClusterNotFound)?;
Ok(self.clusters[index].as_ref())
}
pub fn get_cluster_mut(
&mut self,
cluster_id: ClusterId,
) -> Result<&mut dyn ClusterType, Error> {
let index = self
.get_cluster_index(cluster_id)
.ok_or(Error::ClusterNotFound)?;
Ok(self.clusters[index].as_mut())
}
// Returns a slice of clusters, with either a single cluster or all (wildcard)
pub fn get_wildcard_clusters(
&self, &self,
cluster: Option<ClusterId>, cl: Option<ClusterId>,
) -> Result<(&BoxedClusters, bool), IMStatusCode> { attr: Option<AttrId>,
if let Some(c) = cluster { ) -> impl Iterator<Item = (&'_ Cluster, &'_ Attribute)> + '_ {
if let Some(i) = self.get_cluster_index(c) { self.match_clusters(cl).flat_map(move |cluster| {
Ok((&self.clusters[i..i + 1], false)) cluster
} else { .match_attributes(attr)
Err(IMStatusCode::UnsupportedCluster) .map(move |attr| (cluster, attr))
})
} }
} else {
Ok((self.clusters.as_slice(), true)) pub fn match_commands(
&self,
cl: Option<ClusterId>,
cmd: Option<CmdId>,
) -> impl Iterator<Item = (&'_ Cluster, CmdId)> + '_ {
self.match_clusters(cl)
.flat_map(move |cluster| cluster.match_commands(cmd).map(move |cmd| (cluster, cmd)))
}
pub fn check_attribute(
&self,
accessor: &Accessor,
cl: ClusterId,
attr: AttrId,
write: bool,
) -> Result<(), IMStatusCode> {
self.check_cluster(cl)
.and_then(|cluster| cluster.check_attribute(accessor, self.id, attr, write))
}
pub fn check_command(
&self,
accessor: &Accessor,
cl: ClusterId,
cmd: CmdId,
) -> Result<(), IMStatusCode> {
self.check_cluster(cl)
.and_then(|cluster| cluster.check_command(accessor, self.id, cmd))
}
pub fn match_clusters(&self, cl: Option<ClusterId>) -> impl Iterator<Item = &'_ Cluster> + '_ {
self.clusters
.iter()
.filter(move |cluster| cl.map(|id| id == cluster.id).unwrap_or(true))
}
pub fn check_cluster(&self, cl: ClusterId) -> Result<&Cluster, IMStatusCode> {
self.clusters
.iter()
.find(|cluster| cluster.id == cl)
.ok_or(IMStatusCode::UnsupportedCluster)
} }
} }
// Returns a slice of clusters, with either a single cluster or all (wildcard) impl<'a> core::fmt::Display for Endpoint<'a> {
pub fn get_wildcard_clusters_mut(
&mut self,
cluster: Option<ClusterId>,
) -> Result<(&mut BoxedClusters, bool), IMStatusCode> {
if let Some(c) = cluster {
if let Some(i) = self.get_cluster_index(c) {
Ok((&mut self.clusters[i..i + 1], false))
} else {
Err(IMStatusCode::UnsupportedCluster)
}
} else {
Ok((&mut self.clusters[..], true))
}
}
}
impl std::fmt::Display for Endpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "clusters:[")?; write!(f, "clusters:[")?;
let mut comma = ""; let mut comma = "";
for element in self.clusters.iter() { for cluster in self.clusters {
write!(f, "{} {{ {} }}", comma, element.base())?; write!(f, "{} {{ {} }}", comma, cluster)?;
comma = ", "; comma = ", ";
} }
write!(f, "]") write!(f, "]")
} }
} }

View file

@ -0,0 +1,361 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::{
error::{Error, ErrorCode},
interaction_model::core::Transaction,
tlv::TLVElement,
};
use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails};
pub trait ChangeNotifier<T> {
fn consume_change(&mut self) -> Option<T>;
}
pub trait Handler {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>;
fn write(&mut self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into())
}
fn invoke(
&mut self,
_transaction: &mut Transaction,
_cmd: &CmdDetails,
_data: &TLVElement,
_encoder: CmdDataEncoder,
) -> Result<(), Error> {
Err(ErrorCode::CommandNotFound.into())
}
}
impl<T> Handler for &mut T
where
T: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
(**self).read(attr, encoder)
}
fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
(**self).write(attr, data)
}
fn invoke(
&mut self,
transaction: &mut Transaction,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
(**self).invoke(transaction, cmd, data, encoder)
}
}
pub trait NonBlockingHandler: Handler {}
impl<T> NonBlockingHandler for &mut T where T: NonBlockingHandler {}
pub struct EmptyHandler;
impl EmptyHandler {
pub const fn chain<H>(
self,
handler_endpoint: u16,
handler_cluster: u32,
handler: H,
) -> ChainedHandler<H, Self> {
ChainedHandler {
handler_endpoint,
handler_cluster,
handler,
next: self,
}
}
}
impl Handler for EmptyHandler {
fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into())
}
}
impl NonBlockingHandler for EmptyHandler {}
impl ChangeNotifier<(u16, u32)> for EmptyHandler {
fn consume_change(&mut self) -> Option<(u16, u32)> {
None
}
}
pub struct ChainedHandler<H, T> {
pub handler_endpoint: u16,
pub handler_cluster: u32,
pub handler: H,
pub next: T,
}
impl<H, T> ChainedHandler<H, T> {
pub const fn chain<H2>(
self,
handler_endpoint: u16,
handler_cluster: u32,
handler: H2,
) -> ChainedHandler<H2, Self> {
ChainedHandler {
handler_endpoint,
handler_cluster,
handler,
next: self,
}
}
}
impl<H, T> Handler for ChainedHandler<H, T>
where
H: Handler,
T: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id {
self.handler.read(attr, encoder)
} else {
self.next.read(attr, encoder)
}
}
fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id {
self.handler.write(attr, data)
} else {
self.next.write(attr, data)
}
}
fn invoke(
&mut self,
transaction: &mut Transaction,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id {
self.handler.invoke(transaction, cmd, data, encoder)
} else {
self.next.invoke(transaction, cmd, data, encoder)
}
}
}
impl<H, T> NonBlockingHandler for ChainedHandler<H, T>
where
H: NonBlockingHandler,
T: NonBlockingHandler,
{
}
impl<H, T> ChangeNotifier<(u16, u32)> for ChainedHandler<H, T>
where
H: ChangeNotifier<()>,
T: ChangeNotifier<(u16, u32)>,
{
fn consume_change(&mut self) -> Option<(u16, u32)> {
if self.handler.consume_change().is_some() {
Some((self.handler_endpoint, self.handler_cluster))
} else {
self.next.consume_change()
}
}
}
#[allow(unused_macros)]
#[macro_export]
macro_rules! handler_chain_type {
($h:ty) => {
$crate::data_model::objects::ChainedHandler<$h, $crate::data_model::objects::EmptyHandler>
};
($h1:ty $(, $rest:ty)+) => {
$crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+)>
};
($h:ty | $f:ty) => {
$crate::data_model::objects::ChainedHandler<$h, $f>
};
($h1:ty $(, $rest:ty)+ | $f:ty) => {
$crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+ | $f)>
};
}
#[cfg(feature = "nightly")]
pub mod asynch {
use crate::{
data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails},
error::{Error, ErrorCode},
interaction_model::core::Transaction,
tlv::TLVElement,
};
use super::{ChainedHandler, EmptyHandler, Handler, NonBlockingHandler};
pub trait AsyncHandler {
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error>;
async fn write<'a>(
&'a mut self,
_attr: &'a AttrDetails<'_>,
_data: AttrData<'a>,
) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into())
}
async fn invoke<'a>(
&'a mut self,
_transaction: &'a mut Transaction<'_, '_>,
_cmd: &'a CmdDetails<'_>,
_data: &'a TLVElement<'_>,
_encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
Err(ErrorCode::CommandNotFound.into())
}
}
impl<T> AsyncHandler for &mut T
where
T: AsyncHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
(**self).read(attr, encoder).await
}
async fn write<'a>(
&'a mut self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
(**self).write(attr, data).await
}
async fn invoke<'a>(
&'a mut self,
transaction: &'a mut Transaction<'_, '_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
(**self).invoke(transaction, cmd, data, encoder).await
}
}
pub struct Asyncify<T>(pub T);
impl<T> AsyncHandler for Asyncify<T>
where
T: NonBlockingHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
Handler::read(&self.0, attr, encoder)
}
async fn write<'a>(
&'a mut self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
Handler::write(&mut self.0, attr, data)
}
async fn invoke<'a>(
&'a mut self,
transaction: &'a mut Transaction<'_, '_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
Handler::invoke(&mut self.0, transaction, cmd, data, encoder)
}
}
impl AsyncHandler for EmptyHandler {
async fn read<'a>(
&'a self,
_attr: &'a AttrDetails<'_>,
_encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into())
}
}
impl<H, T> AsyncHandler for ChainedHandler<H, T>
where
H: AsyncHandler,
T: AsyncHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id
{
self.handler.read(attr, encoder).await
} else {
self.next.read(attr, encoder).await
}
}
async fn write<'a>(
&'a mut self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id
{
self.handler.write(attr, data).await
} else {
self.next.write(attr, data).await
}
}
async fn invoke<'a>(
&'a mut self,
transaction: &'a mut Transaction<'_, '_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id {
self.handler.invoke(transaction, cmd, data, encoder).await
} else {
self.next.invoke(transaction, cmd, data, encoder).await
}
}
}
}

View file

@ -14,11 +14,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error;
pub type EndptId = u16; use crate::tlv::{TLVWriter, TagType, ToTLV};
pub type ClusterId = u32;
pub type AttrId = u16;
pub type CmdId = u32;
mod attribute; mod attribute;
pub use attribute::*; pub use attribute::*;
@ -37,3 +34,20 @@ pub use privilege::*;
mod encoder; mod encoder;
pub use encoder::*; pub use encoder::*;
mod handler;
pub use handler::*;
mod dataver;
pub use dataver::*;
pub type EndptId = u16;
pub type ClusterId = u32;
pub type AttrId = u16;
pub type CmdId = u32;
#[derive(Debug, ToTLV, Copy, Clone)]
pub struct DeviceType {
pub dtype: u16,
pub drev: u16,
}

View file

@ -16,283 +16,515 @@
*/ */
use crate::{ use crate::{
data_model::objects::{ClusterType, Endpoint}, acl::Accessor,
error::*, data_model::objects::Endpoint,
interaction_model::{core::IMStatusCode, messages::GenericPath}, interaction_model::{
core::{IMStatusCode, ResumeReadReq, ResumeSubscribeReq},
messages::{
ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter},
msg::{InvReq, ReadReq, SubscribeReq, WriteReq},
GenericPath,
},
},
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{TLVArray, TLVArrayIter, TLVElement},
};
use core::{
fmt,
iter::{once, Once},
}; };
use std::fmt;
use super::{ClusterId, DeviceType, EndptId}; use super::{AttrDetails, AttrId, Attribute, Cluster, ClusterId, CmdDetails, CmdId, EndptId};
pub trait ChangeConsumer { pub enum WildcardIter<T, E> {
fn endpoint_added(&self, id: EndptId, endpoint: &mut Endpoint) -> Result<(), Error>; None,
Single(Once<E>),
Wildcard(T),
} }
pub const ENDPTS_PER_ACC: usize = 3; impl<T, E> Iterator for WildcardIter<T, E>
where
T: Iterator<Item = E>,
{
type Item = E;
pub type BoxedEndpoints = [Option<Box<Endpoint>>]; fn next(&mut self) -> Option<Self::Item> {
match self {
#[derive(Default)] Self::None => None,
pub struct Node { Self::Single(iter) => iter.next(),
endpoints: [Option<Box<Endpoint>>; ENDPTS_PER_ACC], Self::Wildcard(iter) => iter.next(),
changes_cb: Option<Box<dyn ChangeConsumer>>, }
}
} }
impl std::fmt::Display for Node { pub trait Iterable {
type Item;
type Iterator<'a>: Iterator<Item = Self::Item>
where
Self: 'a;
fn iter(&self) -> Self::Iterator<'_>;
}
impl<'a> Iterable for Option<&'a TLVArray<'a, DataVersionFilter>> {
type Item = DataVersionFilter;
type Iterator<'i> = WildcardIter<TLVArrayIter<'i, DataVersionFilter>, DataVersionFilter> where Self: 'i;
fn iter(&self) -> Self::Iterator<'_> {
if let Some(filters) = self {
WildcardIter::Wildcard(filters.iter())
} else {
WildcardIter::None
}
}
}
impl<'a> Iterable for &'a [DataVersionFilter] {
type Item = DataVersionFilter;
type Iterator<'i> = core::iter::Cloned<core::slice::Iter<'i, DataVersionFilter>> where Self: 'i;
fn iter(&self) -> Self::Iterator<'_> {
let slice: &[DataVersionFilter] = self;
slice.iter().cloned()
}
}
#[derive(Debug, Clone)]
pub struct Node<'a> {
pub id: u16,
pub endpoints: &'a [Endpoint<'a>],
}
impl<'a> Node<'a> {
pub fn read<'s, 'm>(
&'s self,
req: &'m ReadReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
{
self.read_attr_requests(
req.attr_requests
.iter()
.flat_map(|attr_requests| attr_requests.iter()),
req.dataver_filters.as_ref(),
req.fabric_filtered,
accessor,
None,
)
}
pub fn resume_read<'s, 'm>(
&'s self,
req: &'m ResumeReadReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
{
self.read_attr_requests(
req.paths.iter().cloned(),
req.filters.as_slice(),
req.fabric_filtered,
accessor,
Some(req.resume_path.clone()),
)
}
pub fn subscribing_read<'s, 'm>(
&'s self,
req: &'m SubscribeReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
{
self.read_attr_requests(
req.attr_requests
.iter()
.flat_map(|attr_requests| attr_requests.iter()),
req.dataver_filters.as_ref(),
req.fabric_filtered,
accessor,
None,
)
}
pub fn resume_subscribing_read<'s, 'm>(
&'s self,
req: &'m ResumeSubscribeReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
{
self.read_attr_requests(
req.paths.iter().cloned(),
req.filters.as_slice(),
req.fabric_filtered,
accessor,
Some(req.resume_path.clone().unwrap()),
)
}
fn read_attr_requests<'s, 'm, P, D>(
&'s self,
attr_requests: P,
dataver_filters: D,
fabric_filtered: bool,
accessor: &'m Accessor<'m>,
from: Option<GenericPath>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
P: Iterator<Item = AttrPath> + 'm,
D: Iterable<Item = DataVersionFilter> + Clone + 'm,
{
attr_requests.flat_map(move |path| {
if path.to_gp().is_wildcard() {
let dataver_filters = dataver_filters.clone();
let from = from.clone();
let iter = self
.match_attributes(path.endpoint, path.cluster, path.attr)
.skip_while(move |(ep, cl, attr)| {
!Self::matches(from.as_ref(), ep.id, cl.id, attr.id as _)
})
.filter(move |(ep, cl, attr)| {
Cluster::check_attr_access(
accessor,
GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)),
false,
attr.access,
)
.is_ok()
})
.map(move |(ep, cl, attr)| {
let dataver = dataver_filters.iter().find_map(|filter| {
(filter.path.endpoint == ep.id && filter.path.cluster == cl.id)
.then_some(filter.data_ver)
});
Ok(AttrDetails {
node: self,
endpoint_id: ep.id,
cluster_id: cl.id,
attr_id: attr.id,
list_index: path.list_index,
fab_idx: accessor.fab_idx,
fab_filter: fabric_filtered,
dataver,
wildcard: true,
})
});
WildcardIter::Wildcard(iter)
} else {
let ep = path.endpoint.unwrap();
let cl = path.cluster.unwrap();
let attr = path.attr.unwrap();
let result = match self.check_attribute(accessor, ep, cl, attr, false) {
Ok(()) => {
let dataver = dataver_filters.iter().find_map(|filter| {
(filter.path.endpoint == ep && filter.path.cluster == cl)
.then_some(filter.data_ver)
});
Ok(AttrDetails {
node: self,
endpoint_id: ep,
cluster_id: cl,
attr_id: attr,
list_index: path.list_index,
fab_idx: accessor.fab_idx,
fab_filter: fabric_filtered,
dataver,
wildcard: false,
})
}
Err(err) => Err(AttrStatus::new(&path.to_gp(), err, 0)),
};
WildcardIter::Single(once(result))
}
})
}
pub fn write<'m>(
&'m self,
req: &'m WriteReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<(AttrDetails, TLVElement<'m>), AttrStatus>> + 'm {
req.write_requests.iter().flat_map(move |attr_data| {
if attr_data.path.cluster.is_none() {
WildcardIter::Single(once(Err(AttrStatus::new(
&attr_data.path.to_gp(),
IMStatusCode::UnsupportedCluster,
0,
))))
} else if attr_data.path.attr.is_none() {
WildcardIter::Single(once(Err(AttrStatus::new(
&attr_data.path.to_gp(),
IMStatusCode::UnsupportedAttribute,
0,
))))
} else if attr_data.path.to_gp().is_wildcard() {
let iter = self
.match_attributes(
attr_data.path.endpoint,
attr_data.path.cluster,
attr_data.path.attr,
)
.filter(move |(ep, cl, attr)| {
Cluster::check_attr_access(
accessor,
GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)),
true,
attr.access,
)
.is_ok()
})
.map(move |(ep, cl, attr)| {
Ok((
AttrDetails {
node: self,
endpoint_id: ep.id,
cluster_id: cl.id,
attr_id: attr.id,
list_index: attr_data.path.list_index,
fab_idx: accessor.fab_idx,
fab_filter: false,
dataver: attr_data.data_ver,
wildcard: true,
},
attr_data.data.clone().unwrap_tlv().unwrap(),
))
});
WildcardIter::Wildcard(iter)
} else {
let ep = attr_data.path.endpoint.unwrap();
let cl = attr_data.path.cluster.unwrap();
let attr = attr_data.path.attr.unwrap();
let result = match self.check_attribute(accessor, ep, cl, attr, true) {
Ok(()) => Ok((
AttrDetails {
node: self,
endpoint_id: ep,
cluster_id: cl,
attr_id: attr,
list_index: attr_data.path.list_index,
fab_idx: accessor.fab_idx,
fab_filter: false,
dataver: attr_data.data_ver,
wildcard: false,
},
attr_data.data.unwrap_tlv().unwrap(),
)),
Err(err) => Err(AttrStatus::new(&attr_data.path.to_gp(), err, 0)),
};
WildcardIter::Single(once(result))
}
})
}
pub fn invoke<'m>(
&'m self,
req: &'m InvReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<(CmdDetails, TLVElement<'m>), CmdStatus>> + 'm {
req.inv_requests
.iter()
.flat_map(|inv_requests| inv_requests.iter())
.flat_map(move |cmd_data| {
if cmd_data.path.path.is_wildcard() {
let iter = self
.match_commands(
cmd_data.path.path.endpoint,
cmd_data.path.path.cluster,
cmd_data.path.path.leaf.map(|leaf| leaf as _),
)
.filter(move |(ep, cl, cmd)| {
Cluster::check_cmd_access(
accessor,
GenericPath::new(Some(ep.id), Some(cl.id), Some(*cmd)),
)
.is_ok()
})
.map(move |(ep, cl, cmd)| {
Ok((
CmdDetails {
node: self,
endpoint_id: ep.id,
cluster_id: cl.id,
cmd_id: cmd,
wildcard: true,
},
cmd_data.data.clone().unwrap_tlv().unwrap(),
))
});
WildcardIter::Wildcard(iter)
} else {
let ep = cmd_data.path.path.endpoint.unwrap();
let cl = cmd_data.path.path.cluster.unwrap();
let cmd = cmd_data.path.path.leaf.unwrap();
let result = match self.check_command(accessor, ep, cl, cmd) {
Ok(()) => Ok((
CmdDetails {
node: self,
endpoint_id: cmd_data.path.path.endpoint.unwrap(),
cluster_id: cmd_data.path.path.cluster.unwrap(),
cmd_id: cmd_data.path.path.leaf.unwrap(),
wildcard: false,
},
cmd_data.data.unwrap_tlv().unwrap(),
)),
Err(err) => Err(CmdStatus::new(cmd_data.path, err, 0)),
};
WildcardIter::Single(once(result))
}
})
}
fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool {
if let Some(path) = path {
path.endpoint.map(|id| id == ep).unwrap_or(true)
&& path.cluster.map(|id| id == cl).unwrap_or(true)
&& path.leaf.map(|id| id == leaf).unwrap_or(true)
} else {
true
}
}
pub fn match_attributes(
&self,
ep: Option<EndptId>,
cl: Option<ClusterId>,
attr: Option<AttrId>,
) -> impl Iterator<Item = (&'_ Endpoint, &'_ Cluster, &'_ Attribute)> + '_ {
self.match_endpoints(ep).flat_map(move |endpoint| {
endpoint
.match_attributes(cl, attr)
.map(move |(cl, attr)| (endpoint, cl, attr))
})
}
pub fn match_commands(
&self,
ep: Option<EndptId>,
cl: Option<ClusterId>,
cmd: Option<CmdId>,
) -> impl Iterator<Item = (&'_ Endpoint, &'_ Cluster, CmdId)> + '_ {
self.match_endpoints(ep).flat_map(move |endpoint| {
endpoint
.match_commands(cl, cmd)
.map(move |(cl, cmd)| (endpoint, cl, cmd))
})
}
pub fn check_attribute(
&self,
accessor: &Accessor,
ep: EndptId,
cl: ClusterId,
attr: AttrId,
write: bool,
) -> Result<(), IMStatusCode> {
self.check_endpoint(ep)
.and_then(|endpoint| endpoint.check_attribute(accessor, cl, attr, write))
}
pub fn check_command(
&self,
accessor: &Accessor,
ep: EndptId,
cl: ClusterId,
cmd: CmdId,
) -> Result<(), IMStatusCode> {
self.check_endpoint(ep)
.and_then(|endpoint| endpoint.check_command(accessor, cl, cmd))
}
pub fn match_endpoints(&self, ep: Option<EndptId>) -> impl Iterator<Item = &'_ Endpoint> + '_ {
self.endpoints
.iter()
.filter(move |endpoint| ep.map(|id| id == endpoint.id).unwrap_or(true))
}
pub fn check_endpoint(&self, ep: EndptId) -> Result<&Endpoint, IMStatusCode> {
self.endpoints
.iter()
.find(|endpoint| endpoint.id == ep)
.ok_or(IMStatusCode::UnsupportedEndpoint)
}
}
impl<'a> core::fmt::Display for Node<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "node:")?; writeln!(f, "node:")?;
for (i, element) in self.endpoints.iter().enumerate() { for (index, endpoint) in self.endpoints.iter().enumerate() {
if let Some(e) = element { writeln!(f, "endpoint {}: {}", index, endpoint)?;
writeln!(f, "endpoint {}: {}", i, e)?;
}
} }
write!(f, "") write!(f, "")
} }
} }
impl Node { pub struct DynamicNode<'a, const N: usize> {
pub fn new() -> Result<Box<Node>, Error> { id: u16,
let node = Box::default(); endpoints: heapless::Vec<Endpoint<'a>, N>,
Ok(node)
} }
pub fn set_changes_cb(&mut self, consumer: Box<dyn ChangeConsumer>) { impl<'a, const N: usize> DynamicNode<'a, N> {
self.changes_cb = Some(consumer); pub const fn new(id: u16) -> Self {
Self {
id,
endpoints: heapless::Vec::new(),
}
} }
pub fn add_endpoint(&mut self, dev_type: DeviceType) -> Result<EndptId, Error> { pub fn node(&self) -> Node<'_> {
Node {
id: self.id,
endpoints: &self.endpoints,
}
}
pub fn add(&mut self, endpoint: Endpoint<'a>) -> Result<(), Endpoint<'a>> {
if !self.endpoints.iter().any(|ep| ep.id == endpoint.id) {
self.endpoints.push(endpoint)
} else {
Err(endpoint)
}
}
pub fn remove(&mut self, endpoint_id: u16) -> Option<Endpoint<'a>> {
let index = self let index = self
.endpoints .endpoints
.iter() .iter()
.position(|x| x.is_none()) .enumerate()
.ok_or(Error::NoSpace)?; .find_map(|(index, ep)| (ep.id == endpoint_id).then_some(index));
let mut endpoint = Endpoint::new(dev_type)?;
if let Some(cb) = &self.changes_cb {
cb.endpoint_added(index as EndptId, &mut endpoint)?;
}
self.endpoints[index] = Some(endpoint);
Ok(index as EndptId)
}
pub fn get_endpoint(&self, endpoint_id: EndptId) -> Result<&Endpoint, Error> { if let Some(index) = index {
if (endpoint_id as usize) < ENDPTS_PER_ACC { Some(self.endpoints.swap_remove(index))
let endpoint = self.endpoints[endpoint_id as usize]
.as_ref()
.ok_or(Error::EndpointNotFound)?;
Ok(endpoint)
} else { } else {
Err(Error::EndpointNotFound) None
}
} }
} }
pub fn get_endpoint_mut(&mut self, endpoint_id: EndptId) -> Result<&mut Endpoint, Error> { impl<'a, const N: usize> core::fmt::Display for DynamicNode<'a, N> {
if (endpoint_id as usize) < ENDPTS_PER_ACC { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let endpoint = self.endpoints[endpoint_id as usize] self.node().fmt(f)
.as_mut()
.ok_or(Error::EndpointNotFound)?;
Ok(endpoint)
} else {
Err(Error::EndpointNotFound)
}
}
pub fn get_cluster_mut(
&mut self,
e: EndptId,
c: ClusterId,
) -> Result<&mut dyn ClusterType, Error> {
self.get_endpoint_mut(e)?.get_cluster_mut(c)
}
pub fn get_cluster(&self, e: EndptId, c: ClusterId) -> Result<&dyn ClusterType, Error> {
self.get_endpoint(e)?.get_cluster(c)
}
pub fn add_cluster(
&mut self,
endpoint_id: EndptId,
cluster: Box<dyn ClusterType>,
) -> Result<(), Error> {
let endpoint_id = endpoint_id as usize;
if endpoint_id < ENDPTS_PER_ACC {
self.endpoints[endpoint_id]
.as_mut()
.ok_or(Error::NoEndpoint)?
.add_cluster(cluster)
} else {
Err(Error::Invalid)
}
}
// Returns a slice of endpoints, with either a single endpoint or all (wildcard)
pub fn get_wildcard_endpoints(
&self,
endpoint: Option<EndptId>,
) -> Result<(&BoxedEndpoints, usize, bool), IMStatusCode> {
if let Some(e) = endpoint {
let e = e as usize;
if self.endpoints.len() <= e || self.endpoints[e].is_none() {
Err(IMStatusCode::UnsupportedEndpoint)
} else {
Ok((&self.endpoints[e..e + 1], e, false))
}
} else {
Ok((&self.endpoints[..], 0, true))
}
}
pub fn get_wildcard_endpoints_mut(
&mut self,
endpoint: Option<EndptId>,
) -> Result<(&mut BoxedEndpoints, usize, bool), IMStatusCode> {
if let Some(e) = endpoint {
let e = e as usize;
if self.endpoints.len() <= e || self.endpoints[e].is_none() {
Err(IMStatusCode::UnsupportedEndpoint)
} else {
Ok((&mut self.endpoints[e..e + 1], e, false))
}
} else {
Ok((&mut self.endpoints[..], 0, true))
}
}
/// Run a closure for all endpoints as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_endpoint<T>(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &Endpoint) -> Result<(), IMStatusCode>,
{
let mut current_path = *path;
let (endpoints, mut endpoint_id, wildcard) = self.get_wildcard_endpoints(path.endpoint)?;
for e in endpoints.iter() {
if let Some(e) = e {
current_path.endpoint = Some(endpoint_id as EndptId);
f(&current_path, e.as_ref())
.or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
endpoint_id += 1;
}
Ok(())
}
/// Run a closure for all endpoints (mutable) as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_endpoint_mut<T>(
&mut self,
path: &GenericPath,
mut f: T,
) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &mut Endpoint) -> Result<(), IMStatusCode>,
{
let mut current_path = *path;
let (endpoints, mut endpoint_id, wildcard) =
self.get_wildcard_endpoints_mut(path.endpoint)?;
for e in endpoints.iter_mut() {
if let Some(e) = e {
current_path.endpoint = Some(endpoint_id as EndptId);
f(&current_path, e.as_mut())
.or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
endpoint_id += 1;
}
Ok(())
}
/// Run a closure for all clusters as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_cluster<T>(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &dyn ClusterType) -> Result<(), IMStatusCode>,
{
self.for_each_endpoint(path, |p, e| {
let mut current_path = *p;
let (clusters, wildcard) = e.get_wildcard_clusters(p.cluster)?;
for c in clusters.iter() {
current_path.cluster = Some(c.base().id);
f(&current_path, c.as_ref())
.or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
Ok(())
})
}
/// Run a closure for all clusters (mutable) as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_cluster_mut<T>(
&mut self,
path: &GenericPath,
mut f: T,
) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &mut dyn ClusterType) -> Result<(), IMStatusCode>,
{
self.for_each_endpoint_mut(path, |p, e| {
let mut current_path = *p;
let (clusters, wildcard) = e.get_wildcard_clusters_mut(p.cluster)?;
for c in clusters.iter_mut() {
current_path.cluster = Some(c.base().id);
f(&current_path, c.as_mut())
.or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
Ok(())
})
}
/// Run a closure for all attributes as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_attribute<T>(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &dyn ClusterType) -> Result<(), IMStatusCode>,
{
self.for_each_cluster(path, |current_path, c| {
let mut current_path = *current_path;
let (attributes, wildcard) = c
.base()
.get_wildcard_attribute(path.leaf.map(|at| at as u16))?;
for a in attributes.iter() {
current_path.leaf = Some(a.id as u32);
f(&current_path, c).or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
Ok(())
})
} }
} }

View file

@ -16,7 +16,7 @@
*/ */
use crate::{ use crate::{
error::Error, error::{Error, ErrorCode},
tlv::{FromTLV, TLVElement, ToTLV}, tlv::{FromTLV, TLVElement, ToTLV},
}; };
use log::error; use log::error;
@ -47,12 +47,12 @@ impl FromTLV<'_> for Privilege {
1 => Ok(Privilege::VIEW), 1 => Ok(Privilege::VIEW),
2 => { 2 => {
error!("ProxyView privilege not yet supporteds"); error!("ProxyView privilege not yet supporteds");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
3 => Ok(Privilege::OPERATE), 3 => Ok(Privilege::OPERATE),
4 => Ok(Privilege::MANAGE), 4 => Ok(Privilege::MANAGE),
5 => Ok(Privilege::ADMIN), 5 => Ok(Privilege::ADMIN),
_ => Err(Error::Invalid), _ => Err(ErrorCode::Invalid.into()),
} }
} }
} }

View file

@ -0,0 +1,125 @@
use core::{borrow::Borrow, cell::RefCell};
use crate::{
acl::AclMgr,
fabric::FabricMgr,
handler_chain_type,
mdns::Mdns,
secure_channel::pake::PaseMgr,
utils::{epoch::Epoch, rand::Rand},
};
use super::{
cluster_basic_information::{self, BasicInfoCluster, BasicInfoConfig},
objects::{Cluster, EmptyHandler, Endpoint, EndptId},
sdm::{
admin_commissioning::{self, AdminCommCluster},
dev_att::DevAttDataFetcher,
failsafe::FailSafe,
general_commissioning::{self, GenCommCluster},
noc::{self, NocCluster},
nw_commissioning::{self, NwCommCluster},
},
system_model::{
access_control::{self, AccessControlCluster},
descriptor::{self, DescriptorCluster},
},
};
pub type RootEndpointHandler<'a> = handler_chain_type!(
DescriptorCluster<'static>,
BasicInfoCluster<'a>,
GenCommCluster<'a>,
NwCommCluster,
AdminCommCluster<'a>,
NocCluster<'a>,
AccessControlCluster<'a>
);
pub const CLUSTERS: [Cluster<'static>; 7] = [
descriptor::CLUSTER,
cluster_basic_information::CLUSTER,
general_commissioning::CLUSTER,
nw_commissioning::CLUSTER,
admin_commissioning::CLUSTER,
noc::CLUSTER,
access_control::CLUSTER,
];
pub fn endpoint(id: EndptId) -> Endpoint<'static> {
Endpoint {
id,
device_type: super::device_types::DEV_TYPE_ROOT_NODE,
clusters: &CLUSTERS,
}
}
pub fn handler<'a, T>(endpoint_id: u16, matter: &'a T) -> RootEndpointHandler<'a>
where
T: Borrow<BasicInfoConfig<'a>>
+ Borrow<dyn DevAttDataFetcher + 'a>
+ Borrow<RefCell<PaseMgr>>
+ Borrow<RefCell<FabricMgr>>
+ Borrow<RefCell<AclMgr>>
+ Borrow<RefCell<FailSafe>>
+ Borrow<dyn Mdns + 'a>
+ Borrow<Epoch>
+ Borrow<Rand>
+ 'a,
{
wrap(
endpoint_id,
matter.borrow(),
matter.borrow(),
matter.borrow(),
matter.borrow(),
matter.borrow(),
matter.borrow(),
matter.borrow(),
*matter.borrow(),
*matter.borrow(),
)
}
#[allow(clippy::too_many_arguments)]
pub fn wrap<'a>(
endpoint_id: u16,
basic_info: &'a BasicInfoConfig<'a>,
dev_att: &'a dyn DevAttDataFetcher,
pase: &'a RefCell<PaseMgr>,
fabric: &'a RefCell<FabricMgr>,
acl: &'a RefCell<AclMgr>,
failsafe: &'a RefCell<FailSafe>,
mdns: &'a dyn Mdns,
epoch: Epoch,
rand: Rand,
) -> RootEndpointHandler<'a> {
EmptyHandler
.chain(
endpoint_id,
access_control::ID,
AccessControlCluster::new(acl, rand),
)
.chain(
endpoint_id,
noc::ID,
NocCluster::new(dev_att, fabric, acl, failsafe, mdns, epoch, rand),
)
.chain(
endpoint_id,
admin_commissioning::ID,
AdminCommCluster::new(pase, mdns, rand),
)
.chain(endpoint_id, nw_commissioning::ID, NwCommCluster::new(rand))
.chain(
endpoint_id,
general_commissioning::ID,
GenCommCluster::new(failsafe, rand),
)
.chain(
endpoint_id,
cluster_basic_information::ID,
BasicInfoCluster::new(basic_info, rand),
)
.chain(endpoint_id, descriptor::ID, DescriptorCluster::new(rand))
}

View file

@ -15,15 +15,21 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::cmd_enter; use core::cell::RefCell;
use core::convert::TryInto;
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::core::Transaction;
use crate::mdns::Mdns;
use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::pake::PaseMgr;
use crate::secure_channel::spake2p::VerifierData; use crate::secure_channel::spake2p::VerifierData;
use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement};
use crate::{error::*, interaction_model::command::CommandReq}; use crate::utils::rand::Rand;
use log::{error, info}; use crate::{attribute_enum, cmd_enter};
use crate::{command_enum, error::*};
use log::info;
use num_derive::FromPrimitive; use num_derive::FromPrimitive;
use strum::{EnumDiscriminants, FromRepr};
pub const ID: u32 = 0x003C; pub const ID: u32 = 0x003C;
@ -34,120 +40,54 @@ pub enum WindowStatus {
BasicWindowOpen = 2, BasicWindowOpen = 2,
} }
#[derive(FromPrimitive)] #[derive(Copy, Clone, Debug, FromRepr, EnumDiscriminants)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
WindowStatus = 0, WindowStatus(AttrType<u8>) = 0,
AdminFabricIndex = 1, AdminFabricIndex(AttrType<Nullable<u8>>) = 1,
AdminVendorId = 2, AdminVendorId(AttrType<Nullable<u8>>) = 2,
} }
#[derive(FromPrimitive)] attribute_enum!(Attributes);
#[derive(FromRepr)]
#[repr(u32)]
pub enum Commands { pub enum Commands {
OpenCommWindow = 0x00, OpenCommWindow = 0x00,
OpenBasicCommWindow = 0x01, OpenBasicCommWindow = 0x01,
RevokeComm = 0x02, RevokeComm = 0x02,
} }
fn attr_window_status_new() -> Attribute { command_enum!(Commands);
pub const CLUSTER: Cluster<'static> = Cluster {
id: ID as _,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new( Attribute::new(
Attributes::WindowStatus as u16, AttributesDiscriminants::WindowStatus as u16,
AttrValue::Custom,
Access::RV, Access::RV,
Quality::NONE, Quality::NONE,
) ),
}
fn attr_admin_fabid_new() -> Attribute {
Attribute::new( Attribute::new(
Attributes::AdminFabricIndex as u16, AttributesDiscriminants::AdminFabricIndex as u16,
AttrValue::Custom,
Access::RV, Access::RV,
Quality::NULLABLE, Quality::NULLABLE,
) ),
}
fn attr_admin_vid_new() -> Attribute {
Attribute::new( Attribute::new(
Attributes::AdminVendorId as u16, AttributesDiscriminants::AdminVendorId as u16,
AttrValue::Custom,
Access::RV, Access::RV,
Quality::NULLABLE, Quality::NULLABLE,
) ),
} ],
commands: &[
pub struct AdminCommCluster { Commands::OpenCommWindow as _,
pase_mgr: PaseMgr, Commands::OpenBasicCommWindow as _,
base: Cluster, Commands::RevokeComm as _,
} ],
};
impl ClusterType for AdminCommCluster {
fn base(&self) -> &Cluster {
&self.base
}
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
}
fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) {
match num::FromPrimitive::from_u16(attr.attr_id) {
Some(Attributes::WindowStatus) => {
let status = 1_u8;
encoder.encode(EncodeValue::Value(&status))
}
Some(Attributes::AdminVendorId) => {
let vid = Nullable::NotNull(1_u8);
encoder.encode(EncodeValue::Value(&vid))
}
Some(Attributes::AdminFabricIndex) => {
let vid = Nullable::NotNull(1_u8);
encoder.encode(EncodeValue::Value(&vid))
}
_ => {
error!("Unsupported Attribute: this shouldn't happen");
}
}
}
fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> {
let cmd = cmd_req
.cmd
.path
.leaf
.map(num::FromPrimitive::from_u32)
.ok_or(IMStatusCode::UnsupportedCommand)?
.ok_or(IMStatusCode::UnsupportedCommand)?;
match cmd {
Commands::OpenCommWindow => self.handle_command_opencomm_win(cmd_req),
_ => Err(IMStatusCode::UnsupportedCommand),
}
}
}
impl AdminCommCluster {
pub fn new(pase_mgr: PaseMgr) -> Result<Box<Self>, Error> {
let mut c = Box::new(AdminCommCluster {
pase_mgr,
base: Cluster::new(ID)?,
});
c.base.add_attribute(attr_window_status_new())?;
c.base.add_attribute(attr_admin_fabid_new())?;
c.base.add_attribute(attr_admin_vid_new())?;
Ok(c)
}
fn handle_command_opencomm_win(
&mut self,
cmd_req: &mut CommandReq,
) -> Result<(), IMStatusCode> {
cmd_enter!("Open Commissioning Window");
let req =
OpenCommWindowReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?;
let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0);
self.pase_mgr
.enable_pase_session(verifier, req.discriminator)?;
Err(IMStatusCode::Success)
}
}
#[derive(FromTLV)] #[derive(FromTLV)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
@ -158,3 +98,88 @@ pub struct OpenCommWindowReq<'a> {
iterations: u32, iterations: u32,
salt: OctetStr<'a>, salt: OctetStr<'a>,
} }
pub struct AdminCommCluster<'a> {
data_ver: Dataver,
pase_mgr: &'a RefCell<PaseMgr>,
mdns: &'a dyn Mdns,
}
impl<'a> AdminCommCluster<'a> {
pub fn new(pase_mgr: &'a RefCell<PaseMgr>, mdns: &'a dyn Mdns, rand: Rand) -> Self {
Self {
data_ver: Dataver::new(rand),
pase_mgr,
mdns,
}
}
pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
match attr.attr_id.try_into()? {
Attributes::WindowStatus(codec) => codec.encode(writer, 1),
Attributes::AdminVendorId(codec) => codec.encode(writer, Nullable::NotNull(1)),
Attributes::AdminFabricIndex(codec) => {
codec.encode(writer, Nullable::NotNull(1))
}
}
}
} else {
Ok(())
}
}
pub fn invoke(
&mut self,
cmd: &CmdDetails,
data: &TLVElement,
_encoder: CmdDataEncoder,
) -> Result<(), Error> {
match cmd.cmd_id.try_into()? {
Commands::OpenCommWindow => self.handle_command_opencomm_win(data)?,
_ => Err(ErrorCode::CommandNotFound)?,
}
self.data_ver.changed();
Ok(())
}
fn handle_command_opencomm_win(&mut self, data: &TLVElement) -> Result<(), Error> {
cmd_enter!("Open Commissioning Window");
let req = OpenCommWindowReq::from_tlv(data)?;
let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0);
self.pase_mgr
.borrow_mut()
.enable_pase_session(verifier, req.discriminator, self.mdns)?;
Ok(())
}
}
impl<'a> Handler for AdminCommCluster<'a> {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
AdminCommCluster::read(self, attr, encoder)
}
fn invoke(
&mut self,
_transaction: &mut Transaction,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
AdminCommCluster::invoke(self, cmd, data, encoder)
}
}
impl<'a> NonBlockingHandler for AdminCommCluster<'a> {}
impl<'a> ChangeNotifier<()> for AdminCommCluster<'a> {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
}
}

View file

@ -15,9 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::{error::Error, transport::session::SessionMode}; use crate::{
error::{Error, ErrorCode},
transport::session::SessionMode,
};
use log::error; use log::error;
use std::sync::RwLock;
#[derive(PartialEq)] #[derive(PartialEq)]
#[allow(dead_code)] #[allow(dead_code)]
@ -42,26 +44,20 @@ pub enum State {
Armed(ArmedCtx), Armed(ArmedCtx),
} }
pub struct FailSafeInner { pub struct FailSafe {
state: State, state: State,
} }
pub struct FailSafe {
state: RwLock<FailSafeInner>,
}
impl FailSafe { impl FailSafe {
pub fn new() -> Self { #[inline(always)]
Self { pub const fn new() -> Self {
state: RwLock::new(FailSafeInner { state: State::Idle }), Self { state: State::Idle }
}
} }
pub fn arm(&self, timeout: u8, session_mode: SessionMode) -> Result<(), Error> { pub fn arm(&mut self, timeout: u8, session_mode: SessionMode) -> Result<(), Error> {
let mut inner = self.state.write()?; match &mut self.state {
match &mut inner.state {
State::Idle => { State::Idle => {
inner.state = State::Armed(ArmedCtx { self.state = State::Armed(ArmedCtx {
session_mode, session_mode,
timeout, timeout,
noc_state: NocState::NocNotRecvd, noc_state: NocState::NocNotRecvd,
@ -69,7 +65,8 @@ impl FailSafe {
} }
State::Armed(c) => { State::Armed(c) => {
if c.session_mode != session_mode { if c.session_mode != session_mode {
return Err(Error::Invalid); error!("Received Fail-Safe Arm with different session modes; current {:?}, incoming {:?}", c.session_mode, session_mode);
Err(ErrorCode::Invalid)?;
} }
// re-arm // re-arm
c.timeout = timeout; c.timeout = timeout;
@ -78,58 +75,55 @@ impl FailSafe {
Ok(()) Ok(())
} }
pub fn disarm(&self, session_mode: SessionMode) -> Result<(), Error> { pub fn disarm(&mut self, session_mode: SessionMode) -> Result<(), Error> {
let mut inner = self.state.write()?; match &mut self.state {
match &mut inner.state {
State::Idle => { State::Idle => {
error!("Received Fail-Safe Disarm without it being armed"); error!("Received Fail-Safe Disarm without it being armed");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
State::Armed(c) => { State::Armed(c) => {
match c.noc_state { match c.noc_state {
NocState::NocNotRecvd => return Err(Error::Invalid), NocState::NocNotRecvd => Err(ErrorCode::Invalid)?,
NocState::AddNocRecvd(idx) | NocState::UpdateNocRecvd(idx) => { NocState::AddNocRecvd(idx) | NocState::UpdateNocRecvd(idx) => {
if let SessionMode::Case(c) = session_mode { if let SessionMode::Case(c) = session_mode {
if c.fab_idx != idx { if c.fab_idx != idx {
error!( error!(
"Received disarm in separate session from previous Add/Update NOC" "Received disarm in separate session from previous Add/Update NOC"
); );
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
} else { } else {
error!("Received disarm in a non-CASE session"); error!("Received disarm in a non-CASE session");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
} }
} }
inner.state = State::Idle; self.state = State::Idle;
} }
} }
Ok(()) Ok(())
} }
pub fn is_armed(&self) -> bool { pub fn is_armed(&self) -> bool {
self.state.read().unwrap().state != State::Idle self.state != State::Idle
} }
pub fn record_add_noc(&self, fabric_index: u8) -> Result<(), Error> { pub fn record_add_noc(&mut self, fabric_index: u8) -> Result<(), Error> {
let mut inner = self.state.write()?; match &mut self.state {
match &mut inner.state { State::Idle => Err(ErrorCode::Invalid.into()),
State::Idle => Err(Error::Invalid),
State::Armed(c) => { State::Armed(c) => {
if c.noc_state == NocState::NocNotRecvd { if c.noc_state == NocState::NocNotRecvd {
c.noc_state = NocState::AddNocRecvd(fabric_index); c.noc_state = NocState::AddNocRecvd(fabric_index);
Ok(()) Ok(())
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
} }
} }
pub fn allow_noc_change(&self) -> Result<bool, Error> { pub fn allow_noc_change(&self) -> Result<bool, Error> {
let mut inner = self.state.write()?; let allow = match &self.state {
let allow = match &mut inner.state {
State::Idle => false, State::Idle => false,
State::Armed(c) => c.noc_state == NocState::NocNotRecvd, State::Armed(c) => c.noc_state == NocState::NocNotRecvd,
}; };

View file

@ -15,16 +15,18 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::cmd_enter; use core::cell::RefCell;
use core::convert::TryInto;
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::data_model::sdm::failsafe::FailSafe; use crate::data_model::sdm::failsafe::FailSafe;
use crate::interaction_model::core::IMStatusCode; use crate::interaction_model::core::Transaction;
use crate::interaction_model::messages::ib; use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr};
use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}; use crate::utils::rand::Rand;
use crate::{error::*, interaction_model::command::CommandReq}; use crate::{attribute_enum, cmd_enter};
use log::{error, info}; use crate::{command_enum, error::*};
use num_derive::FromPrimitive; use log::info;
use std::sync::Arc; use strum::{EnumDiscriminants, FromRepr};
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
#[allow(dead_code)] #[allow(dead_code)]
@ -38,65 +40,80 @@ enum CommissioningError {
pub const ID: u32 = 0x0030; pub const ID: u32 = 0x0030;
#[derive(FromPrimitive)] #[derive(FromRepr, EnumDiscriminants)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
BreadCrumb = 0, BreadCrumb(AttrType<u64>) = 0,
BasicCommissioningInfo = 1, BasicCommissioningInfo(()) = 1,
RegConfig = 2, RegConfig(AttrType<u8>) = 2,
LocationCapability = 3, LocationCapability(AttrType<u8>) = 3,
} }
#[derive(FromPrimitive)] attribute_enum!(Attributes);
#[derive(FromRepr)]
#[repr(u32)]
pub enum Commands { pub enum Commands {
ArmFailsafe = 0x00, ArmFailsafe = 0x00,
ArmFailsafeResp = 0x01,
SetRegulatoryConfig = 0x02, SetRegulatoryConfig = 0x02,
SetRegulatoryConfigResp = 0x03,
CommissioningComplete = 0x04, CommissioningComplete = 0x04,
}
command_enum!(Commands);
#[repr(u16)]
pub enum RespCommands {
ArmFailsafeResp = 0x01,
SetRegulatoryConfigResp = 0x03,
CommissioningCompleteResp = 0x05, CommissioningCompleteResp = 0x05,
} }
#[derive(FromTLV, ToTLV)]
#[tlvargs(lifetime = "'a")]
struct CommonResponse<'a> {
error_code: u8,
debug_txt: UtfStr<'a>,
}
pub enum RegLocationType { pub enum RegLocationType {
Indoor = 0, Indoor = 0,
Outdoor = 1, Outdoor = 1,
IndoorOutdoor = 2, IndoorOutdoor = 2,
} }
fn attr_bread_crumb_new(bread_crumb: u64) -> Attribute { pub const CLUSTER: Cluster<'static> = Cluster {
id: ID as _,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new( Attribute::new(
Attributes::BreadCrumb as u16, AttributesDiscriminants::BreadCrumb as u16,
AttrValue::Uint64(bread_crumb), Access::READ.union(Access::WRITE).union(Access::NEED_ADMIN),
Access::READ | Access::WRITE | Access::NEED_ADMIN,
Quality::NONE, Quality::NONE,
) ),
}
fn attr_reg_config_new(reg_config: RegLocationType) -> Attribute {
Attribute::new( Attribute::new(
Attributes::RegConfig as u16, AttributesDiscriminants::RegConfig as u16,
AttrValue::Uint8(reg_config as u8),
Access::RV, Access::RV,
Quality::NONE, Quality::NONE,
) ),
}
fn attr_location_capability_new(reg_config: RegLocationType) -> Attribute {
Attribute::new( Attribute::new(
Attributes::LocationCapability as u16, AttributesDiscriminants::LocationCapability as u16,
AttrValue::Uint8(reg_config as u8),
Access::RV, Access::RV,
Quality::FIXED, Quality::FIXED,
) ),
}
fn attr_comm_info_new() -> Attribute {
Attribute::new( Attribute::new(
Attributes::BasicCommissioningInfo as u16, AttributesDiscriminants::BasicCommissioningInfo as u16,
AttrValue::Custom,
Access::RV, Access::RV,
Quality::FIXED, Quality::FIXED,
) ),
} ],
commands: &[
Commands::ArmFailsafe as _,
Commands::SetRegulatoryConfig as _,
Commands::CommissioningComplete as _,
],
};
#[derive(FromTLV, ToTLV)] #[derive(FromTLV, ToTLV)]
struct FailSafeParams { struct FailSafeParams {
@ -104,144 +121,151 @@ struct FailSafeParams {
bread_crumb: u8, bread_crumb: u8,
} }
pub struct GenCommCluster { pub struct GenCommCluster<'a> {
data_ver: Dataver,
expiry_len: u16, expiry_len: u16,
failsafe: Arc<FailSafe>, failsafe: &'a RefCell<FailSafe>,
base: Cluster,
} }
impl ClusterType for GenCommCluster { impl<'a> GenCommCluster<'a> {
fn base(&self) -> &Cluster { pub fn new(failsafe: &'a RefCell<FailSafe>, rand: Rand) -> Self {
&self.base Self {
} data_ver: Dataver::new(rand),
fn base_mut(&mut self) -> &mut Cluster { failsafe,
&mut self.base
}
fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) {
match num::FromPrimitive::from_u16(attr.attr_id) {
Some(Attributes::BasicCommissioningInfo) => {
encoder.encode(EncodeValue::Closure(&|tag, tw| {
let _ = tw.start_struct(tag);
let _ = tw.u16(TagType::Context(0), self.expiry_len);
let _ = tw.end_container();
}))
}
_ => {
error!("Unsupported Attribute: this shouldn't happen");
}
}
}
fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> {
let cmd = cmd_req
.cmd
.path
.leaf
.map(num::FromPrimitive::from_u32)
.ok_or(IMStatusCode::UnsupportedCommand)?
.ok_or(IMStatusCode::UnsupportedCommand)?;
match cmd {
Commands::ArmFailsafe => self.handle_command_armfailsafe(cmd_req),
Commands::SetRegulatoryConfig => self.handle_command_setregulatoryconfig(cmd_req),
Commands::CommissioningComplete => self.handle_command_commissioningcomplete(cmd_req),
_ => Err(IMStatusCode::UnsupportedCommand),
}
}
}
impl GenCommCluster {
pub fn new() -> Result<Box<Self>, Error> {
let failsafe = Arc::new(FailSafe::new());
let mut c = Box::new(GenCommCluster {
// TODO: Arch-Specific // TODO: Arch-Specific
expiry_len: 120, expiry_len: 120,
failsafe, }
base: Cluster::new(ID)?,
});
c.base.add_attribute(attr_bread_crumb_new(0))?;
// TODO: Arch-Specific
c.base
.add_attribute(attr_reg_config_new(RegLocationType::IndoorOutdoor))?;
// TODO: Arch-Specific
c.base
.add_attribute(attr_location_capability_new(RegLocationType::IndoorOutdoor))?;
c.base.add_attribute(attr_comm_info_new())?;
Ok(c)
} }
pub fn failsafe(&self) -> Arc<FailSafe> { pub fn failsafe(&self) -> &RefCell<FailSafe> {
self.failsafe.clone() self.failsafe
} }
fn handle_command_armfailsafe(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
match attr.attr_id.try_into()? {
Attributes::BreadCrumb(codec) => codec.encode(writer, 0),
// TODO: Arch-Specific
Attributes::RegConfig(codec) => {
codec.encode(writer, RegLocationType::IndoorOutdoor as _)
}
// TODO: Arch-Specific
Attributes::LocationCapability(codec) => {
codec.encode(writer, RegLocationType::IndoorOutdoor as _)
}
Attributes::BasicCommissioningInfo(_) => {
writer.start_struct(AttrDataWriter::TAG)?;
writer.u16(TagType::Context(0), self.expiry_len)?;
writer.end_container()?;
writer.complete()
}
}
}
} else {
Ok(())
}
}
pub fn invoke(
&mut self,
transaction: &mut Transaction,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
match cmd.cmd_id.try_into()? {
Commands::ArmFailsafe => self.handle_command_armfailsafe(transaction, data, encoder)?,
Commands::SetRegulatoryConfig => {
self.handle_command_setregulatoryconfig(transaction, data, encoder)?
}
Commands::CommissioningComplete => {
self.handle_command_commissioningcomplete(transaction, encoder)?;
}
}
self.data_ver.changed();
Ok(())
}
fn handle_command_armfailsafe(
&mut self,
transaction: &mut Transaction,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
cmd_enter!("ARM Fail Safe"); cmd_enter!("ARM Fail Safe");
let p = FailSafeParams::from_tlv(&cmd_req.data)?; let p = FailSafeParams::from_tlv(data)?;
let mut status = CommissioningError::Ok as u8;
if self let status = if self
.failsafe .failsafe
.arm(p.expiry_len, cmd_req.trans.session.get_session_mode()) .borrow_mut()
.arm(
p.expiry_len,
transaction.session().get_session_mode().clone(),
)
.is_err() .is_err()
{ {
status = CommissioningError::ErrBusyWithOtherAdmin as u8; CommissioningError::ErrBusyWithOtherAdmin as u8
} } else {
CommissioningError::Ok as u8
};
let cmd_data = CommonResponse { let cmd_data = CommonResponse {
error_code: status, error_code: status,
debug_txt: "".to_owned(), debug_txt: UtfStr::new(b""),
}; };
let resp = ib::InvResp::cmd_new(
0, encoder
ID, .with_command(RespCommands::ArmFailsafeResp as _)?
Commands::ArmFailsafeResp as u16, .set(cmd_data)?;
EncodeValue::Value(&cmd_data),
); transaction.complete();
let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous);
cmd_req.trans.complete();
Ok(()) Ok(())
} }
fn handle_command_setregulatoryconfig( fn handle_command_setregulatoryconfig(
&mut self, &mut self,
cmd_req: &mut CommandReq, transaction: &mut Transaction,
) -> Result<(), IMStatusCode> { data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
cmd_enter!("Set Regulatory Config"); cmd_enter!("Set Regulatory Config");
let country_code = cmd_req let country_code = data
.data
.find_tag(1) .find_tag(1)
.map_err(|_| IMStatusCode::InvalidCommand)? .map_err(|_| ErrorCode::InvalidCommand)?
.slice() .slice()
.map_err(|_| IMStatusCode::InvalidCommand)?; .map_err(|_| ErrorCode::InvalidCommand)?;
info!("Received country code: {:?}", country_code); info!("Received country code: {:?}", country_code);
let cmd_data = CommonResponse { let cmd_data = CommonResponse {
error_code: 0, error_code: 0,
debug_txt: "".to_owned(), debug_txt: UtfStr::new(b""),
}; };
let resp = ib::InvResp::cmd_new(
0, encoder
ID, .with_command(RespCommands::SetRegulatoryConfigResp as _)?
Commands::SetRegulatoryConfigResp as u16, .set(cmd_data)?;
EncodeValue::Value(&cmd_data),
); transaction.complete();
let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous);
cmd_req.trans.complete();
Ok(()) Ok(())
} }
fn handle_command_commissioningcomplete( fn handle_command_commissioningcomplete(
&mut self, &mut self,
cmd_req: &mut CommandReq, transaction: &mut Transaction,
) -> Result<(), IMStatusCode> { encoder: CmdDataEncoder,
) -> Result<(), Error> {
cmd_enter!("Commissioning Complete"); cmd_enter!("Commissioning Complete");
let mut status: u8 = CommissioningError::Ok as u8; let mut status: u8 = CommissioningError::Ok as u8;
// Has to be a Case Session // Has to be a Case Session
if cmd_req.trans.session.get_local_fabric_idx().is_none() { if transaction.session().get_local_fabric_idx().is_none() {
status = CommissioningError::ErrInvalidAuth as u8; status = CommissioningError::ErrInvalidAuth as u8;
} }
@ -249,7 +273,8 @@ impl GenCommCluster {
// scope that is for this session // scope that is for this session
if self if self
.failsafe .failsafe
.disarm(cmd_req.trans.session.get_session_mode()) .borrow_mut()
.disarm(transaction.session().get_session_mode().clone())
.is_err() .is_err()
{ {
status = CommissioningError::ErrInvalidAuth as u8; status = CommissioningError::ErrInvalidAuth as u8;
@ -257,22 +282,38 @@ impl GenCommCluster {
let cmd_data = CommonResponse { let cmd_data = CommonResponse {
error_code: status, error_code: status,
debug_txt: "".to_owned(), debug_txt: UtfStr::new(b""),
}; };
let resp = ib::InvResp::cmd_new(
0, encoder
ID, .with_command(RespCommands::CommissioningCompleteResp as _)?
Commands::CommissioningCompleteResp as u16, .set(cmd_data)?;
EncodeValue::Value(&cmd_data),
); transaction.complete();
let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous);
cmd_req.trans.complete();
Ok(()) Ok(())
} }
} }
#[derive(FromTLV, ToTLV)] impl<'a> Handler for GenCommCluster<'a> {
struct CommonResponse { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
error_code: u8, GenCommCluster::read(self, attr, encoder)
debug_txt: String, }
fn invoke(
&mut self,
transaction: &mut Transaction,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
GenCommCluster::invoke(self, transaction, cmd, data, encoder)
}
}
impl<'a> NonBlockingHandler for GenCommCluster<'a> {}
impl<'a> ChangeNotifier<()> for GenCommCluster<'a> {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
}
} }

File diff suppressed because it is too large Load diff

View file

@ -16,38 +16,59 @@
*/ */
use crate::{ use crate::{
data_model::objects::{Cluster, ClusterType}, data_model::objects::{
error::Error, AttrDataEncoder, AttrDetails, ChangeNotifier, Cluster, Dataver, Handler,
NonBlockingHandler, ATTRIBUTE_LIST, FEATURE_MAP,
},
error::{Error, ErrorCode},
utils::rand::Rand,
}; };
pub const ID: u32 = 0x0031; pub const ID: u32 = 0x0031;
pub struct NwCommCluster {
base: Cluster,
}
impl ClusterType for NwCommCluster {
fn base(&self) -> &Cluster {
&self.base
}
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
}
}
enum FeatureMap { enum FeatureMap {
_Wifi = 0x01, _Wifi = 0x01,
_Thread = 0x02, _Thread = 0x02,
Ethernet = 0x04, Ethernet = 0x04,
} }
pub const CLUSTER: Cluster<'static> = Cluster {
id: ID as _,
feature_map: FeatureMap::Ethernet as _,
attributes: &[FEATURE_MAP, ATTRIBUTE_LIST],
commands: &[],
};
pub struct NwCommCluster {
data_ver: Dataver,
}
impl NwCommCluster { impl NwCommCluster {
pub fn new() -> Result<Box<Self>, Error> { pub fn new(rand: Rand) -> Self {
let mut c = Box::new(Self { Self {
base: Cluster::new(ID)?, data_ver: Dataver::new(rand),
}); }
// TODO: Arch-Specific }
c.base.set_feature_map(FeatureMap::Ethernet as u32)?; }
Ok(c)
impl Handler for NwCommCluster {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
Err(ErrorCode::AttributeNotFound.into())
}
} else {
Ok(())
}
}
}
impl NonBlockingHandler for NwCommCluster {}
impl ChangeNotifier<()> for NwCommCluster {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
} }
} }

View file

@ -15,46 +15,135 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::Arc; use core::cell::RefCell;
use core::convert::TryInto;
use num_derive::FromPrimitive; use strum::{EnumDiscriminants, FromRepr};
use crate::acl::{self, AclEntry, AclMgr}; use crate::acl::{self, AclEntry, AclMgr};
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::error::*;
use crate::interaction_model::core::IMStatusCode;
use crate::interaction_model::messages::ib::{attr_list_write, ListOperation}; use crate::interaction_model::messages::ib::{attr_list_write, ListOperation};
use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV}; use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV};
use crate::utils::rand::Rand;
use crate::{attribute_enum, error::*};
use log::{error, info}; use log::{error, info};
pub const ID: u32 = 0x001F; pub const ID: u32 = 0x001F;
#[derive(FromPrimitive)] #[derive(FromRepr, EnumDiscriminants)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
Acl = 0, Acl(()) = 0,
Extension = 1, Extension(()) = 1,
SubjectsPerEntry = 2, SubjectsPerEntry(AttrType<u16>) = 2,
TargetsPerEntry = 3, TargetsPerEntry(AttrType<u16>) = 3,
EntriesPerFabric = 4, EntriesPerFabric(AttrType<u16>) = 4,
} }
pub struct AccessControlCluster { attribute_enum!(Attributes);
base: Cluster,
acl_mgr: Arc<AclMgr>, pub const CLUSTER: Cluster<'static> = Cluster {
id: ID,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new(
AttributesDiscriminants::Acl as u16,
Access::RWFA,
Quality::NONE,
),
Attribute::new(
AttributesDiscriminants::Extension as u16,
Access::RWFA,
Quality::NONE,
),
Attribute::new(
AttributesDiscriminants::SubjectsPerEntry as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::TargetsPerEntry as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::EntriesPerFabric as u16,
Access::RV,
Quality::FIXED,
),
],
commands: &[],
};
pub struct AccessControlCluster<'a> {
data_ver: Dataver,
acl_mgr: &'a RefCell<AclMgr>,
} }
impl AccessControlCluster { impl<'a> AccessControlCluster<'a> {
pub fn new(acl_mgr: Arc<AclMgr>) -> Result<Box<Self>, Error> { pub fn new(acl_mgr: &'a RefCell<AclMgr>, rand: Rand) -> Self {
let mut c = Box::new(AccessControlCluster { Self {
base: Cluster::new(ID)?, data_ver: Dataver::new(rand),
acl_mgr, acl_mgr,
}); }
c.base.add_attribute(attr_acl_new())?; }
c.base.add_attribute(attr_extension_new())?;
c.base.add_attribute(attr_subjects_per_entry_new())?; pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
c.base.add_attribute(attr_targets_per_entry_new())?; if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? {
c.base.add_attribute(attr_entries_per_fabric_new())?; if attr.is_system() {
Ok(c) CLUSTER.read(attr.attr_id, writer)
} else {
match attr.attr_id.try_into()? {
Attributes::Acl(_) => {
writer.start_array(AttrDataWriter::TAG)?;
self.acl_mgr.borrow().for_each_acl(|entry| {
if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx {
entry.to_tlv(&mut writer, TagType::Anonymous)?;
}
Ok(())
})?;
writer.end_container()?;
writer.complete()
}
Attributes::Extension(_) => {
// Empty for now
writer.start_array(AttrDataWriter::TAG)?;
writer.end_container()?;
writer.complete()
}
Attributes::SubjectsPerEntry(codec) => {
codec.encode(writer, acl::SUBJECTS_PER_ENTRY as u16)
}
Attributes::TargetsPerEntry(codec) => {
codec.encode(writer, acl::TARGETS_PER_ENTRY as u16)
}
Attributes::EntriesPerFabric(codec) => {
codec.encode(writer, acl::ENTRIES_PER_FABRIC as u16)
}
}
}
} else {
Ok(())
}
}
pub fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
match attr.attr_id.try_into()? {
Attributes::Acl(_) => {
attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| {
self.write_acl_attr(&op, data, attr.fab_idx)
})
}
_ => {
error!("Attribute not yet supported: this shouldn't happen");
Err(ErrorCode::AttributeNotFound.into())
}
}
} }
/// Write the ACL Attribute /// Write the ACL Attribute
@ -66,141 +155,59 @@ impl AccessControlCluster {
op: &ListOperation, op: &ListOperation,
data: &TLVElement, data: &TLVElement,
fab_idx: u8, fab_idx: u8,
) -> Result<(), IMStatusCode> { ) -> Result<(), Error> {
info!("Performing ACL operation {:?}", op); info!("Performing ACL operation {:?}", op);
let result = match op { match op {
ListOperation::AddItem | ListOperation::EditItem(_) => { ListOperation::AddItem | ListOperation::EditItem(_) => {
let mut acl_entry = let mut acl_entry = AclEntry::from_tlv(data)?;
AclEntry::from_tlv(data).map_err(|_| IMStatusCode::ConstraintError)?;
info!("ACL {:?}", acl_entry); info!("ACL {:?}", acl_entry);
// Overwrite the fabric index with our accessing fabric index // Overwrite the fabric index with our accessing fabric index
acl_entry.fab_idx = Some(fab_idx); acl_entry.fab_idx = Some(fab_idx);
if let ListOperation::EditItem(index) = op { if let ListOperation::EditItem(index) = op {
self.acl_mgr.edit(*index as u8, fab_idx, acl_entry) self.acl_mgr
.borrow_mut()
.edit(*index as u8, fab_idx, acl_entry)
} else { } else {
self.acl_mgr.add(acl_entry) self.acl_mgr.borrow_mut().add(acl_entry)
} }
} }
ListOperation::DeleteItem(index) => self.acl_mgr.delete(*index as u8, fab_idx), ListOperation::DeleteItem(index) => {
ListOperation::DeleteList => self.acl_mgr.delete_for_fabric(fab_idx), self.acl_mgr.borrow_mut().delete(*index as u8, fab_idx)
}; }
match result { ListOperation::DeleteList => self.acl_mgr.borrow_mut().delete_for_fabric(fab_idx),
Ok(_) => Ok(()),
Err(Error::NoSpace) => Err(IMStatusCode::ResourceExhausted),
_ => Err(IMStatusCode::ConstraintError),
} }
} }
} }
impl ClusterType for AccessControlCluster { impl<'a> Handler for AccessControlCluster<'a> {
fn base(&self) -> &Cluster { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
&self.base AccessControlCluster::read(self, attr, encoder)
}
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
} }
fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { fn write(&mut self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
match num::FromPrimitive::from_u16(attr.attr_id) { AccessControlCluster::write(self, attr, data)
Some(Attributes::Acl) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
let _ = tw.start_array(tag);
let _ = self.acl_mgr.for_each_acl(|entry| {
if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx {
let _ = entry.to_tlv(tw, TagType::Anonymous);
}
});
let _ = tw.end_container();
})),
Some(Attributes::Extension) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
// Empty for now
let _ = tw.start_array(tag);
let _ = tw.end_container();
})),
_ => {
error!("Attribute not yet supported: this shouldn't happen");
}
} }
} }
fn write_attribute( impl<'a> NonBlockingHandler for AccessControlCluster<'a> {}
&mut self,
attr: &AttrDetails,
data: &TLVElement,
) -> Result<(), IMStatusCode> {
let result = if let Some(Attributes::Acl) = num::FromPrimitive::from_u16(attr.attr_id) {
attr_list_write(attr, data, |op, data| {
self.write_acl_attr(&op, data, attr.fab_idx)
})
} else {
error!("Attribute not yet supported: this shouldn't happen");
Err(IMStatusCode::NotFound)
};
if result.is_ok() {
self.base.cluster_changed();
}
result
}
}
fn attr_acl_new() -> Attribute { impl<'a> ChangeNotifier<()> for AccessControlCluster<'a> {
Attribute::new( fn consume_change(&mut self) -> Option<()> {
Attributes::Acl as u16, self.data_ver.consume_change(())
AttrValue::Custom,
Access::RWFA,
Quality::NONE,
)
} }
fn attr_extension_new() -> Attribute {
Attribute::new(
Attributes::Extension as u16,
AttrValue::Custom,
Access::RWFA,
Quality::NONE,
)
}
fn attr_subjects_per_entry_new() -> Attribute {
Attribute::new(
Attributes::SubjectsPerEntry as u16,
AttrValue::Uint16(acl::SUBJECTS_PER_ENTRY as u16),
Access::RV,
Quality::FIXED,
)
}
fn attr_targets_per_entry_new() -> Attribute {
Attribute::new(
Attributes::TargetsPerEntry as u16,
AttrValue::Uint16(acl::TARGETS_PER_ENTRY as u16),
Access::RV,
Quality::FIXED,
)
}
fn attr_entries_per_fabric_new() -> Attribute {
Attribute::new(
Attributes::EntriesPerFabric as u16,
AttrValue::Uint16(acl::ENTRIES_PER_FABRIC as u16),
Access::RV,
Quality::FIXED,
)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc; use core::cell::RefCell;
use crate::{ use crate::{
acl::{AclEntry, AclMgr, AuthMode}, acl::{AclEntry, AclMgr, AuthMode},
data_model::{ data_model::objects::{AttrDataEncoder, AttrDetails, Node, Privilege},
core::read::AttrReadEncoder,
objects::{AttrDetails, ClusterType, Privilege},
},
interaction_model::messages::ib::ListOperation, interaction_model::messages::ib::ListOperation,
tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV},
utils::writebuf::WriteBuf, utils::{rand::dummy_rand, writebuf::WriteBuf},
}; };
use super::AccessControlCluster; use super::AccessControlCluster;
@ -209,26 +216,27 @@ mod tests {
/// Add an ACL entry /// Add an ACL entry
fn acl_cluster_add() { fn acl_cluster_add() {
let mut buf: [u8; 100] = [0; 100]; let mut buf: [u8; 100] = [0; 100];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); let acl_mgr = RefCell::new(AclMgr::new());
let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); new.to_tlv(&mut tw, TagType::Anonymous).unwrap();
let data = get_root_node_struct(writebuf.as_borrow_slice()).unwrap(); let data = get_root_node_struct(writebuf.as_slice()).unwrap();
// Test, ACL has fabric index 2, but the accessing fabric is 1 // Test, ACL has fabric index 2, but the accessing fabric is 1
// the fabric index in the TLV should be ignored and the ACL should be created with entry 1 // the fabric index in the TLV should be ignored and the ACL should be created with entry 1
let result = acl.write_acl_attr(&ListOperation::AddItem, &data, 1); let result = acl.write_acl_attr(&ListOperation::AddItem, &data, 1);
assert_eq!(result, Ok(())); assert!(result.is_ok());
let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case);
acl_mgr acl_mgr
.borrow()
.for_each_acl(|a| { .for_each_acl(|a| {
assert_eq!(*a, verifier); assert_eq!(*a, verifier);
Ok(())
}) })
.unwrap(); .unwrap();
} }
@ -237,38 +245,39 @@ mod tests {
/// - The listindex used for edit should be relative to the current fabric /// - The listindex used for edit should be relative to the current fabric
fn acl_cluster_edit() { fn acl_cluster_edit() {
let mut buf: [u8; 100] = [0; 100]; let mut buf: [u8; 100] = [0; 100];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
// Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order
let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); let acl_mgr = RefCell::new(AclMgr::new());
let mut verifier = [ let mut verifier = [
AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::VIEW, AuthMode::Case),
AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case),
AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case),
]; ];
for i in verifier { for i in &verifier {
acl_mgr.add(i).unwrap(); acl_mgr.borrow_mut().add(i.clone()).unwrap();
} }
let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); new.to_tlv(&mut tw, TagType::Anonymous).unwrap();
let data = get_root_node_struct(writebuf.as_borrow_slice()).unwrap(); let data = get_root_node_struct(writebuf.as_slice()).unwrap();
// Test, Edit Fabric 2's index 1 - with accessing fabring as 2 - allow // Test, Edit Fabric 2's index 1 - with accessing fabring as 2 - allow
let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2); let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2);
// Fabric 2's index 1, is actually our index 2, update the verifier // Fabric 2's index 1, is actually our index 2, update the verifier
verifier[2] = new; verifier[2] = new;
assert_eq!(result, Ok(())); assert!(result.is_ok());
// Also validate in the acl_mgr that the entries are in the right order // Also validate in the acl_mgr that the entries are in the right order
let mut index = 0; let mut index = 0;
acl_mgr acl_mgr
.borrow()
.for_each_acl(|a| { .for_each_acl(|a| {
assert_eq!(*a, verifier[index]); assert_eq!(*a, verifier[index]);
index += 1; index += 1;
Ok(())
}) })
.unwrap(); .unwrap();
} }
@ -277,30 +286,32 @@ mod tests {
/// - The listindex used for delete should be relative to the current fabric /// - The listindex used for delete should be relative to the current fabric
fn acl_cluster_delete() { fn acl_cluster_delete() {
// Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order
let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); let acl_mgr = RefCell::new(AclMgr::new());
let input = [ let input = [
AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::VIEW, AuthMode::Case),
AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case),
AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case),
]; ];
for i in input { for i in &input {
acl_mgr.add(i).unwrap(); acl_mgr.borrow_mut().add(i.clone()).unwrap();
} }
let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); let mut acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
// data is don't-care actually // data is don't-care actually
let data = TLVElement::new(TagType::Anonymous, ElementType::True); let data = TLVElement::new(TagType::Anonymous, ElementType::True);
// Test , Delete Fabric 1's index 0 // Test , Delete Fabric 1's index 0
let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, 1); let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, 1);
assert_eq!(result, Ok(())); assert!(result.is_ok());
let verifier = [input[0], input[2]]; let verifier = [input[0].clone(), input[2].clone()];
// Also validate in the acl_mgr that the entries are in the right order // Also validate in the acl_mgr that the entries are in the right order
let mut index = 0; let mut index = 0;
acl_mgr acl_mgr
.borrow()
.for_each_acl(|a| { .for_each_acl(|a| {
assert_eq!(*a, verifier[index]); assert_eq!(*a, verifier[index]);
index += 1; index += 1;
Ok(())
}) })
.unwrap(); .unwrap();
} }
@ -309,84 +320,126 @@ mod tests {
/// - acl read with and without fabric filtering /// - acl read with and without fabric filtering
fn acl_cluster_read() { fn acl_cluster_read() {
let mut buf: [u8; 100] = [0; 100]; let mut buf: [u8; 100] = [0; 100];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
// Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order
let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); let acl_mgr = RefCell::new(AclMgr::new());
let input = [ let input = [
AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::VIEW, AuthMode::Case),
AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case),
AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case),
]; ];
for i in input { for i in input {
acl_mgr.add(i).unwrap(); acl_mgr.borrow_mut().add(i).unwrap();
} }
let acl = AccessControlCluster::new(acl_mgr).unwrap(); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
// Test 1, all 3 entries are read in the response without fabric filtering // Test 1, all 3 entries are read in the response without fabric filtering
{ {
let mut tw = TLVWriter::new(&mut writebuf); let attr = AttrDetails {
let mut encoder = AttrReadEncoder::new(&mut tw); node: &Node {
let attr_details = AttrDetails { id: 0,
endpoints: &[],
},
endpoint_id: 0,
cluster_id: 0,
attr_id: 0, attr_id: 0,
list_index: None, list_index: None,
fab_idx: 1, fab_idx: 1,
fab_filter: false, fab_filter: false,
dataver: None,
wildcard: false,
}; };
acl.read_custom_attribute(&mut encoder, &attr_details);
let mut tw = TLVWriter::new(&mut writebuf);
let encoder = AttrDataEncoder::new(&attr, &mut tw);
acl.read(&attr, encoder).unwrap();
assert_eq!( assert_eq!(
// &[
// 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54,
// 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254,
// 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24,
// 24
// ],
&[ &[
21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1,
4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54,
1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, 3, 24, 54, 4, 24, 36, 254, 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24,
24 36, 254, 2, 24, 24, 24, 24
], ],
writebuf.as_borrow_slice() writebuf.as_slice()
); );
} }
writebuf.reset(0); writebuf.reset();
// Test 2, only single entry is read in the response with fabric filtering and fabric idx 1 // Test 2, only single entry is read in the response with fabric filtering and fabric idx 1
{ {
let mut tw = TLVWriter::new(&mut writebuf); let attr = AttrDetails {
let mut encoder = AttrReadEncoder::new(&mut tw); node: &Node {
id: 0,
let attr_details = AttrDetails { endpoints: &[],
},
endpoint_id: 0,
cluster_id: 0,
attr_id: 0, attr_id: 0,
list_index: None, list_index: None,
fab_idx: 1, fab_idx: 1,
fab_filter: true, fab_filter: true,
dataver: None,
wildcard: false,
}; };
acl.read_custom_attribute(&mut encoder, &attr_details);
let mut tw = TLVWriter::new(&mut writebuf);
let encoder = AttrDataEncoder::new(&attr, &mut tw);
acl.read(&attr, encoder).unwrap();
assert_eq!( assert_eq!(
// &[
// 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54,
// 4, 24, 36, 254, 1, 24, 24, 24, 24
// ],
&[ &[
21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1,
4, 24, 36, 254, 1, 24, 24, 24, 24 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 1, 24, 24, 24, 24
], ],
writebuf.as_borrow_slice() writebuf.as_slice()
); );
} }
writebuf.reset(0); writebuf.reset();
// Test 3, only single entry is read in the response with fabric filtering and fabric idx 2 // Test 3, only single entry is read in the response with fabric filtering and fabric idx 2
{ {
let mut tw = TLVWriter::new(&mut writebuf); let attr = AttrDetails {
let mut encoder = AttrReadEncoder::new(&mut tw); node: &Node {
id: 0,
let attr_details = AttrDetails { endpoints: &[],
},
endpoint_id: 0,
cluster_id: 0,
attr_id: 0, attr_id: 0,
list_index: None, list_index: None,
fab_idx: 2, fab_idx: 2,
fab_filter: true, fab_filter: true,
dataver: None,
wildcard: false,
}; };
acl.read_custom_attribute(&mut encoder, &attr_details);
let mut tw = TLVWriter::new(&mut writebuf);
let encoder = AttrDataEncoder::new(&attr, &mut tw);
acl.read(&attr, encoder).unwrap();
assert_eq!( assert_eq!(
// &[
// 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54,
// 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254,
// 2, 24, 24, 24, 24
// ],
&[ &[
21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1,
4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54,
2, 24, 24, 24, 24 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, 24
], ],
writebuf.as_borrow_slice() writebuf.as_slice()
); );
} }
} }

View file

@ -15,18 +15,20 @@
* limitations under the License. * limitations under the License.
*/ */
use num_derive::FromPrimitive; use core::convert::TryInto;
use crate::data_model::core::DataModel; use strum::FromRepr;
use crate::attribute_enum;
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::error::*; use crate::error::Error;
use crate::interaction_model::messages::GenericPath;
use crate::tlv::{TLVWriter, TagType, ToTLV}; use crate::tlv::{TLVWriter, TagType, ToTLV};
use log::error; use crate::utils::rand::Rand;
pub const ID: u32 = 0x001D; pub const ID: u32 = 0x001D;
#[derive(FromPrimitive)] #[derive(FromRepr)]
#[repr(u16)]
#[allow(clippy::enum_variant_names)] #[allow(clippy::enum_variant_names)]
pub enum Attributes { pub enum Attributes {
DeviceTypeList = 0, DeviceTypeList = 0,
@ -35,134 +37,210 @@ pub enum Attributes {
PartsList = 3, PartsList = 3,
} }
pub struct DescriptorCluster { attribute_enum!(Attributes);
base: Cluster,
endpoint_id: EndptId,
data_model: DataModel,
}
impl DescriptorCluster { pub const CLUSTER: Cluster<'static> = Cluster {
pub fn new(endpoint_id: EndptId, data_model: DataModel) -> Result<Box<Self>, Error> { id: ID as _,
let mut c = Box::new(DescriptorCluster { feature_map: 0,
endpoint_id, attributes: &[
data_model, FEATURE_MAP,
base: Cluster::new(ID)?, ATTRIBUTE_LIST,
}); Attribute::new(Attributes::DeviceTypeList as u16, Access::RV, Quality::NONE),
let attrs = [ Attribute::new(Attributes::ServerList as u16, Access::RV, Quality::NONE),
Attribute::new( Attribute::new(Attributes::PartsList as u16, Access::RV, Quality::NONE),
Attributes::DeviceTypeList as u16, Attribute::new(Attributes::ClientList as u16, Access::RV, Quality::NONE),
AttrValue::Custom, ],
Access::RV, commands: &[],
Quality::NONE,
),
Attribute::new(
Attributes::ServerList as u16,
AttrValue::Custom,
Access::RV,
Quality::NONE,
),
Attribute::new(
Attributes::PartsList as u16,
AttrValue::Custom,
Access::RV,
Quality::NONE,
),
Attribute::new(
Attributes::ClientList as u16,
AttrValue::Custom,
Access::RV,
Quality::NONE,
),
];
c.base.add_attributes(&attrs[..])?;
Ok(c)
}
fn encode_devtype_list(&self, tag: TagType, tw: &mut TLVWriter) {
let path = GenericPath {
endpoint: Some(self.endpoint_id),
cluster: None,
leaf: None,
}; };
let _ = tw.start_array(tag);
let dm = self.data_model.node.read().unwrap(); struct StandardPartsMatcher;
let _ = dm.for_each_endpoint(&path, |_, e| {
let dev_type = e.get_dev_type(); impl PartsMatcher for StandardPartsMatcher {
let _ = dev_type.to_tlv(tw, TagType::Anonymous); fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool {
Ok(()) our_endpoint == 0 && endpoint != our_endpoint
}); }
let _ = tw.end_container();
} }
fn encode_server_list(&self, tag: TagType, tw: &mut TLVWriter) { struct AggregatorPartsMatcher;
let path = GenericPath {
endpoint: Some(self.endpoint_id), impl PartsMatcher for AggregatorPartsMatcher {
cluster: None, fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool {
leaf: None, endpoint != our_endpoint && endpoint != 0
}; }
let _ = tw.start_array(tag);
let dm = self.data_model.node.read().unwrap();
let _ = dm.for_each_cluster(&path, |_current_path, c| {
let _ = tw.u32(TagType::Anonymous, c.base().id());
Ok(())
});
let _ = tw.end_container();
} }
fn encode_parts_list(&self, tag: TagType, tw: &mut TLVWriter) { pub trait PartsMatcher {
let path = GenericPath { fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool;
endpoint: None,
cluster: None,
leaf: None,
};
let _ = tw.start_array(tag);
if self.endpoint_id == 0 {
// TODO: If endpoint is another than 0, need to figure out what to do
let dm = self.data_model.node.read().unwrap();
let _ = dm.for_each_endpoint(&path, |current_path, _| {
if let Some(endpoint_id) = current_path.endpoint {
if endpoint_id != 0 {
let _ = tw.u16(TagType::Anonymous, endpoint_id);
}
}
Ok(())
});
}
let _ = tw.end_container();
} }
fn encode_client_list(&self, tag: TagType, tw: &mut TLVWriter) { impl<T> PartsMatcher for &T
where
T: PartsMatcher,
{
fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool {
(**self).describe(our_endpoint, endpoint)
}
}
impl<T> PartsMatcher for &mut T
where
T: PartsMatcher,
{
fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool {
(**self).describe(our_endpoint, endpoint)
}
}
pub struct DescriptorCluster<'a> {
matcher: &'a dyn PartsMatcher,
data_ver: Dataver,
}
impl DescriptorCluster<'static> {
pub fn new(rand: Rand) -> Self {
Self::new_matching(&StandardPartsMatcher, rand)
}
pub fn new_aggregator(rand: Rand) -> Self {
Self::new_matching(&AggregatorPartsMatcher, rand)
}
}
impl<'a> DescriptorCluster<'a> {
pub fn new_matching(matcher: &'a dyn PartsMatcher, rand: Rand) -> DescriptorCluster<'a> {
Self {
matcher,
data_ver: Dataver::new(rand),
}
}
pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
match attr.attr_id.try_into()? {
Attributes::DeviceTypeList => {
self.encode_devtype_list(
attr.node,
attr.endpoint_id,
AttrDataWriter::TAG,
&mut writer,
)?;
writer.complete()
}
Attributes::ServerList => {
self.encode_server_list(
attr.node,
attr.endpoint_id,
AttrDataWriter::TAG,
&mut writer,
)?;
writer.complete()
}
Attributes::PartsList => {
self.encode_parts_list(
attr.node,
attr.endpoint_id,
AttrDataWriter::TAG,
&mut writer,
)?;
writer.complete()
}
Attributes::ClientList => {
self.encode_client_list(
attr.node,
attr.endpoint_id,
AttrDataWriter::TAG,
&mut writer,
)?;
writer.complete()
}
}
}
} else {
Ok(())
}
}
fn encode_devtype_list(
&self,
node: &Node,
endpoint_id: u16,
tag: TagType,
tw: &mut TLVWriter,
) -> Result<(), Error> {
tw.start_array(tag)?;
for endpoint in node.endpoints {
if endpoint.id == endpoint_id {
let dev_type = endpoint.device_type;
dev_type.to_tlv(tw, TagType::Anonymous)?;
}
}
tw.end_container()
}
fn encode_server_list(
&self,
node: &Node,
endpoint_id: u16,
tag: TagType,
tw: &mut TLVWriter,
) -> Result<(), Error> {
tw.start_array(tag)?;
for endpoint in node.endpoints {
if endpoint.id == endpoint_id {
for cluster in endpoint.clusters {
tw.u32(TagType::Anonymous, cluster.id as _)?;
}
}
}
tw.end_container()
}
fn encode_parts_list(
&self,
node: &Node,
endpoint_id: u16,
tag: TagType,
tw: &mut TLVWriter,
) -> Result<(), Error> {
tw.start_array(tag)?;
for endpoint in node.endpoints {
if self.matcher.describe(endpoint_id, endpoint.id) {
tw.u16(TagType::Anonymous, endpoint.id)?;
}
}
tw.end_container()
}
fn encode_client_list(
&self,
_node: &Node,
_endpoint_id: u16,
tag: TagType,
tw: &mut TLVWriter,
) -> Result<(), Error> {
// No Clients supported // No Clients supported
let _ = tw.start_array(tag); tw.start_array(tag)?;
let _ = tw.end_container(); tw.end_container()
} }
} }
impl ClusterType for DescriptorCluster { impl<'a> Handler for DescriptorCluster<'a> {
fn base(&self) -> &Cluster { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
&self.base DescriptorCluster::read(self, attr, encoder)
} }
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
} }
fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { impl<'a> NonBlockingHandler for DescriptorCluster<'a> {}
match num::FromPrimitive::from_u16(attr.attr_id) {
Some(Attributes::DeviceTypeList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { impl<'a> ChangeNotifier<()> for DescriptorCluster<'a> {
self.encode_devtype_list(tag, tw) fn consume_change(&mut self) -> Option<()> {
})), self.data_ver.consume_change(())
Some(Attributes::ServerList) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
self.encode_server_list(tag, tw)
})),
Some(Attributes::PartsList) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
self.encode_parts_list(tag, tw)
})),
Some(Attributes::ClientList) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
self.encode_client_list(tag, tw)
})),
_ => {
error!("Attribute not supported: this shouldn't happen");
}
}
} }
} }

View file

@ -15,15 +15,10 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::{array::TryFromSliceError, fmt, str::Utf8Error};
array::TryFromSliceError, fmt, string::FromUtf8Error, sync::PoisonError, time::SystemTimeError,
};
use async_channel::{SendError, TryRecvError}; #[derive(Debug, PartialEq, Eq, Clone, Copy)]
use log::error; pub enum ErrorCode {
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum Error {
AttributeNotFound, AttributeNotFound,
AttributeIsCustom, AttributeIsCustom,
BufferTooSmall, BufferTooSmall,
@ -31,6 +26,13 @@ pub enum Error {
CommandNotFound, CommandNotFound,
Duplicate, Duplicate,
EndpointNotFound, EndpointNotFound,
InvalidAction,
InvalidCommand,
InvalidDataType,
UnsupportedAccess,
ResourceExhausted,
Busy,
DataVersionMismatch,
Crypto, Crypto,
TLSStack, TLSStack,
MdnsError, MdnsError,
@ -71,78 +73,155 @@ pub enum Error {
Utf8Fail, Utf8Fail,
} }
impl From<std::io::Error> for Error { impl From<ErrorCode> for Error {
fn from(_e: std::io::Error) -> Self { fn from(code: ErrorCode) -> Self {
// Keep things simple for now Self::new(code)
Self::StdIoError
} }
} }
impl<T> From<PoisonError<T>> for Error { pub struct Error {
fn from(_e: PoisonError<T>) -> Self { code: ErrorCode,
Self::RwLock #[cfg(all(feature = "std", feature = "backtrace"))]
backtrace: std::backtrace::Backtrace,
}
impl Error {
pub fn new(code: ErrorCode) -> Self {
Self {
code,
#[cfg(all(feature = "std", feature = "backtrace"))]
backtrace: std::backtrace::Backtrace::capture(),
}
}
pub const fn code(&self) -> ErrorCode {
self.code
}
#[cfg(all(feature = "std", feature = "backtrace"))]
pub const fn backtrace(&self) -> &std::backtrace::Backtrace {
&self.backtrace
}
pub fn remap<F>(self, matcher: F, to: Self) -> Self
where
F: FnOnce(&Self) -> bool,
{
if matcher(&self) {
to
} else {
self
}
}
pub fn map_invalid(self, to: Self) -> Self {
self.remap(
|e| matches!(e.code(), ErrorCode::Invalid | ErrorCode::InvalidData),
to,
)
}
pub fn map_invalid_command(self) -> Self {
self.map_invalid(Error::new(ErrorCode::InvalidCommand))
}
pub fn map_invalid_action(self) -> Self {
self.map_invalid(Error::new(ErrorCode::InvalidAction))
}
pub fn map_invalid_data_type(self) -> Self {
self.map_invalid(Error::new(ErrorCode::InvalidDataType))
}
}
#[cfg(feature = "std")]
impl From<std::io::Error> for Error {
fn from(_e: std::io::Error) -> Self {
// Keep things simple for now
Self::new(ErrorCode::StdIoError)
}
}
#[cfg(feature = "std")]
impl<T> From<std::sync::PoisonError<T>> for Error {
fn from(_e: std::sync::PoisonError<T>) -> Self {
Self::new(ErrorCode::RwLock)
} }
} }
#[cfg(feature = "crypto_openssl")] #[cfg(feature = "crypto_openssl")]
impl From<openssl::error::ErrorStack> for Error { impl From<openssl::error::ErrorStack> for Error {
fn from(e: openssl::error::ErrorStack) -> Self { fn from(e: openssl::error::ErrorStack) -> Self {
error!("Error in TLS: {}", e); ::log::error!("Error in TLS: {}", e);
Self::TLSStack Self::new(ErrorCode::TLSStack)
} }
} }
#[cfg(feature = "crypto_mbedtls")] #[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))]
impl From<mbedtls::Error> for Error { impl From<mbedtls::Error> for Error {
fn from(e: mbedtls::Error) -> Self { fn from(e: mbedtls::Error) -> Self {
error!("Error in TLS: {}", e); ::log::error!("Error in TLS: {}", e);
Self::TLSStack Self::new(ErrorCode::TLSStack)
}
}
#[cfg(target_os = "espidf")]
impl From<esp_idf_sys::EspError> for Error {
fn from(e: esp_idf_sys::EspError) -> Self {
::log::error!("Error in ESP: {}", e);
Self::new(ErrorCode::TLSStack) // TODO: Not a good mapping
} }
} }
#[cfg(feature = "crypto_rustcrypto")] #[cfg(feature = "crypto_rustcrypto")]
impl From<ccm::aead::Error> for Error { impl From<ccm::aead::Error> for Error {
fn from(_e: ccm::aead::Error) -> Self { fn from(_e: ccm::aead::Error) -> Self {
Self::Crypto Self::new(ErrorCode::Crypto)
} }
} }
impl From<SystemTimeError> for Error { #[cfg(feature = "std")]
fn from(_e: SystemTimeError) -> Self { impl From<std::time::SystemTimeError> for Error {
Self::SysTimeFail fn from(_e: std::time::SystemTimeError) -> Self {
Error::new(ErrorCode::SysTimeFail)
} }
} }
impl From<TryFromSliceError> for Error { impl From<TryFromSliceError> for Error {
fn from(_e: TryFromSliceError) -> Self { fn from(_e: TryFromSliceError) -> Self {
Self::Invalid Self::new(ErrorCode::Invalid)
} }
} }
impl<T> From<SendError<T>> for Error { impl From<Utf8Error> for Error {
fn from(e: SendError<T>) -> Self { fn from(_e: Utf8Error) -> Self {
error!("Error in channel send {}", e); Self::new(ErrorCode::Utf8Fail)
Self::Invalid
} }
} }
impl From<FromUtf8Error> for Error { impl fmt::Debug for Error {
fn from(_e: FromUtf8Error) -> Self { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Self::Utf8Fail #[cfg(not(all(feature = "std", feature = "backtrace")))]
} {
write!(f, "Error::{}", self)?;
} }
impl From<TryRecvError> for Error { #[cfg(all(feature = "std", feature = "backtrace"))]
fn from(e: TryRecvError) -> Self { {
error!("Error in channel try_recv {}", e); writeln!(f, "Error::{} {{", self)?;
Self::Invalid write!(f, "{}", self.backtrace())?;
writeln!(f, "}}")?;
}
Ok(())
} }
} }
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self) write!(f, "{:?}", self.code())
} }
} }
#[cfg(feature = "std")]
impl std::error::Error for Error {} impl std::error::Error for Error {}

View file

@ -15,56 +15,26 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::{Arc, Mutex, MutexGuard, RwLock}; use core::fmt::Write;
use byteorder::{BigEndian, ByteOrder, LittleEndian}; use byteorder::{BigEndian, ByteOrder, LittleEndian};
use log::{error, info}; use heapless::{String, Vec};
use owning_ref::RwLockReadGuardRef; use log::info;
use crate::{ use crate::{
cert::Cert, cert::{Cert, MAX_CERT_TLV_LEN},
crypto::{self, crypto_dummy::KeyPairDummy, hkdf_sha256, CryptoKeyPair, HmacSha256, KeyPair}, crypto::{self, hkdf_sha256, HmacSha256, KeyPair},
error::Error, error::{Error, ErrorCode},
group_keys::KeySet, group_keys::KeySet,
mdns::{self, Mdns}, mdns::{Mdns, ServiceMode},
sys::{Psm, SysMdnsService}, tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr},
tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, utils::writebuf::WriteBuf,
}; };
const MAX_CERT_TLV_LEN: usize = 350;
const COMPRESSED_FABRIC_ID_LEN: usize = 8; const COMPRESSED_FABRIC_ID_LEN: usize = 8;
macro_rules! fb_key {
($index:ident, $key:ident) => {
&format!("fb{}{}", $index, $key)
};
}
const ST_VID: &str = "vid";
const ST_RCA: &str = "rca";
const ST_ICA: &str = "ica";
const ST_NOC: &str = "noc";
const ST_IPK: &str = "ipk";
const ST_LBL: &str = "label";
const ST_PBKEY: &str = "pubkey";
const ST_PRKEY: &str = "privkey";
#[allow(dead_code)] #[allow(dead_code)]
pub struct Fabric { #[derive(Debug, ToTLV)]
node_id: u64,
fabric_id: u64,
vendor_id: u16,
key_pair: Box<dyn CryptoKeyPair>,
pub root_ca: Cert,
pub icac: Option<Cert>,
pub noc: Cert,
pub ipk: KeySet,
label: String,
compressed_id: [u8; COMPRESSED_FABRIC_ID_LEN],
mdns_service: Option<SysMdnsService>,
}
#[derive(ToTLV)]
#[tlvargs(lifetime = "'a", start = 1)] #[tlvargs(lifetime = "'a", start = 1)]
pub struct FabricDescriptor<'a> { pub struct FabricDescriptor<'a> {
root_public_key: OctetStr<'a>, root_public_key: OctetStr<'a>,
@ -77,64 +47,70 @@ pub struct FabricDescriptor<'a> {
pub fab_idx: Option<u8>, pub fab_idx: Option<u8>,
} }
#[derive(Debug, ToTLV, FromTLV)]
pub struct Fabric {
node_id: u64,
fabric_id: u64,
vendor_id: u16,
key_pair: KeyPair,
pub root_ca: Vec<u8, { MAX_CERT_TLV_LEN }>,
pub icac: Option<Vec<u8, { MAX_CERT_TLV_LEN }>>,
pub noc: Vec<u8, { MAX_CERT_TLV_LEN }>,
pub ipk: KeySet,
label: String<32>,
mdns_service_name: String<33>,
}
impl Fabric { impl Fabric {
pub fn new( pub fn new(
key_pair: KeyPair, key_pair: KeyPair,
root_ca: Cert, root_ca: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
icac: Option<Cert>, icac: Option<heapless::Vec<u8, { MAX_CERT_TLV_LEN }>>,
noc: Cert, noc: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
ipk: &[u8], ipk: &[u8],
vendor_id: u16, vendor_id: u16,
label: &str,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let node_id = noc.get_node_id()?; let (node_id, fabric_id) = {
let fabric_id = noc.get_fabric_id()?; let noc_p = Cert::new(&noc)?;
(noc_p.get_node_id()?, noc_p.get_fabric_id()?)
let mut f = Self {
node_id,
fabric_id,
vendor_id,
key_pair: Box::new(key_pair),
root_ca,
icac,
noc,
ipk: KeySet::default(),
compressed_id: [0; COMPRESSED_FABRIC_ID_LEN],
label: "".into(),
mdns_service: None,
}; };
Fabric::get_compressed_id(f.root_ca.get_pubkey(), fabric_id, &mut f.compressed_id)?;
f.ipk = KeySet::new(ipk, &f.compressed_id)?;
let mut mdns_service_name = String::with_capacity(33); let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN];
for c in f.compressed_id {
mdns_service_name.push_str(&format!("{:02X}", c)); let ipk = {
let root_ca_p = Cert::new(&root_ca)?;
Fabric::get_compressed_id(root_ca_p.get_pubkey(), fabric_id, &mut compressed_id)?;
KeySet::new(ipk, &compressed_id)?
};
let mut mdns_service_name = heapless::String::<33>::new();
for c in compressed_id {
let mut hex = heapless::String::<4>::new();
write!(&mut hex, "{:02X}", c).unwrap();
mdns_service_name.push_str(&hex).unwrap();
} }
mdns_service_name.push('-'); mdns_service_name.push('-').unwrap();
let mut node_id_be: [u8; 8] = [0; 8]; let mut node_id_be: [u8; 8] = [0; 8];
BigEndian::write_u64(&mut node_id_be, node_id); BigEndian::write_u64(&mut node_id_be, node_id);
for c in node_id_be { for c in node_id_be {
mdns_service_name.push_str(&format!("{:02X}", c)); let mut hex = heapless::String::<4>::new();
write!(&mut hex, "{:02X}", c).unwrap();
mdns_service_name.push_str(&hex).unwrap();
} }
info!("MDNS Service Name: {}", mdns_service_name); info!("MDNS Service Name: {}", mdns_service_name);
f.mdns_service = Some(
Mdns::get()?.publish_service(&mdns_service_name, mdns::ServiceMode::Commissioned)?,
);
Ok(f)
}
pub fn dummy() -> Result<Self, Error> {
Ok(Self { Ok(Self {
node_id: 0, node_id,
fabric_id: 0, fabric_id,
vendor_id: 0, vendor_id,
key_pair: Box::new(KeyPairDummy::new()?), key_pair,
root_ca: Cert::default(), root_ca,
icac: Some(Cert::default()), icac,
noc: Cert::default(), noc,
ipk: KeySet::default(), ipk,
label: "".into(), label: label.into(),
compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], mdns_service_name,
mdns_service: None,
}) })
} }
@ -147,14 +123,14 @@ impl Fabric {
0x69, 0x63, 0x69, 0x63,
]; ];
hkdf_sha256(&fabric_id_be, root_pubkey, &COMPRESSED_FABRIC_ID_INFO, out) hkdf_sha256(&fabric_id_be, root_pubkey, &COMPRESSED_FABRIC_ID_INFO, out)
.map_err(|_| Error::NoSpace) .map_err(|_| Error::from(ErrorCode::NoSpace))
} }
pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<(), Error> { pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<(), Error> {
let mut mac = HmacSha256::new(self.ipk.op_key())?; let mut mac = HmacSha256::new(self.ipk.op_key())?;
mac.update(random)?; mac.update(random)?;
mac.update(self.root_ca.get_pubkey())?; mac.update(self.get_root_ca()?.get_pubkey())?;
let mut buf: [u8; 8] = [0; 8]; let mut buf: [u8; 8] = [0; 8];
LittleEndian::write_u64(&mut buf, self.fabric_id); LittleEndian::write_u64(&mut buf, self.fabric_id);
@ -168,7 +144,7 @@ impl Fabric {
if id.as_slice() == target { if id.as_slice() == target {
Ok(()) Ok(())
} else { } else {
Err(Error::NotFound) Err(ErrorCode::NotFound.into())
} }
} }
@ -184,255 +160,180 @@ impl Fabric {
self.fabric_id self.fabric_id
} }
pub fn get_fabric_desc(&self, fab_idx: u8) -> FabricDescriptor { pub fn get_root_ca(&self) -> Result<Cert<'_>, Error> {
FabricDescriptor { Cert::new(&self.root_ca)
root_public_key: OctetStr::new(self.root_ca.get_pubkey()), }
pub fn get_fabric_desc<'a>(
&'a self,
fab_idx: u8,
root_ca_cert: &'a Cert,
) -> Result<FabricDescriptor<'a>, Error> {
let desc = FabricDescriptor {
root_public_key: OctetStr::new(root_ca_cert.get_pubkey()),
vendor_id: self.vendor_id, vendor_id: self.vendor_id,
fabric_id: self.fabric_id, fabric_id: self.fabric_id,
node_id: self.node_id, node_id: self.node_id,
label: UtfStr(self.label.as_bytes()), label: UtfStr(self.label.as_bytes()),
fab_idx: Some(fab_idx), fab_idx: Some(fab_idx),
}
}
fn rm_store(&self, index: usize, psm: &MutexGuard<Psm>) {
psm.rm(fb_key!(index, ST_RCA));
psm.rm(fb_key!(index, ST_ICA));
psm.rm(fb_key!(index, ST_NOC));
psm.rm(fb_key!(index, ST_IPK));
psm.rm(fb_key!(index, ST_LBL));
psm.rm(fb_key!(index, ST_PBKEY));
psm.rm(fb_key!(index, ST_PRKEY));
psm.rm(fb_key!(index, ST_VID));
}
fn store(&self, index: usize, psm: &MutexGuard<Psm>) -> Result<(), Error> {
let mut key = [0u8; MAX_CERT_TLV_LEN];
let len = self.root_ca.as_tlv(&mut key)?;
psm.set_kv_slice(fb_key!(index, ST_RCA), &key[..len])?;
let len = if let Some(icac) = &self.icac {
icac.as_tlv(&mut key)?
} else {
0
};
psm.set_kv_slice(fb_key!(index, ST_ICA), &key[..len])?;
let len = self.noc.as_tlv(&mut key)?;
psm.set_kv_slice(fb_key!(index, ST_NOC), &key[..len])?;
psm.set_kv_slice(fb_key!(index, ST_IPK), self.ipk.epoch_key())?;
psm.set_kv_slice(fb_key!(index, ST_LBL), self.label.as_bytes())?;
let mut key = [0_u8; crypto::EC_POINT_LEN_BYTES];
let len = self.key_pair.get_public_key(&mut key)?;
let key = &key[..len];
psm.set_kv_slice(fb_key!(index, ST_PBKEY), key)?;
let mut key = [0_u8; crypto::BIGNUM_LEN_BYTES];
let len = self.key_pair.get_private_key(&mut key)?;
let key = &key[..len];
psm.set_kv_slice(fb_key!(index, ST_PRKEY), key)?;
psm.set_kv_u64(fb_key!(index, ST_VID), self.vendor_id.into())?;
Ok(())
}
fn load(index: usize, psm: &MutexGuard<Psm>) -> Result<Self, Error> {
let mut root_ca = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_RCA), &mut root_ca)?;
let root_ca = Cert::new(root_ca.as_slice())?;
let mut icac = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_ICA), &mut icac)?;
let icac = if !icac.is_empty() {
Some(Cert::new(icac.as_slice())?)
} else {
None
}; };
let mut noc = Vec::new(); Ok(desc)
psm.get_kv_slice(fb_key!(index, ST_NOC), &mut noc)?;
let noc = Cert::new(noc.as_slice())?;
let mut ipk = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_IPK), &mut ipk)?;
let mut label = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_LBL), &mut label)?;
let label = String::from_utf8(label).map_err(|_| {
error!("Couldn't read label");
Error::Invalid
})?;
let mut pub_key = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_PBKEY), &mut pub_key)?;
let mut priv_key = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_PRKEY), &mut priv_key)?;
let keypair = KeyPair::new_from_components(pub_key.as_slice(), priv_key.as_slice())?;
let mut vendor_id = 0;
psm.get_kv_u64(fb_key!(index, ST_VID), &mut vendor_id)?;
let f = Fabric::new(
keypair,
root_ca,
icac,
noc,
ipk.as_slice(),
vendor_id as u16,
);
f.map(|mut f| {
f.label = label;
f
})
} }
} }
pub const MAX_SUPPORTED_FABRICS: usize = 3; pub const MAX_SUPPORTED_FABRICS: usize = 3;
#[derive(Default)]
pub struct FabricMgrInner { type FabricEntries = Vec<Option<Fabric>, MAX_SUPPORTED_FABRICS>;
// The outside world expects Fabric Index to be one more than the actual one
// since 0 is not allowed. Need to handle this cleanly somehow
pub fabrics: [Option<Fabric>; MAX_SUPPORTED_FABRICS],
}
pub struct FabricMgr { pub struct FabricMgr {
inner: RwLock<FabricMgrInner>, fabrics: FabricEntries,
psm: Arc<Mutex<Psm>>, changed: bool,
} }
impl FabricMgr { impl FabricMgr {
pub fn new() -> Result<Self, Error> { #[inline(always)]
let dummy_fabric = Fabric::dummy()?; pub const fn new() -> Self {
let mut mgr = FabricMgrInner::default(); Self {
mgr.fabrics[0] = Some(dummy_fabric); fabrics: FabricEntries::new(),
let mut fm = Self { changed: false,
inner: RwLock::new(mgr), }
psm: Psm::get()?,
};
fm.load()?;
Ok(fm)
} }
fn store(&self, index: usize, fabric: &Fabric) -> Result<(), Error> { pub fn load(&mut self, data: &[u8], mdns: &dyn Mdns) -> Result<(), Error> {
let psm = self.psm.lock().unwrap(); for fabric in self.fabrics.iter().flatten() {
fabric.store(index, &psm) mdns.remove(&fabric.mdns_service_name)?;
} }
fn load(&mut self) -> Result<(), Error> { let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?;
let mut mgr = self.inner.write()?;
let psm = self.psm.lock().unwrap(); tlv::from_tlv(&mut self.fabrics, &root)?;
for i in 0..MAX_SUPPORTED_FABRICS {
let result = Fabric::load(i, &psm); for fabric in self.fabrics.iter().flatten() {
if let Ok(fabric) = result { mdns.add(&fabric.mdns_service_name, ServiceMode::Commissioned)?;
info!("Adding new fabric at index {}", i);
mgr.fabrics[i] = Some(fabric);
}
} }
self.changed = false;
Ok(()) Ok(())
} }
pub fn add(&self, f: Fabric) -> Result<u8, Error> { pub fn store<'a>(&mut self, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> {
let mut mgr = self.inner.write()?; if self.changed {
let index = mgr let mut wb = WriteBuf::new(buf);
.fabrics let mut tw = TLVWriter::new(&mut wb);
.iter()
.position(|f| f.is_none())
.ok_or(Error::NoSpace)?;
self.store(index, &f)?; self.fabrics
.as_slice()
.to_tlv(&mut tw, TagType::Anonymous)?;
mgr.fabrics[index] = Some(f); self.changed = false;
Ok(index as u8)
let len = tw.get_tail();
Ok(Some(&buf[..len]))
} else {
Ok(None)
}
} }
pub fn remove(&self, fab_idx: u8) -> Result<(), Error> { pub fn is_changed(&self) -> bool {
let fab_idx = fab_idx as usize; self.changed
let mut mgr = self.inner.write().unwrap(); }
let psm = self.psm.lock().unwrap();
if let Some(f) = &mgr.fabrics[fab_idx] { pub fn add(&mut self, f: Fabric, mdns: &dyn Mdns) -> Result<u8, Error> {
f.rm_store(fab_idx, &psm); let slot = self.fabrics.iter().position(|x| x.is_none());
mgr.fabrics[fab_idx] = None;
if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS {
mdns.add(&f.mdns_service_name, ServiceMode::Commissioned)?;
self.changed = true;
if let Some(index) = slot {
self.fabrics[index] = Some(f);
Ok((index + 1) as u8)
} else {
self.fabrics
.push(Some(f))
.map_err(|_| ErrorCode::NoSpace)
.unwrap();
Ok(self.fabrics.len() as u8)
}
} else {
Err(ErrorCode::NoSpace.into())
}
}
pub fn remove(&mut self, fab_idx: u8, mdns: &dyn Mdns) -> Result<(), Error> {
if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() {
if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() {
mdns.remove(&f.mdns_service_name)?;
self.changed = true;
Ok(()) Ok(())
} else { } else {
Err(Error::NotFound) Err(ErrorCode::NotFound.into())
}
} else {
Err(ErrorCode::NotFound.into())
} }
} }
pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<usize, Error> { pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<usize, Error> {
let mgr = self.inner.read()?; for (index, fabric) in self.fabrics.iter().enumerate() {
for i in 0..MAX_SUPPORTED_FABRICS { if let Some(fabric) = fabric {
if let Some(fabric) = &mgr.fabrics[i] {
if fabric.match_dest_id(random, target).is_ok() { if fabric.match_dest_id(random, target).is_ok() {
return Ok(i); return Ok(index + 1);
} }
} }
} }
Err(Error::NotFound) Err(ErrorCode::NotFound.into())
} }
pub fn get_fabric<'ret, 'me: 'ret>( pub fn get_fabric(&self, idx: usize) -> Result<Option<&Fabric>, Error> {
&'me self, if idx == 0 {
idx: usize, Ok(None)
) -> Result<RwLockReadGuardRef<'ret, FabricMgrInner, Option<Fabric>>, Error> { } else {
Ok(RwLockReadGuardRef::new(self.inner.read()?).map(|fm| &fm.fabrics[idx])) Ok(self.fabrics[idx - 1].as_ref())
}
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
let mgr = self.inner.read().unwrap(); !self.fabrics.iter().any(Option::is_some)
for i in 1..MAX_SUPPORTED_FABRICS {
if mgr.fabrics[i].is_some() {
return false;
}
}
true
} }
pub fn used_count(&self) -> usize { pub fn used_count(&self) -> usize {
let mgr = self.inner.read().unwrap(); self.fabrics.iter().filter(|f| f.is_some()).count()
let mut count = 0;
for i in 1..MAX_SUPPORTED_FABRICS {
if mgr.fabrics[i].is_some() {
count += 1;
}
}
count
} }
// Parameters to T are the Fabric and its Fabric Index // Parameters to T are the Fabric and its Fabric Index
pub fn for_each<T>(&self, mut f: T) -> Result<(), Error> pub fn for_each<T>(&self, mut f: T) -> Result<(), Error>
where where
T: FnMut(&Fabric, u8), T: FnMut(&Fabric, u8) -> Result<(), Error>,
{ {
let mgr = self.inner.read().unwrap(); for (index, fabric) in self.fabrics.iter().enumerate() {
for i in 1..MAX_SUPPORTED_FABRICS { if let Some(fabric) = fabric {
if let Some(fabric) = &mgr.fabrics[i] { f(fabric, (index + 1) as u8)?;
f(fabric, i as u8)
} }
} }
Ok(()) Ok(())
} }
pub fn set_label(&self, index: u8, label: String) -> Result<(), Error> { pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> {
let index = index as usize; if !label.is_empty()
let mut mgr = self.inner.write()?; && self
if !label.is_empty() { .fabrics
for i in 1..MAX_SUPPORTED_FABRICS { .iter()
if let Some(fabric) = &mgr.fabrics[i] { .filter_map(|f| f.as_ref())
if fabric.label == label { .any(|f| f.label == label)
return Err(Error::Invalid); {
} return Err(ErrorCode::Invalid.into());
}
}
}
if let Some(fabric) = &mut mgr.fabrics[index] {
let old = fabric.label.clone();
fabric.label = label;
let psm = self.psm.lock().unwrap();
if fabric.store(index, &psm).is_err() {
fabric.label = old;
return Err(Error::StdIoError);
} }
let index = (index - 1) as usize;
if let Some(fabric) = &mut self.fabrics[index] {
fabric.label = label.into();
self.changed = true;
} }
Ok(()) Ok(())
} }

View file

@ -15,39 +15,18 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::{Arc, Mutex, Once}; use crate::{
crypto::{self, SYMM_KEY_LEN_BYTES},
error::{Error, ErrorCode},
tlv::{FromTLV, TLVWriter, TagType, ToTLV},
};
use crate::{crypto, error::Error}; type KeySetKey = [u8; SYMM_KEY_LEN_BYTES];
// This is just makeshift implementation for now, not used anywhere #[derive(Debug, Default, FromTLV, ToTLV)]
pub struct GroupKeys {}
static mut G_GRP_KEYS: Option<Arc<Mutex<GroupKeys>>> = None;
static INIT: Once = Once::new();
impl GroupKeys {
fn new() -> Self {
Self {}
}
pub fn get() -> Result<Arc<Mutex<Self>>, Error> {
unsafe {
INIT.call_once(|| {
G_GRP_KEYS = Some(Arc::new(Mutex::new(GroupKeys::new())));
});
Ok(G_GRP_KEYS.as_ref().ok_or(Error::Invalid)?.clone())
}
}
pub fn insert_key() -> Result<(), Error> {
Ok(())
}
}
#[derive(Debug, Default)]
pub struct KeySet { pub struct KeySet {
pub epoch_key: [u8; crypto::SYMM_KEY_LEN_BYTES], pub epoch_key: KeySetKey,
pub op_key: [u8; crypto::SYMM_KEY_LEN_BYTES], pub op_key: KeySetKey,
} }
impl KeySet { impl KeySet {
@ -63,7 +42,8 @@ impl KeySet {
0x47, 0x72, 0x6f, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x20, 0x76, 0x31, 0x2e, 0x30, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x20, 0x76, 0x31, 0x2e, 0x30,
]; ];
crypto::hkdf_sha256(compressed_id, ipk, &GRP_KEY_INFO, opkey).map_err(|_| Error::NoSpace) crypto::hkdf_sha256(compressed_id, ipk, &GRP_KEY_INFO, opkey)
.map_err(|_| ErrorCode::NoSpace.into())
} }
pub fn op_key(&self) -> &[u8] { pub fn op_key(&self) -> &[u8] {

View file

@ -1,88 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use super::core::IMStatusCode;
use super::core::OpCode;
use super::messages::ib;
use super::messages::msg;
use super::messages::msg::InvReq;
use super::InteractionModel;
use super::Transaction;
use crate::{
error::*,
tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType},
transport::{packet::Packet, proto_demux::ResponseRequired},
};
use log::error;
#[macro_export]
macro_rules! cmd_enter {
($e:expr) => {{
use colored::Colorize;
info! {"{} {}", "Handling Command".cyan(), $e.cyan()}
}};
}
pub struct CommandReq<'a, 'b, 'c, 'd, 'e> {
pub cmd: ib::CmdPath,
pub data: TLVElement<'a>,
pub resp: &'a mut TLVWriter<'b, 'c>,
pub trans: &'a mut Transaction<'d, 'e>,
}
impl InteractionModel {
pub fn handle_invoke_req(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
if InteractionModel::req_timeout_handled(trans, proto_tx)? {
return Ok(ResponseRequired::Yes);
}
proto_tx.set_proto_opcode(OpCode::InvokeResponse as u8);
let mut tw = TLVWriter::new(proto_tx.get_writebuf()?);
let root = get_root_node_struct(rx_buf)?;
let inv_req = InvReq::from_tlv(&root)?;
let timed_tx = trans.get_timeout().map(|_| true);
let timed_request = inv_req.timed_request.filter(|a| *a);
// Either both should be None, or both should be Some(true)
if timed_tx != timed_request {
InteractionModel::create_status_response(proto_tx, IMStatusCode::TimedRequestMisMatch)?;
return Ok(ResponseRequired::Yes);
}
tw.start_struct(TagType::Anonymous)?;
// Suppress Response -> TODO: Need to revisit this for cases where we send a command back
tw.bool(
TagType::Context(msg::InvRespTag::SupressResponse as u8),
false,
)?;
self.consumer
.consume_invoke_cmd(&inv_req, trans, &mut tw)
.map_err(|e| {
error!("Error in handling command: {:?}", e);
print_tlv_list(rx_buf);
e
})?;
tw.end_container()?;
Ok(ResponseRequired::Yes)
}
}

File diff suppressed because it is too large Load diff

View file

@ -17,13 +17,13 @@
use crate::{ use crate::{
data_model::objects::{ClusterId, EndptId}, data_model::objects::{ClusterId, EndptId},
error::Error, error::{Error, ErrorCode},
tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{FromTLV, TLVWriter, TagType, ToTLV},
}; };
// A generic path with endpoint, clusters, and a leaf // A generic path with endpoint, clusters, and a leaf
// The leaf could be command, attribute, event // The leaf could be command, attribute, event
#[derive(Default, Clone, Copy, Debug, PartialEq, FromTLV, ToTLV)] #[derive(Default, Clone, Debug, PartialEq, FromTLV, ToTLV)]
#[tlvargs(datatype = "list")] #[tlvargs(datatype = "list")]
pub struct GenericPath { pub struct GenericPath {
pub endpoint: Option<EndptId>, pub endpoint: Option<EndptId>,
@ -48,7 +48,7 @@ impl GenericPath {
cluster: Some(c), cluster: Some(c),
leaf: Some(l), leaf: Some(l),
} => Ok((e, c, l)), } => Ok((e, c, l)),
_ => Err(Error::Invalid), _ => Err(ErrorCode::Invalid.into()),
} }
} }
/// Returns true, if the path is wildcard /// Returns true, if the path is wildcard
@ -69,7 +69,7 @@ pub mod msg {
use crate::{ use crate::{
error::Error, error::Error,
interaction_model::core::IMStatusCode, interaction_model::core::IMStatusCode,
tlv::{FromTLV, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{FromTLV, TLVArray, TLVWriter, TagType, ToTLV},
}; };
use super::ib::{ use super::ib::{
@ -77,7 +77,7 @@ pub mod msg {
EventPath, EventPath,
}; };
#[derive(Default, FromTLV, ToTLV)] #[derive(Debug, Default, FromTLV, ToTLV)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct SubscribeReq<'a> { pub struct SubscribeReq<'a> {
pub keep_subs: bool, pub keep_subs: bool,
@ -106,16 +106,6 @@ pub mod msg {
self.attr_requests = Some(TLVArray::new(requests)); self.attr_requests = Some(TLVArray::new(requests));
self self
} }
pub fn to_read_req(&self) -> ReadReq<'a> {
ReadReq {
attr_requests: self.attr_requests,
event_requests: self.event_requests,
event_filters: self.event_filters,
fabric_filtered: self.fabric_filtered,
dataver_filters: self.dataver_filters,
}
}
} }
#[derive(Debug, FromTLV, ToTLV)] #[derive(Debug, FromTLV, ToTLV)]
@ -160,13 +150,6 @@ pub mod msg {
pub inv_requests: Option<TLVArray<'a, CmdData<'a>>>, pub inv_requests: Option<TLVArray<'a, CmdData<'a>>>,
} }
// This enum is helpful when we are constructing the response
// step by step in incremental manner
pub enum InvRespTag {
SupressResponse = 0,
InvokeResponses = 1,
}
#[derive(FromTLV, ToTLV, Debug)] #[derive(FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct InvResp<'a> { pub struct InvResp<'a> {
@ -174,7 +157,14 @@ pub mod msg {
pub inv_responses: Option<TLVArray<'a, ib::InvResp<'a>>>, pub inv_responses: Option<TLVArray<'a, ib::InvResp<'a>>>,
} }
#[derive(Default, ToTLV, FromTLV)] // This enum is helpful when we are constructing the response
// step by step in incremental manner
pub enum InvRespTag {
SupressResponse = 0,
InvokeResponses = 1,
}
#[derive(Default, ToTLV, FromTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct ReadReq<'a> { pub struct ReadReq<'a> {
pub attr_requests: Option<TLVArray<'a, AttrPath>>, pub attr_requests: Option<TLVArray<'a, AttrPath>>,
@ -198,17 +188,17 @@ pub mod msg {
} }
} }
#[derive(ToTLV, FromTLV)] #[derive(FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'b")] #[tlvargs(lifetime = "'a")]
pub struct WriteReq<'a, 'b> { pub struct WriteReq<'a> {
pub supress_response: Option<bool>, pub supress_response: Option<bool>,
timed_request: Option<bool>, timed_request: Option<bool>,
pub write_requests: TLVArray<'a, AttrData<'b>>, pub write_requests: TLVArray<'a, AttrData<'a>>,
more_chunked: Option<bool>, more_chunked: Option<bool>,
} }
impl<'a, 'b> WriteReq<'a, 'b> { impl<'a> WriteReq<'a> {
pub fn new(supress_response: bool, write_requests: &'a [AttrData<'b>]) -> Self { pub fn new(supress_response: bool, write_requests: &'a [AttrData<'a>]) -> Self {
let mut w = Self { let mut w = Self {
supress_response: None, supress_response: None,
write_requests: TLVArray::new(write_requests), write_requests: TLVArray::new(write_requests),
@ -223,7 +213,7 @@ pub mod msg {
} }
// Report Data // Report Data
#[derive(FromTLV, ToTLV)] #[derive(FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct ReportDataMsg<'a> { pub struct ReportDataMsg<'a> {
pub subscription_id: Option<u32>, pub subscription_id: Option<u32>,
@ -243,7 +233,7 @@ pub mod msg {
} }
// Write Response // Write Response
#[derive(ToTLV, FromTLV)] #[derive(ToTLV, FromTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct WriteResp<'a> { pub struct WriteResp<'a> {
pub write_responses: TLVArray<'a, AttrStatus>, pub write_responses: TLVArray<'a, AttrStatus>,
@ -255,11 +245,11 @@ pub mod msg {
} }
pub mod ib { pub mod ib {
use std::fmt::Debug; use core::fmt::Debug;
use crate::{ use crate::{
data_model::objects::{AttrDetails, AttrId, ClusterId, EncodeValue, EndptId}, data_model::objects::{AttrDetails, AttrId, ClusterId, CmdId, EncodeValue, EndptId},
error::Error, error::{Error, ErrorCode},
interaction_model::core::IMStatusCode, interaction_model::core::IMStatusCode,
tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV},
}; };
@ -268,7 +258,7 @@ pub mod ib {
use super::GenericPath; use super::GenericPath;
// Command Response // Command Response
#[derive(Clone, Copy, FromTLV, ToTLV, Debug)] #[derive(Clone, FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub enum InvResp<'a> { pub enum InvResp<'a> {
Cmd(CmdData<'a>), Cmd(CmdData<'a>),
@ -276,18 +266,6 @@ pub mod ib {
} }
impl<'a> InvResp<'a> { impl<'a> InvResp<'a> {
pub fn cmd_new(
endpoint: EndptId,
cluster: ClusterId,
cmd: u16,
data: EncodeValue<'a>,
) -> Self {
Self::Cmd(CmdData::new(
CmdPath::new(Some(endpoint), Some(cluster), Some(cmd)),
data,
))
}
pub fn status_new(cmd_path: CmdPath, status: IMStatusCode, cluster_status: u16) -> Self { pub fn status_new(cmd_path: CmdPath, status: IMStatusCode, cluster_status: u16) -> Self {
Self::Status(CmdStatus { Self::Status(CmdStatus {
path: cmd_path, path: cmd_path,
@ -296,7 +274,24 @@ pub mod ib {
} }
} }
#[derive(FromTLV, ToTLV, Copy, Clone, PartialEq, Debug)] impl<'a> From<CmdData<'a>> for InvResp<'a> {
fn from(value: CmdData<'a>) -> Self {
Self::Cmd(value)
}
}
pub enum InvRespTag {
Cmd = 0,
Status = 1,
}
impl<'a> From<CmdStatus> for InvResp<'a> {
fn from(value: CmdStatus) -> Self {
Self::Status(value)
}
}
#[derive(FromTLV, ToTLV, Clone, PartialEq, Debug)]
pub struct CmdStatus { pub struct CmdStatus {
path: CmdPath, path: CmdPath,
status: Status, status: Status,
@ -314,7 +309,7 @@ pub mod ib {
} }
} }
#[derive(Debug, Clone, Copy, FromTLV, ToTLV)] #[derive(Debug, Clone, FromTLV, ToTLV)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct CmdData<'a> { pub struct CmdData<'a> {
pub path: CmdPath, pub path: CmdPath,
@ -327,8 +322,13 @@ pub mod ib {
} }
} }
pub enum CmdDataTag {
Path = 0,
Data = 1,
}
// Status // Status
#[derive(Debug, Clone, Copy, PartialEq, FromTLV, ToTLV)] #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)]
pub struct Status { pub struct Status {
pub status: IMStatusCode, pub status: IMStatusCode,
pub cluster_status: u16, pub cluster_status: u16,
@ -344,7 +344,7 @@ pub mod ib {
} }
// Attribute Response // Attribute Response
#[derive(Clone, Copy, FromTLV, ToTLV, PartialEq, Debug)] #[derive(Clone, FromTLV, ToTLV, PartialEq, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub enum AttrResp<'a> { pub enum AttrResp<'a> {
Status(AttrStatus), Status(AttrStatus),
@ -352,10 +352,6 @@ pub mod ib {
} }
impl<'a> AttrResp<'a> { impl<'a> AttrResp<'a> {
pub fn new(data_ver: u32, path: &AttrPath, data: EncodeValue<'a>) -> Self {
AttrResp::Data(AttrData::new(Some(data_ver), *path, data))
}
pub fn unwrap_data(self) -> AttrData<'a> { pub fn unwrap_data(self) -> AttrData<'a> {
match self { match self {
AttrResp::Data(d) => d, AttrResp::Data(d) => d,
@ -366,8 +362,25 @@ pub mod ib {
} }
} }
impl<'a> From<AttrData<'a>> for AttrResp<'a> {
fn from(value: AttrData<'a>) -> Self {
Self::Data(value)
}
}
impl<'a> From<AttrStatus> for AttrResp<'a> {
fn from(value: AttrStatus) -> Self {
Self::Status(value)
}
}
pub enum AttrRespTag {
Status = 0,
Data = 1,
}
// Attribute Data // Attribute Data
#[derive(Clone, Copy, PartialEq, FromTLV, ToTLV, Debug)] #[derive(Clone, PartialEq, FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct AttrData<'a> { pub struct AttrData<'a> {
pub data_ver: Option<u32>, pub data_ver: Option<u32>,
@ -385,6 +398,12 @@ pub mod ib {
} }
} }
pub enum AttrDataTag {
DataVer = 0,
Path = 1,
Data = 2,
}
#[derive(Debug)] #[derive(Debug)]
/// Operations on an Interaction Model List /// Operations on an Interaction Model List
pub enum ListOperation { pub enum ListOperation {
@ -399,13 +418,9 @@ pub mod ib {
} }
/// Attribute Lists in Attribute Data are special. Infer the correct meaning using this function /// Attribute Lists in Attribute Data are special. Infer the correct meaning using this function
pub fn attr_list_write<F>( pub fn attr_list_write<F>(attr: &AttrDetails, data: &TLVElement, mut f: F) -> Result<(), Error>
attr: &AttrDetails,
data: &TLVElement,
mut f: F,
) -> Result<(), IMStatusCode>
where where
F: FnMut(ListOperation, &TLVElement) -> Result<(), IMStatusCode>, F: FnMut(ListOperation, &TLVElement) -> Result<(), Error>,
{ {
if let Some(Nullable::NotNull(index)) = attr.list_index { if let Some(Nullable::NotNull(index)) = attr.list_index {
// If list index is valid, // If list index is valid,
@ -422,7 +437,7 @@ pub mod ib {
f(ListOperation::DeleteList, data)?; f(ListOperation::DeleteList, data)?;
// Now the data must be a list, that should be added item by item // Now the data must be a list, that should be added item by item
let container = data.enter().ok_or(Error::Invalid)?; let container = data.enter().ok_or(ErrorCode::Invalid)?;
for d in container { for d in container {
f(ListOperation::AddItem, &d)?; f(ListOperation::AddItem, &d)?;
} }
@ -433,7 +448,7 @@ pub mod ib {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, FromTLV, ToTLV)] #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)]
pub struct AttrStatus { pub struct AttrStatus {
path: AttrPath, path: AttrPath,
status: Status, status: Status,
@ -449,7 +464,7 @@ pub mod ib {
} }
// Attribute Path // Attribute Path
#[derive(Default, Clone, Copy, Debug, PartialEq, FromTLV, ToTLV)] #[derive(Default, Clone, Debug, PartialEq, FromTLV, ToTLV)]
#[tlvargs(datatype = "list")] #[tlvargs(datatype = "list")]
pub struct AttrPath { pub struct AttrPath {
pub tag_compression: Option<bool>, pub tag_compression: Option<bool>,
@ -476,7 +491,7 @@ pub mod ib {
} }
// Command Path // Command Path
#[derive(Default, Debug, Copy, Clone, PartialEq)] #[derive(Default, Debug, Clone, PartialEq)]
pub struct CmdPath { pub struct CmdPath {
pub path: GenericPath, pub path: GenericPath,
} }
@ -499,13 +514,13 @@ pub mod ib {
pub fn new( pub fn new(
endpoint: Option<EndptId>, endpoint: Option<EndptId>,
cluster: Option<ClusterId>, cluster: Option<ClusterId>,
command: Option<u16>, command: Option<CmdId>,
) -> Self { ) -> Self {
Self { Self {
path: GenericPath { path: GenericPath {
endpoint, endpoint,
cluster, cluster,
leaf: command.map(|a| a as u32), leaf: command,
}, },
} }
} }
@ -519,7 +534,7 @@ pub mod ib {
if c.path.leaf.is_none() { if c.path.leaf.is_none() {
error!("Wildcard command parameter not supported"); error!("Wildcard command parameter not supported");
Err(Error::CommandNotFound) Err(ErrorCode::CommandNotFound.into())
} else { } else {
Ok(c) Ok(c)
} }
@ -532,20 +547,20 @@ pub mod ib {
} }
} }
#[derive(FromTLV, ToTLV, Copy, Clone)] #[derive(FromTLV, ToTLV, Clone, Debug)]
pub struct ClusterPath { pub struct ClusterPath {
pub node: Option<u64>, pub node: Option<u64>,
pub endpoint: EndptId, pub endpoint: EndptId,
pub cluster: ClusterId, pub cluster: ClusterId,
} }
#[derive(FromTLV, ToTLV, Copy, Clone)] #[derive(FromTLV, ToTLV, Clone, Debug)]
pub struct DataVersionFilter { pub struct DataVersionFilter {
pub path: ClusterPath, pub path: ClusterPath,
pub data_ver: u32, pub data_ver: u32,
} }
#[derive(FromTLV, ToTLV, Copy, Clone)] #[derive(FromTLV, ToTLV, Clone, Debug)]
#[tlvargs(datatype = "list")] #[tlvargs(datatype = "list")]
pub struct EventPath { pub struct EventPath {
pub node: Option<u64>, pub node: Option<u64>,
@ -555,7 +570,7 @@ pub mod ib {
pub is_urgent: Option<bool>, pub is_urgent: Option<bool>,
} }
#[derive(FromTLV, ToTLV, Copy, Clone)] #[derive(FromTLV, ToTLV, Clone, Debug)]
pub struct EventFilter { pub struct EventFilter {
pub node: Option<u64>, pub node: Option<u64>,
pub event_min: Option<u64>, pub event_min: Option<u64>,

View file

@ -15,73 +15,5 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::{
error::Error,
tlv::TLVWriter,
transport::{exchange::Exchange, proto_demux::ResponseRequired, session::SessionHandle},
};
use self::{
core::OpCode,
messages::msg::{InvReq, StatusResp, WriteReq},
};
#[derive(PartialEq)]
pub enum TransactionState {
Ongoing,
Complete,
Terminate,
}
pub struct Transaction<'a, 'b> {
pub state: TransactionState,
pub session: &'a mut SessionHandle<'b>,
pub exch: &'a mut Exchange,
}
pub trait InteractionConsumer {
fn consume_invoke_cmd(
&self,
req: &InvReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error>;
fn consume_read_attr(
&self,
// TODO: This handling is different from the other APIs here, identify
// consistent options for this trait
req: &[u8],
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error>;
fn consume_write_attr(
&self,
req: &WriteReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error>;
fn consume_status_report(
&self,
_req: &StatusResp,
_trans: &mut Transaction,
_tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error>;
fn consume_subscribe(
&self,
_req: &[u8],
_trans: &mut Transaction,
_tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error>;
}
pub struct InteractionModel {
consumer: Box<dyn InteractionConsumer>,
}
pub mod command;
pub mod core; pub mod core;
pub mod messages; pub mod messages;
pub mod read;
pub mod write;

View file

@ -1,42 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::{
error::Error,
interaction_model::core::OpCode,
tlv::TLVWriter,
transport::{packet::Packet, proto_demux::ResponseRequired},
};
use super::{InteractionModel, Transaction};
impl InteractionModel {
pub fn handle_read_req(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
proto_tx.set_proto_opcode(OpCode::ReportData as u8);
let proto_tx_wb = proto_tx.get_writebuf()?;
let mut tw = TLVWriter::new(proto_tx_wb);
self.consumer.consume_read_attr(rx_buf, trans, &mut tw)?;
Ok(ResponseRequired::Yes)
}
}

View file

@ -1,58 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use log::error;
use crate::{
error::Error,
tlv::{get_root_node_struct, FromTLV, TLVWriter, TagType},
transport::{packet::Packet, proto_demux::ResponseRequired},
};
use super::{core::OpCode, messages::msg::WriteReq, InteractionModel, Transaction};
impl InteractionModel {
pub fn handle_write_req(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
if InteractionModel::req_timeout_handled(trans, proto_tx)? {
return Ok(ResponseRequired::Yes);
}
proto_tx.set_proto_opcode(OpCode::WriteResponse as u8);
let mut tw = TLVWriter::new(proto_tx.get_writebuf()?);
let root = get_root_node_struct(rx_buf)?;
let write_req = WriteReq::from_tlv(&root)?;
let supress_response = write_req.supress_response.unwrap_or_default();
tw.start_struct(TagType::Anonymous)?;
self.consumer
.consume_write_attr(&write_req, trans, &mut tw)?;
tw.end_container()?;
trans.complete();
if supress_response {
error!("Supress response is set, is this the expected handling?");
Ok(ResponseRequired::No)
} else {
Ok(ResponseRequired::Yes)
}
}
}

View file

@ -23,7 +23,7 @@
//! Currently Ethernet based transport is supported. //! Currently Ethernet based transport is supported.
//! //!
//! # Examples //! # Examples
//! ``` //! TODO: Fix once new API has stabilized a bit
//! use matter::{Matter, CommissioningData}; //! use matter::{Matter, CommissioningData};
//! use matter::data_model::device_types::device_type_add_on_off_light; //! use matter::data_model::device_types::device_type_add_on_off_light;
//! use matter::data_model::cluster_basic_information::BasicInfoConfig; //! use matter::data_model::cluster_basic_information::BasicInfoConfig;
@ -65,8 +65,11 @@
//! } //! }
//! // Start the Matter Daemon //! // Start the Matter Daemon
//! // matter.start_daemon().unwrap(); //! // matter.start_daemon().unwrap();
//! ``` //!
//! Start off exploring by going to the [Matter] object. //! Start off exploring by going to the [Matter] object.
#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))]
#![cfg_attr(feature = "nightly", allow(incomplete_features))]
pub mod acl; pub mod acl;
pub mod cert; pub mod cert;
@ -80,8 +83,8 @@ pub mod group_keys;
pub mod interaction_model; pub mod interaction_model;
pub mod mdns; pub mod mdns;
pub mod pairing; pub mod pairing;
pub mod persist;
pub mod secure_channel; pub mod secure_channel;
pub mod sys;
pub mod tlv; pub mod tlv;
pub mod transport; pub mod transport;
pub mod utils; pub mod utils;

View file

@ -15,35 +15,73 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::{Arc, Mutex, Once}; use core::fmt::Write;
use crate::{ use crate::{data_model::cluster_basic_information::BasicInfoConfig, error::Error};
error::Error,
sys::{sys_publish_service, SysMdnsService},
transport::udp::MATTER_PORT,
};
#[derive(Default)] #[cfg(all(feature = "std", target_os = "macos"))]
/// The mDNS service handler pub mod astro;
pub struct MdnsInner { pub mod builtin;
/// Vendor ID pub mod proto;
vid: u16,
/// Product ID pub trait Mdns {
pid: u16, fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error>;
/// Device name fn remove(&self, service: &str) -> Result<(), Error>;
device_name: String,
} }
pub struct Mdns { impl<T> Mdns for &mut T
inner: Mutex<MdnsInner>, where
T: Mdns,
{
fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
(**self).add(service, mode)
} }
const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00; fn remove(&self, service: &str) -> Result<(), Error> {
const SHORT_DISCRIMINATOR_SHIFT: u16 = 8; (**self).remove(service)
}
}
static mut G_MDNS: Option<Arc<Mdns>> = None; impl<T> Mdns for &T
static INIT: Once = Once::new(); where
T: Mdns,
{
fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
(**self).add(service, mode)
}
fn remove(&self, service: &str) -> Result<(), Error> {
(**self).remove(service)
}
}
#[cfg(all(feature = "std", target_os = "macos"))]
pub type DefaultMdns<'a> = astro::Mdns<'a>;
#[cfg(all(feature = "std", target_os = "macos"))]
pub type DefaultMdnsRunner<'a> = astro::MdnsRunner<'a>;
#[cfg(not(all(feature = "std", target_os = "macos")))]
pub type DefaultMdns<'a> = builtin::Mdns<'a>;
#[cfg(not(all(feature = "std", target_os = "macos")))]
pub type DefaultMdnsRunner<'a> = builtin::MdnsRunner<'a>;
pub struct DummyMdns;
impl Mdns for DummyMdns {
fn add(&self, _service: &str, _mode: ServiceMode) -> Result<(), Error> {
Ok(())
}
fn remove(&self, _service: &str) -> Result<(), Error> {
Ok(())
}
}
pub type Service<'a> = proto::Service<'a>;
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ServiceMode { pub enum ServiceMode {
/// The commissioned state /// The commissioned state
Commissioned, Commissioned,
@ -51,69 +89,88 @@ pub enum ServiceMode {
Commissionable(u16), Commissionable(u16),
} }
impl Mdns { impl ServiceMode {
fn new() -> Self { pub fn service<R, F: for<'a> FnOnce(&Service<'a>) -> Result<R, Error>>(
Self { &self,
inner: Mutex::new(MdnsInner { dev_att: &BasicInfoConfig,
..Default::default() matter_port: u16,
name: &str,
f: F,
) -> Result<R, Error> {
match self {
Self::Commissioned => f(&Service {
name,
service: "_matter",
protocol: "_tcp",
port: matter_port,
service_subtypes: &[],
txt_kvs: &[],
}), }),
}
}
/// Get a handle to the globally unique mDNS instance
pub fn get() -> Result<Arc<Self>, Error> {
unsafe {
INIT.call_once(|| {
G_MDNS = Some(Arc::new(Mdns::new()));
});
Ok(G_MDNS.as_ref().ok_or(Error::Invalid)?.clone())
}
}
/// Set mDNS service specific values
/// Values like vid, pid, discriminator etc
// TODO: More things like device-type etc can be added here
pub fn set_values(&self, vid: u16, pid: u16, device_name: &str) {
let mut inner = self.inner.lock().unwrap();
inner.vid = vid;
inner.pid = pid;
inner.device_name = device_name.chars().take(32).collect();
}
/// Publish a mDNS service
/// name - is the service name (comma separated subtypes may follow)
/// mode - the current service mode
#[allow(clippy::needless_pass_by_value)]
pub fn publish_service(&self, name: &str, mode: ServiceMode) -> Result<SysMdnsService, Error> {
match mode {
ServiceMode::Commissioned => {
sys_publish_service(name, "_matter._tcp", MATTER_PORT, &[])
}
ServiceMode::Commissionable(discriminator) => { ServiceMode::Commissionable(discriminator) => {
let inner = self.inner.lock().unwrap(); let discriminator_str = Self::get_discriminator_str(*discriminator);
let short = compute_short_discriminator(discriminator); let vp = Self::get_vp(dev_att.vid, dev_att.pid);
let serv_type = format!("_matterc._udp,_S{},_L{}", short, discriminator);
let str_discriminator = format!("{}", discriminator); let txt_kvs = &[
let txt_kvs = [ ("D", discriminator_str.as_str()),
["D", &str_discriminator], ("CM", "1"),
["CM", "1"], ("DN", dev_att.device_name),
["DN", &inner.device_name], ("VP", &vp),
["VP", &format!("{}+{}", inner.vid, inner.pid)], ("SII", "5000"), /* Sleepy Idle Interval */
["SII", "5000"], /* Sleepy Idle Interval */ ("SAI", "300"), /* Sleepy Active Interval */
["SAI", "300"], /* Sleepy Active Interval */ ("PH", "33"), /* Pairing Hint */
["PH", "33"], /* Pairing Hint */ ("PI", ""), /* Pairing Instruction */
["PI", ""], /* Pairing Instruction */
]; ];
sys_publish_service(name, &serv_type, MATTER_PORT, &txt_kvs)
f(&Service {
name,
service: "_matterc",
protocol: "_udp",
port: matter_port,
service_subtypes: &[
&Self::get_long_service_subtype(*discriminator),
&Self::get_short_service_type(*discriminator),
],
txt_kvs,
})
} }
} }
} }
fn get_long_service_subtype(discriminator: u16) -> heapless::String<32> {
let mut serv_type = heapless::String::new();
write!(&mut serv_type, "_L{}", discriminator).unwrap();
serv_type
}
fn get_short_service_type(discriminator: u16) -> heapless::String<32> {
let short = Self::compute_short_discriminator(discriminator);
let mut serv_type = heapless::String::new();
write!(&mut serv_type, "_S{}", short).unwrap();
serv_type
}
fn get_discriminator_str(discriminator: u16) -> heapless::String<5> {
discriminator.into()
}
fn get_vp(vid: u16, pid: u16) -> heapless::String<11> {
let mut vp = heapless::String::new();
write!(&mut vp, "{}+{}", vid, pid).unwrap();
vp
} }
fn compute_short_discriminator(discriminator: u16) -> u16 { fn compute_short_discriminator(discriminator: u16) -> u16 {
const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00;
const SHORT_DISCRIMINATOR_SHIFT: u16 = 8;
(discriminator & SHORT_DISCRIMINATOR_MASK) >> SHORT_DISCRIMINATOR_SHIFT (discriminator & SHORT_DISCRIMINATOR_MASK) >> SHORT_DISCRIMINATOR_SHIFT
} }
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -122,11 +179,11 @@ mod tests {
#[test] #[test]
fn can_compute_short_discriminator() { fn can_compute_short_discriminator() {
let discriminator: u16 = 0b0000_1111_0000_0000; let discriminator: u16 = 0b0000_1111_0000_0000;
let short = compute_short_discriminator(discriminator); let short = ServiceMode::compute_short_discriminator(discriminator);
assert_eq!(short, 0b1111); assert_eq!(short, 0b1111);
let discriminator: u16 = 840; let discriminator: u16 = 840;
let short = compute_short_discriminator(discriminator); let short = ServiceMode::compute_short_discriminator(discriminator);
assert_eq!(short, 3); assert_eq!(short, 3);
} }
} }

107
matter/src/mdns/astro.rs Normal file
View file

@ -0,0 +1,107 @@
use core::cell::RefCell;
use std::collections::HashMap;
use crate::{
data_model::cluster_basic_information::BasicInfoConfig,
error::{Error, ErrorCode},
transport::pipe::Pipe,
};
use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService};
use log::info;
use super::ServiceMode;
pub struct Mdns<'a> {
dev_det: &'a BasicInfoConfig<'a>,
matter_port: u16,
services: RefCell<HashMap<String, RegisteredDnsService>>,
}
impl<'a> Mdns<'a> {
pub fn new(
_id: u16,
_hostname: &str,
_ip: [u8; 4],
_ipv6: Option<[u8; 16]>,
_interface: u32,
dev_det: &'a BasicInfoConfig<'a>,
matter_port: u16,
) -> Self {
Self::native_new(dev_det, matter_port)
}
pub fn native_new(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> Self {
Self {
dev_det,
matter_port,
services: RefCell::new(HashMap::new()),
}
}
pub fn add(&self, name: &str, mode: ServiceMode) -> Result<(), Error> {
info!("Registering mDNS service {}/{:?}", name, mode);
let _ = self.remove(name);
mode.service(self.dev_det, self.matter_port, name, |service| {
let composite_service_type = if !service.service_subtypes.is_empty() {
format!(
"{}.{},{}",
service.service,
service.protocol,
service.service_subtypes.join(",")
)
} else {
format!("{}.{}", service.service, service.protocol)
};
let mut builder = DNSServiceBuilder::new(&composite_service_type, service.port)
.with_name(service.name);
for kvs in service.txt_kvs {
info!("mDNS TXT key {} val {}", kvs.0, kvs.1);
builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string());
}
let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?;
self.services.borrow_mut().insert(service.name.into(), svc);
Ok(())
})
}
pub fn remove(&self, name: &str) -> Result<(), Error> {
if self.services.borrow_mut().remove(name).is_some() {
info!("Deregistering mDNS service {}", name);
}
Ok(())
}
}
pub struct MdnsRunner<'a>(&'a Mdns<'a>);
impl<'a> MdnsRunner<'a> {
pub const fn new(mdns: &'a Mdns<'a>) -> Self {
Self(mdns)
}
pub async fn run_udp(&mut self) -> Result<(), Error> {
core::future::pending::<Result<(), Error>>().await
}
pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> {
core::future::pending::<Result<(), Error>>().await
}
}
impl<'a> super::Mdns for Mdns<'a> {
fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
Mdns::add(self, service, mode)
}
fn remove(&self, service: &str) -> Result<(), Error> {
Mdns::remove(self, service)
}
}

319
matter/src/mdns/builtin.rs Normal file
View file

@ -0,0 +1,319 @@
use core::{cell::RefCell, mem::MaybeUninit, pin::pin};
use domain::base::name::FromStrError;
use domain::base::{octets::ParseError, ShortBuf};
use embassy_futures::select::{select, select3};
use embassy_time::{Duration, Timer};
use log::info;
use crate::data_model::cluster_basic_information::BasicInfoConfig;
use crate::error::{Error, ErrorCode};
use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use crate::transport::packet::{MAX_RX_BUF_SIZE, MAX_TX_BUF_SIZE};
use crate::transport::pipe::{Chunk, Pipe};
use crate::transport::udp::UdpListener;
use crate::utils::select::{EitherUnwrap, Notification};
use super::{
proto::{Host, Services},
Service, ServiceMode,
};
const IP_BIND_ADDR: IpAddr = IpAddr::V6(Ipv6Addr::UNSPECIFIED);
const IP_BROADCAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb);
const PORT: u16 = 5353;
type MdnsTxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>;
type MdnsRxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>;
pub struct Mdns<'a> {
host: Host<'a>,
interface: u32,
dev_det: &'a BasicInfoConfig<'a>,
matter_port: u16,
services: RefCell<heapless::Vec<(heapless::String<40>, ServiceMode), 4>>,
notification: Notification,
}
impl<'a> Mdns<'a> {
#[inline(always)]
pub const fn new(
id: u16,
hostname: &'a str,
ip: [u8; 4],
ipv6: Option<[u8; 16]>,
interface: u32,
dev_det: &'a BasicInfoConfig<'a>,
matter_port: u16,
) -> Self {
Self {
host: Host {
id,
hostname,
ip,
ipv6,
},
interface,
dev_det,
matter_port,
services: RefCell::new(heapless::Vec::new()),
notification: Notification::new(),
}
}
pub fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
let mut services = self.services.borrow_mut();
services.retain(|(name, _)| name != service);
services
.push((service.into(), mode))
.map_err(|_| ErrorCode::NoSpace)?;
self.notification.signal(());
Ok(())
}
pub fn remove(&self, service: &str) -> Result<(), Error> {
let mut services = self.services.borrow_mut();
services.retain(|(name, _)| name != service);
Ok(())
}
pub fn for_each<F>(&self, mut callback: F) -> Result<(), Error>
where
F: FnMut(&Service) -> Result<(), Error>,
{
let services = self.services.borrow();
for (service, mode) in &*services {
mode.service(self.dev_det, self.matter_port, service, |service| {
callback(service)
})?;
}
Ok(())
}
}
pub struct MdnsRunner<'a>(&'a Mdns<'a>);
impl<'a> MdnsRunner<'a> {
pub const fn new(mdns: &'a Mdns<'a>) -> Self {
Self(mdns)
}
pub async fn run_udp(&mut self) -> Result<(), Error> {
let mut tx_buf = MdnsTxBuf::uninit();
let mut rx_buf = MdnsRxBuf::uninit();
let tx_buf = &mut tx_buf;
let rx_buf = &mut rx_buf;
let tx_pipe = Pipe::new(unsafe { tx_buf.assume_init_mut() });
let rx_pipe = Pipe::new(unsafe { rx_buf.assume_init_mut() });
let tx_pipe = &tx_pipe;
let rx_pipe = &rx_pipe;
let mut udp = UdpListener::new(SocketAddr::new(IP_BIND_ADDR, PORT)).await?;
udp.join_multicast_v6(IPV6_BROADCAST_ADDR, self.0.interface)?;
udp.join_multicast_v4(IP_BROADCAST_ADDR, Ipv4Addr::from(self.0.host.ip))?;
let udp = &udp;
let mut tx = pin!(async move {
loop {
{
let mut data = tx_pipe.data.lock().await;
if let Some(chunk) = data.chunk {
udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end])
.await?;
data.chunk = None;
tx_pipe.data_consumed_notification.signal(());
}
}
tx_pipe.data_supplied_notification.wait().await;
}
});
let mut rx = pin!(async move {
loop {
{
let mut data = rx_pipe.data.lock().await;
if data.chunk.is_none() {
let (len, addr) = udp.recv(data.buf).await?;
data.chunk = Some(Chunk {
start: 0,
end: len,
addr: Address::Udp(addr),
});
rx_pipe.data_supplied_notification.signal(());
}
}
rx_pipe.data_consumed_notification.wait().await;
}
});
let mut run = pin!(async move { self.run(tx_pipe, rx_pipe).await });
select3(&mut tx, &mut rx, &mut run).await.unwrap()
}
pub async fn run(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> {
let mut broadcast = pin!(self.broadcast(tx_pipe));
let mut respond = pin!(self.respond(rx_pipe, tx_pipe));
select(&mut broadcast, &mut respond).await.unwrap()
}
#[allow(clippy::await_holding_refcell_ref)]
async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
loop {
select(
self.0.notification.wait(),
Timer::after(Duration::from_secs(30)),
)
.await;
for addr in [
IpAddr::V4(IP_BROADCAST_ADDR),
IpAddr::V6(IPV6_BROADCAST_ADDR),
] {
loop {
let sent = {
let mut data = tx_pipe.data.lock().await;
if data.chunk.is_none() {
let len = self.0.host.broadcast(&self.0, data.buf, 60)?;
if len > 0 {
info!("Broadasting mDNS entry to {}:{}", addr, PORT);
data.chunk = Some(Chunk {
start: 0,
end: len,
addr: Address::Udp(SocketAddr::new(addr, PORT)),
});
tx_pipe.data_supplied_notification.signal(());
}
true
} else {
false
}
};
if sent {
break;
} else {
tx_pipe.data_consumed_notification.wait().await;
}
}
}
}
}
#[allow(clippy::await_holding_refcell_ref)]
async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
loop {
{
let mut rx_data = rx_pipe.data.lock().await;
if let Some(rx_chunk) = rx_data.chunk {
let data = &rx_data.buf[rx_chunk.start..rx_chunk.end];
loop {
let sent = {
let mut tx_data = tx_pipe.data.lock().await;
if tx_data.chunk.is_none() {
let len = self.0.host.respond(&self.0, data, tx_data.buf, 60)?;
if len > 0 {
info!("Replying to mDNS query from {}", rx_chunk.addr);
tx_data.chunk = Some(Chunk {
start: 0,
end: len,
addr: rx_chunk.addr,
});
tx_pipe.data_supplied_notification.signal(());
}
true
} else {
false
}
};
if sent {
break;
} else {
tx_pipe.data_consumed_notification.wait().await;
}
}
// info!("Got mDNS query");
rx_data.chunk = None;
rx_pipe.data_consumed_notification.signal(());
}
}
rx_pipe.data_supplied_notification.wait().await;
}
}
}
impl<'a> super::Mdns for Mdns<'a> {
fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
Mdns::add(self, service, mode)
}
fn remove(&self, service: &str) -> Result<(), Error> {
Mdns::remove(self, service)
}
}
impl<'a> Services for Mdns<'a> {
type Error = crate::error::Error;
fn for_each<F>(&self, callback: F) -> Result<(), Error>
where
F: FnMut(&Service) -> Result<(), Error>,
{
Mdns::for_each(self, callback)
}
}
impl From<ShortBuf> for Error {
fn from(_e: ShortBuf) -> Self {
Self::new(ErrorCode::NoSpace)
}
}
impl From<ParseError> for Error {
fn from(_e: ParseError) -> Self {
Self::new(ErrorCode::MdnsError)
}
}
impl From<FromStrError> for Error {
fn from(_e: FromStrError) -> Self {
Self::new(ErrorCode::MdnsError)
}
}

508
matter/src/mdns/proto.rs Normal file
View file

@ -0,0 +1,508 @@
use core::fmt::Write;
use core::str::FromStr;
use domain::{
base::{
header::Flags,
iana::Class,
message_builder::AnswerBuilder,
name::FromStrError,
octets::{Octets256, Octets64, OctetsBuilder, ParseError},
Dname, Message, MessageBuilder, Record, Rtype, ShortBuf, ToDname,
},
rdata::{Aaaa, Ptr, Srv, Txt, A},
};
use log::trace;
pub trait Services {
type Error: From<ShortBuf> + From<ParseError> + From<FromStrError>;
fn for_each<F>(&self, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Service) -> Result<(), Self::Error>;
}
impl<T> Services for &mut T
where
T: Services,
{
type Error = T::Error;
fn for_each<F>(&self, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Service) -> Result<(), Self::Error>,
{
(**self).for_each(callback)
}
}
impl<T> Services for &T
where
T: Services,
{
type Error = T::Error;
fn for_each<F>(&self, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Service) -> Result<(), Self::Error>,
{
(**self).for_each(callback)
}
}
pub struct Host<'a> {
pub id: u16,
pub hostname: &'a str,
pub ip: [u8; 4],
pub ipv6: Option<[u8; 16]>,
}
impl<'a> Host<'a> {
pub fn broadcast<T: Services>(
&self,
services: T,
buf: &mut [u8],
ttl_sec: u32,
) -> Result<usize, T::Error> {
let buf = Buf(buf, 0);
let message = MessageBuilder::from_target(buf)?;
let mut answer = message.answer();
self.set_broadcast(services, &mut answer, ttl_sec)?;
let buf = answer.finish();
Ok(buf.1)
}
pub fn respond<T: Services>(
&self,
services: T,
data: &[u8],
buf: &mut [u8],
ttl_sec: u32,
) -> Result<usize, T::Error> {
let buf = Buf(buf, 0);
let message = MessageBuilder::from_target(buf)?;
let mut answer = message.answer();
if self.set_response(data, services, &mut answer, ttl_sec)? {
let buf = answer.finish();
Ok(buf.1)
} else {
Ok(0)
}
}
fn set_broadcast<T, F>(
&self,
services: F,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), F::Error>
where
T: OctetsBuilder + AsMut<[u8]>,
F: Services,
{
self.set_header(answer);
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
services.for_each(|service| {
service.add_service(answer, self.hostname, ttl_sec)?;
service.add_service_type(answer, ttl_sec)?;
service.add_service_subtypes(answer, ttl_sec)?;
service.add_txt(answer, ttl_sec)?;
Ok(())
})?;
Ok(())
}
fn set_response<T, F>(
&self,
data: &[u8],
services: F,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<bool, F::Error>
where
T: OctetsBuilder + AsMut<[u8]>,
F: Services,
{
self.set_header(answer);
let message = Message::from_octets(data)?;
let mut replied = false;
for question in message.question() {
trace!("Handling question {:?}", question);
let question = question?;
match question.qtype() {
Rtype::A
if question
.qname()
.name_eq(&Host::host_fqdn(self.hostname, true)?) =>
{
self.add_ipv4(answer, ttl_sec)?;
replied = true;
}
Rtype::Aaaa
if question
.qname()
.name_eq(&Host::host_fqdn(self.hostname, true)?) =>
{
self.add_ipv6(answer, ttl_sec)?;
replied = true;
}
Rtype::Srv => {
services.for_each(|service| {
if question.qname().name_eq(&service.service_fqdn(true)?) {
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
service.add_service(answer, self.hostname, ttl_sec)?;
replied = true;
}
Ok(())
})?;
}
Rtype::Ptr => {
services.for_each(|service| {
if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) {
service.add_service_type(answer, ttl_sec)?;
replied = true;
} else if question.qname().name_eq(&service.service_type_fqdn(true)?) {
// TODO
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
service.add_service(answer, self.hostname, ttl_sec)?;
service.add_service_type(answer, ttl_sec)?;
service.add_service_subtypes(answer, ttl_sec)?;
service.add_txt(answer, ttl_sec)?;
replied = true;
}
Ok(())
})?;
}
Rtype::Txt => {
services.for_each(|service| {
if question.qname().name_eq(&service.service_fqdn(true)?) {
service.add_txt(answer, ttl_sec)?;
replied = true;
}
Ok(())
})?;
}
Rtype::Any => {
// A / AAAA
if question
.qname()
.name_eq(&Host::host_fqdn(self.hostname, true)?)
{
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
replied = true;
}
// PTR
services.for_each(|service| {
if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) {
service.add_service_type(answer, ttl_sec)?;
replied = true;
} else if question.qname().name_eq(&service.service_type_fqdn(true)?) {
// TODO
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
service.add_service(answer, self.hostname, ttl_sec)?;
service.add_service_type(answer, ttl_sec)?;
service.add_service_subtypes(answer, ttl_sec)?;
service.add_txt(answer, ttl_sec)?;
replied = true;
}
Ok(())
})?;
// SRV
services.for_each(|service| {
if question.qname().name_eq(&service.service_fqdn(true)?) {
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
service.add_service(answer, self.hostname, ttl_sec)?;
replied = true;
}
Ok(())
})?;
}
_ => (),
}
}
Ok(replied)
}
fn set_header<T: OctetsBuilder + AsMut<[u8]>>(&self, answer: &mut AnswerBuilder<T>) {
let header = answer.header_mut();
header.set_id(self.id);
header.set_opcode(domain::base::iana::Opcode::Query);
header.set_rcode(domain::base::iana::Rcode::NoError);
let mut flags = Flags::new();
flags.qr = true;
flags.aa = true;
header.set_flags(flags);
}
fn add_ipv4<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
answer.push(Record::<Dname<Octets64>, A>::new(
Self::host_fqdn(self.hostname, false).unwrap(),
Class::In,
ttl_sec,
A::from_octets(self.ip[0], self.ip[1], self.ip[2], self.ip[3]),
))
}
fn add_ipv6<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
if let Some(ip) = &self.ipv6 {
answer.push(Record::<Dname<Octets64>, Aaaa>::new(
Self::host_fqdn(self.hostname, false).unwrap(),
Class::In,
ttl_sec,
Aaaa::new((*ip).into()),
))
} else {
Ok(())
}
}
fn host_fqdn(hostname: &str, suffix: bool) -> Result<Dname<Octets64>, FromStrError> {
let suffix = if suffix { "." } else { "" };
let mut host_fqdn = heapless::String::<60>::new();
write!(host_fqdn, "{}.local{}", hostname, suffix,).unwrap();
Dname::from_str(&host_fqdn)
}
}
pub struct Service<'a> {
pub name: &'a str,
pub service: &'a str,
pub protocol: &'a str,
pub port: u16,
pub service_subtypes: &'a [&'a str],
pub txt_kvs: &'a [(&'a str, &'a str)],
}
impl<'a> Service<'a> {
fn add_service<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
hostname: &str,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
answer.push(Record::<Dname<Octets64>, Srv<_>>::new(
self.service_fqdn(false).unwrap(),
Class::In,
ttl_sec,
Srv::new(0, 0, self.port, Host::host_fqdn(hostname, false).unwrap()),
))
}
fn add_service_type<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
answer.push(Record::<Dname<Octets64>, Ptr<_>>::new(
Self::dns_sd_fqdn(false).unwrap(),
Class::In,
ttl_sec,
Ptr::new(self.service_type_fqdn(false).unwrap()),
))?;
answer.push(Record::<Dname<Octets64>, Ptr<_>>::new(
self.service_type_fqdn(false).unwrap(),
Class::In,
ttl_sec,
Ptr::new(self.service_fqdn(false).unwrap()),
))
}
fn add_service_subtypes<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
for service_subtype in self.service_subtypes {
self.add_service_subtype(answer, service_subtype, ttl_sec)?;
}
Ok(())
}
fn add_service_subtype<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
service_subtype: &str,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
answer.push(Record::<Dname<Octets64>, Ptr<_>>::new(
Self::dns_sd_fqdn(false).unwrap(),
Class::In,
ttl_sec,
Ptr::new(self.service_subtype_fqdn(service_subtype, false).unwrap()),
))?;
answer.push(Record::<Dname<Octets64>, Ptr<_>>::new(
self.service_subtype_fqdn(service_subtype, false).unwrap(),
Class::In,
ttl_sec,
Ptr::new(self.service_fqdn(false).unwrap()),
))
}
fn add_txt<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
// only way I found to create multiple parts in a Txt
// each slice is the length and then the data
let mut octets = Octets256::new();
//octets.append_slice(&[1u8, b'X'])?;
//octets.append_slice(&[2u8, b'A', b'B'])?;
//octets.append_slice(&[0u8])?;
for (k, v) in self.txt_kvs {
octets.append_slice(&[(k.len() + v.len() + 1) as u8])?;
octets.append_slice(k.as_bytes())?;
octets.append_slice(&[b'='])?;
octets.append_slice(v.as_bytes())?;
}
let txt = Txt::from_octets(&mut octets).unwrap();
answer.push(Record::<Dname<Octets64>, Txt<_>>::new(
self.service_fqdn(false).unwrap(),
Class::In,
ttl_sec,
txt,
))
}
fn service_fqdn(&self, suffix: bool) -> Result<Dname<Octets64>, FromStrError> {
let suffix = if suffix { "." } else { "" };
let mut service_fqdn = heapless::String::<60>::new();
write!(
service_fqdn,
"{}.{}.{}.local{}",
self.name, self.service, self.protocol, suffix,
)
.unwrap();
Dname::from_str(&service_fqdn)
}
fn service_type_fqdn(&self, suffix: bool) -> Result<Dname<Octets64>, FromStrError> {
let suffix = if suffix { "." } else { "" };
let mut service_type_fqdn = heapless::String::<60>::new();
write!(
service_type_fqdn,
"{}.{}.local{}",
self.service, self.protocol, suffix,
)
.unwrap();
Dname::from_str(&service_type_fqdn)
}
fn service_subtype_fqdn(
&self,
service_subtype: &str,
suffix: bool,
) -> Result<Dname<Octets64>, FromStrError> {
let suffix = if suffix { "." } else { "" };
let mut service_subtype_fqdn = heapless::String::<40>::new();
write!(
service_subtype_fqdn,
"{}._sub.{}.{}.local{}",
service_subtype, self.service, self.protocol, suffix,
)
.unwrap();
Dname::from_str(&service_subtype_fqdn)
}
fn dns_sd_fqdn(suffix: bool) -> Result<Dname<Octets64>, FromStrError> {
if suffix {
Dname::from_str("_services._dns-sd._udp.local.")
} else {
Dname::from_str("_services._dns-sd._udp.local")
}
}
}
struct Buf<'a>(pub &'a mut [u8], pub usize);
impl<'a> OctetsBuilder for Buf<'a> {
type Octets = Self;
fn append_slice(&mut self, slice: &[u8]) -> Result<(), ShortBuf> {
if self.1 + slice.len() <= self.0.len() {
let end = self.1 + slice.len();
self.0[self.1..end].copy_from_slice(slice);
self.1 = end;
Ok(())
} else {
Err(ShortBuf)
}
}
fn truncate(&mut self, len: usize) {
self.1 = len;
}
fn freeze(self) -> Self::Octets {
self
}
fn len(&self) -> usize {
self.1
}
fn is_empty(&self) -> bool {
self.1 == 0
}
}
impl<'a> AsMut<[u8]> for Buf<'a> {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.0[..self.1]
}
}

View file

@ -15,56 +15,66 @@
* limitations under the License. * limitations under the License.
*/ */
use core::fmt::Write;
use super::*; use super::*;
pub(super) fn compute_pairing_code(comm_data: &CommissioningData) -> String { pub fn compute_pairing_code(comm_data: &CommissioningData) -> heapless::String<32> {
// 0: no Vendor ID and Product ID present in Manual Pairing Code // 0: no Vendor ID and Product ID present in Manual Pairing Code
const VID_PID_PRESENT: u8 = 0; const VID_PID_PRESENT: u8 = 0;
let passwd = passwd_from_comm_data(comm_data); let passwd = passwd_from_comm_data(comm_data);
let CommissioningData { discriminator, .. } = comm_data; let CommissioningData { discriminator, .. } = comm_data;
let mut digits = String::new(); let mut digits = heapless::String::<32>::new();
digits.push_str(&((VID_PID_PRESENT << 2) | (discriminator >> 10) as u8).to_string()); write!(
digits.push_str(&format!( &mut digits,
"{:0>5}", "{}{:0>5}{:0>4}",
((discriminator & 0x300) << 6) | (passwd & 0x3FFF) as u16 (VID_PID_PRESENT << 2) | (discriminator >> 10) as u8,
)); ((discriminator & 0x300) << 6) | (passwd & 0x3FFF) as u16,
digits.push_str(&format!("{:0>4}", passwd >> 14)); passwd >> 14
)
.unwrap();
let check_digit = digits.calculate_verhoeff_check_digit(); let mut final_digits = heapless::String::<32>::new();
digits.push_str(&check_digit.to_string()); write!(
&mut final_digits,
"{}{}",
digits,
digits.calculate_verhoeff_check_digit()
)
.unwrap();
digits final_digits
} }
pub(super) fn pretty_print_pairing_code(pairing_code: &str) { pub(super) fn pretty_print_pairing_code(pairing_code: &str) {
assert!(pairing_code.len() == 11); assert!(pairing_code.len() == 11);
let mut pretty = String::new(); let mut pretty = heapless::String::<32>::new();
pretty.push_str(&pairing_code[..4]); pretty.push_str(&pairing_code[..4]).unwrap();
pretty.push('-'); pretty.push('-').unwrap();
pretty.push_str(&pairing_code[4..8]); pretty.push_str(&pairing_code[4..8]).unwrap();
pretty.push('-'); pretty.push('-').unwrap();
pretty.push_str(&pairing_code[8..]); pretty.push_str(&pairing_code[8..]).unwrap();
info!("Pairing Code: {}", pretty); info!("Pairing Code: {}", pretty);
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::secure_channel::spake2p::VerifierData; use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand};
#[test] #[test]
fn can_compute_pairing_code() { fn can_compute_pairing_code() {
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(123456), verifier: VerifierData::new_with_pw(123456, dummy_rand),
discriminator: 250, discriminator: 250,
}; };
let pairing_code = compute_pairing_code(&comm_data); let pairing_code = compute_pairing_code(&comm_data);
assert_eq!(pairing_code, "00876800071"); assert_eq!(pairing_code, "00876800071");
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(34567890), verifier: VerifierData::new_with_pw(34567890, dummy_rand),
discriminator: 2976, discriminator: 2976,
}; };
let pairing_code = compute_pairing_code(&comm_data); let pairing_code = compute_pairing_code(&comm_data);

View file

@ -22,7 +22,6 @@ pub mod qr;
pub mod vendor_identifiers; pub mod vendor_identifiers;
use log::info; use log::info;
use qrcode::{render::unicode, QrCode, Version};
use verhoeff::Verhoeff; use verhoeff::Verhoeff;
use crate::{ use crate::{
@ -32,7 +31,7 @@ use crate::{
use self::{ use self::{
code::{compute_pairing_code, pretty_print_pairing_code}, code::{compute_pairing_code, pretty_print_pairing_code},
qr::{payload_base38_representation, print_qr_code, QrSetupPayload}, qr::{compute_qr_code, print_qr_code},
}; };
pub struct DiscoveryCapabilities { pub struct DiscoveryCapabilities {
@ -86,13 +85,15 @@ pub fn print_pairing_code_and_qr(
dev_det: &BasicInfoConfig, dev_det: &BasicInfoConfig,
comm_data: &CommissioningData, comm_data: &CommissioningData,
discovery_capabilities: DiscoveryCapabilities, discovery_capabilities: DiscoveryCapabilities,
) { buf: &mut [u8],
) -> Result<(), Error> {
let pairing_code = compute_pairing_code(comm_data); let pairing_code = compute_pairing_code(comm_data);
let qr_code_data = QrSetupPayload::new(dev_det, comm_data, discovery_capabilities); let qr_code = compute_qr_code(dev_det, comm_data, discovery_capabilities, buf)?;
let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode");
pretty_print_pairing_code(&pairing_code); pretty_print_pairing_code(&pairing_code);
print_qr_code(&data_str); print_qr_code(qr_code);
Ok(())
} }
pub(self) fn passwd_from_comm_data(comm_data: &CommissioningData) -> u32 { pub(self) fn passwd_from_comm_data(comm_data: &CommissioningData) -> u32 {

View file

@ -15,9 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
use std::collections::BTreeMap;
use crate::{ use crate::{
error::ErrorCode,
tlv::{TLVWriter, TagType}, tlv::{TLVWriter, TagType},
utils::writebuf::WriteBuf, utils::writebuf::WriteBuf,
}; };
@ -45,6 +44,7 @@ const TOTAL_PAYLOAD_DATA_SIZE_IN_BITS: usize = VERSION_FIELD_LENGTH_IN_BITS
+ PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS + PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS
+ SETUP_PINCODE_FIELD_LENGTH_IN_BITS + SETUP_PINCODE_FIELD_LENGTH_IN_BITS
+ PADDING_FIELD_LENGTH_IN_BITS; + PADDING_FIELD_LENGTH_IN_BITS;
const TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES: usize = TOTAL_PAYLOAD_DATA_SIZE_IN_BITS / 8; const TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES: usize = TOTAL_PAYLOAD_DATA_SIZE_IN_BITS / 8;
// Spec 5.1.4.2 CHIP-Common Reserved Tags // Spec 5.1.4.2 CHIP-Common Reserved Tags
@ -55,7 +55,7 @@ const SERIAL_NUMBER_TAG: u8 = 0x00;
// const COMMISSIONING_TIMEOUT_TAG: u8 = 0x04; // const COMMISSIONING_TIMEOUT_TAG: u8 = 0x04;
pub enum QRCodeInfoType { pub enum QRCodeInfoType {
String(String), String(heapless::String<128>), // TODO: Big enough?
Int32(i32), Int32(i32),
Int64(i64), Int64(i64),
UInt32(u32), UInt32(u32),
@ -63,7 +63,7 @@ pub enum QRCodeInfoType {
} }
pub enum SerialNumber { pub enum SerialNumber {
String(String), String(heapless::String<128>),
UInt32(u32), UInt32(u32),
} }
@ -78,10 +78,10 @@ pub struct QrSetupPayload<'data> {
version: u8, version: u8,
flow_type: CommissionningFlowType, flow_type: CommissionningFlowType,
discovery_capabilities: DiscoveryCapabilities, discovery_capabilities: DiscoveryCapabilities,
dev_det: &'data BasicInfoConfig, dev_det: &'data BasicInfoConfig<'data>,
comm_data: &'data CommissioningData, comm_data: &'data CommissioningData,
// we use a BTreeMap to keep the order of the optional data stable // The vec is ordered by the tag of OptionalQRCodeInfo
optional_data: BTreeMap<u8, OptionalQRCodeInfo>, optional_data: heapless::Vec<OptionalQRCodeInfo, 16>,
} }
impl<'data> QrSetupPayload<'data> { impl<'data> QrSetupPayload<'data> {
@ -98,11 +98,11 @@ impl<'data> QrSetupPayload<'data> {
discovery_capabilities, discovery_capabilities,
dev_det, dev_det,
comm_data, comm_data,
optional_data: BTreeMap::new(), optional_data: heapless::Vec::new(),
}; };
if !dev_det.serial_no.is_empty() { if !dev_det.serial_no.is_empty() {
result.add_serial_number(SerialNumber::String(dev_det.serial_no.clone())); result.add_serial_number(SerialNumber::String(dev_det.serial_no.into()));
} }
result result
@ -132,13 +132,11 @@ impl<'data> QrSetupPayload<'data> {
/// * `tag` - tag number in the [0x80-0xFF] range /// * `tag` - tag number in the [0x80-0xFF] range
/// * `data` - Data to add /// * `data` - Data to add
pub fn add_optional_vendor_data(&mut self, tag: u8, data: QRCodeInfoType) -> Result<(), Error> { pub fn add_optional_vendor_data(&mut self, tag: u8, data: QRCodeInfoType) -> Result<(), Error> {
if !is_vendor_tag(tag) { if is_vendor_tag(tag) {
return Err(Error::InvalidArgument); self.add_optional_data(tag, data)
} else {
Err(ErrorCode::InvalidArgument.into())
} }
self.optional_data
.insert(tag, OptionalQRCodeInfo { tag, data });
Ok(())
} }
/// A function to add an optional QR Code info CHIP object /// A function to add an optional QR Code info CHIP object
@ -150,16 +148,26 @@ impl<'data> QrSetupPayload<'data> {
tag: u8, tag: u8,
data: QRCodeInfoType, data: QRCodeInfoType,
) -> Result<(), Error> { ) -> Result<(), Error> {
if !is_common_tag(tag) { if is_common_tag(tag) {
return Err(Error::InvalidArgument); self.add_optional_data(tag, data)
} else {
Err(ErrorCode::InvalidArgument.into())
}
} }
self.optional_data fn add_optional_data(&mut self, tag: u8, data: QRCodeInfoType) -> Result<(), Error> {
.insert(tag, OptionalQRCodeInfo { tag, data }); let item = OptionalQRCodeInfo { tag, data };
Ok(()) let index = self.optional_data.iter().position(|info| tag < info.tag);
if let Some(index) = index {
self.optional_data.insert(index, item)
} else {
self.optional_data.push(item)
}
.map_err(|_| ErrorCode::NoSpace.into())
} }
pub fn get_all_optional_data(&self) -> &BTreeMap<u8, OptionalQRCodeInfo> { pub fn get_all_optional_data(&self) -> &[OptionalQRCodeInfo] {
&self.optional_data &self.optional_data
} }
@ -245,35 +253,30 @@ pub enum CommissionningFlowType {
Custom = 2, Custom = 2,
} }
struct TlvData { pub(super) fn payload_base38_representation<'a>(
max_data_length_in_bytes: u32, payload: &QrSetupPayload,
data_length_in_bytes: Option<usize>, buf: &'a mut [u8],
data: Option<Vec<u8>>, ) -> Result<&'a str, Error> {
} if payload.is_valid() {
let (str_buf, bits_buf, tlv_buf) = if payload.has_tlv() {
let (str_buf, buf) = buf.split_at_mut(buf.len() / 3 * 2);
pub(super) fn payload_base38_representation(payload: &QrSetupPayload) -> Result<String, Error> { let (bits_buf, tlv_buf) = buf.split_at_mut(buf.len() / 3);
let (mut bits, tlv_data) = if payload.has_tlv() {
let buffer_size = estimate_buffer_size(payload)?; (str_buf, bits_buf, Some(tlv_buf))
(
vec![0; buffer_size],
Some(TlvData {
max_data_length_in_bytes: buffer_size as u32,
data_length_in_bytes: None,
data: None,
}),
)
} else { } else {
(vec![0; TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES], None) let (str_buf, buf) = buf.split_at_mut(buf.len() / 3 * 2);
(str_buf, buf, None)
}; };
if !payload.is_valid() { payload_base38_representation_with_tlv(payload, str_buf, bits_buf, tlv_buf)
return Err(Error::InvalidArgument); } else {
Err(ErrorCode::InvalidArgument.into())
}
} }
payload_base38_representation_with_tlv(payload, &mut bits, tlv_data) pub fn estimate_buffer_size(payload: &QrSetupPayload) -> Result<usize, Error> {
}
fn estimate_buffer_size(payload: &QrSetupPayload) -> Result<usize, Error> {
// Estimate the size of the needed buffer; initialize with the size of the standard payload. // Estimate the size of the needed buffer; initialize with the size of the standard payload.
let mut estimate = TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES; let mut estimate = TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES;
@ -294,15 +297,14 @@ fn estimate_buffer_size(payload: &QrSetupPayload) -> Result<usize, Error> {
size size
}; };
let vendor_data = payload.get_all_optional_data(); for data in payload.get_all_optional_data() {
vendor_data.values().for_each(|data| {
estimate += data_item_size_estimate(&data.data); estimate += data_item_size_estimate(&data.data);
}); }
estimate = estimate_struct_overhead(estimate); estimate = estimate_struct_overhead(estimate);
if estimate > u32::MAX as usize { if estimate > u32::MAX as usize {
return Err(Error::NoMemory); Err(ErrorCode::NoMemory)?;
} }
Ok(estimate) Ok(estimate)
@ -317,18 +319,38 @@ fn estimate_struct_overhead(first_field_size: usize) -> usize {
first_field_size + 4 + 2 first_field_size + 4 + 2
} }
pub(super) fn print_qr_code(qr_data: &str) { pub(super) fn print_qr_code(qr_code: &str) {
let needed_version = compute_qr_version(qr_data); info!("QR Code: {}", qr_code);
#[cfg(feature = "std")]
{
use qrcode::{render::unicode, QrCode, Version};
let needed_version = compute_qr_version(qr_code);
let code = let code =
QrCode::with_version(qr_data, Version::Normal(needed_version), qrcode::EcLevel::M).unwrap(); QrCode::with_version(qr_code, Version::Normal(needed_version), qrcode::EcLevel::M)
.unwrap();
let image = code let image = code
.render::<unicode::Dense1x2>() .render::<unicode::Dense1x2>()
.dark_color(unicode::Dense1x2::Light) .dark_color(unicode::Dense1x2::Light)
.light_color(unicode::Dense1x2::Dark) .light_color(unicode::Dense1x2::Dark)
.build(); .build();
info!("\n{}", image); info!("\n{}", image);
} }
}
pub fn compute_qr_code<'a>(
dev_det: &BasicInfoConfig,
comm_data: &CommissioningData,
discovery_capabilities: DiscoveryCapabilities,
buf: &'a mut [u8],
) -> Result<&'a str, Error> {
let qr_code_data = QrSetupPayload::new(dev_det, comm_data, discovery_capabilities);
payload_base38_representation(&qr_code_data, buf)
}
#[cfg(feature = "std")]
fn compute_qr_version(qr_data: &str) -> i16 { fn compute_qr_version(qr_data: &str) -> i16 {
match qr_data.len() { match qr_data.len() {
0..=38 => 2, 0..=38 => 2,
@ -346,11 +368,11 @@ fn populate_bits(
total_payload_data_size_in_bits: usize, total_payload_data_size_in_bits: usize,
) -> Result<(), Error> { ) -> Result<(), Error> {
if *offset + number_of_bits > total_payload_data_size_in_bits { if *offset + number_of_bits > total_payload_data_size_in_bits {
return Err(Error::InvalidArgument); Err(ErrorCode::InvalidArgument)?;
} }
if input >= 1u64 << number_of_bits { if input >= 1u64 << number_of_bits {
return Err(Error::InvalidArgument); Err(ErrorCode::InvalidArgument)?;
} }
let mut index = *offset; let mut index = *offset;
@ -368,70 +390,90 @@ fn populate_bits(
Ok(()) Ok(())
} }
fn payload_base38_representation_with_tlv( fn payload_base38_representation_with_tlv<'a>(
payload: &QrSetupPayload, payload: &QrSetupPayload,
bits: &mut [u8], str_buf: &'a mut [u8],
mut tlv_data: Option<TlvData>, bits_buf: &mut [u8],
) -> Result<String, Error> { tlv_buf: Option<&mut [u8]>,
if let Some(tlv_data) = tlv_data.as_mut() { ) -> Result<&'a str, Error> {
generate_tlv_from_optional_data(payload, tlv_data)?; let tlv_data = if let Some(tlv_buf) = tlv_buf {
Some(generate_tlv_from_optional_data(payload, tlv_buf)?)
} else {
None
};
let bits = generate_bit_set(payload, bits_buf, tlv_data)?;
let prefix = "MT:";
if str_buf.len() < prefix.as_bytes().len() {
Err(ErrorCode::NoSpace)?;
} }
let bytes_written = generate_bit_set(payload, bits, tlv_data)?; str_buf[..prefix.as_bytes().len()].copy_from_slice(prefix.as_bytes());
let base38_encoded = base38::encode(&*bits, Some(bytes_written));
Ok(format!("MT:{}", base38_encoded)) let mut offset = prefix.len();
for c in base38::encode(bits) {
let mut char_buf = [0; 4];
let str = c.encode_utf8(&mut char_buf);
if str_buf.len() - offset < str.as_bytes().len() {
Err(ErrorCode::NoSpace)?;
} }
fn generate_tlv_from_optional_data( str_buf[offset..offset + str.as_bytes().len()].copy_from_slice(str.as_bytes());
offset += str.as_bytes().len();
}
Ok(core::str::from_utf8(&str_buf[..offset])?)
}
fn generate_tlv_from_optional_data<'a>(
payload: &QrSetupPayload, payload: &QrSetupPayload,
tlv_data: &mut TlvData, tlv_buf: &'a mut [u8],
) -> Result<(), Error> { ) -> Result<&'a [u8], Error> {
let size_needed = tlv_data.max_data_length_in_bytes as usize; let mut wb = WriteBuf::new(tlv_buf);
let mut tlv_buffer = vec![0u8; size_needed];
let mut wb = WriteBuf::new(&mut tlv_buffer, size_needed);
let mut tw = TLVWriter::new(&mut wb); let mut tw = TLVWriter::new(&mut wb);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
let data = payload.get_all_optional_data();
for (tag, value) in data { for info in payload.get_all_optional_data() {
match &value.data { match &info.data {
QRCodeInfoType::String(data) => tw.utf8(TagType::Context(*tag), data.as_bytes())?, QRCodeInfoType::String(data) => tw.utf8(TagType::Context(info.tag), data.as_bytes())?,
QRCodeInfoType::Int32(data) => tw.i32(TagType::Context(*tag), *data)?, QRCodeInfoType::Int32(data) => tw.i32(TagType::Context(info.tag), *data)?,
QRCodeInfoType::Int64(data) => tw.i64(TagType::Context(*tag), *data)?, QRCodeInfoType::Int64(data) => tw.i64(TagType::Context(info.tag), *data)?,
QRCodeInfoType::UInt32(data) => tw.u32(TagType::Context(*tag), *data)?, QRCodeInfoType::UInt32(data) => tw.u32(TagType::Context(info.tag), *data)?,
QRCodeInfoType::UInt64(data) => tw.u64(TagType::Context(*tag), *data)?, QRCodeInfoType::UInt64(data) => tw.u64(TagType::Context(info.tag), *data)?,
} }
} }
tw.end_container()?; tw.end_container()?;
tlv_data.data_length_in_bytes = Some(tw.get_tail());
tlv_data.data = Some(tlv_buffer);
Ok(()) let tail = tw.get_tail();
Ok(&tlv_buf[..tail])
} }
fn generate_bit_set( fn generate_bit_set<'a>(
payload: &QrSetupPayload, payload: &QrSetupPayload,
bits: &mut [u8], bits_buf: &'a mut [u8],
tlv_data: Option<TlvData>, tlv_data: Option<&[u8]>,
) -> Result<usize, Error> { ) -> Result<&'a [u8], Error> {
let mut offset: usize = 0; let total_payload_size_in_bits =
TOTAL_PAYLOAD_DATA_SIZE_IN_BITS + tlv_data.map(|tlv_data| tlv_data.len() * 8).unwrap_or(0);
let total_payload_size_in_bits = if let Some(tlv_data) = &tlv_data { if bits_buf.len() * 8 < total_payload_size_in_bits {
TOTAL_PAYLOAD_DATA_SIZE_IN_BITS + (tlv_data.data_length_in_bytes.unwrap_or_default() * 8) Err(ErrorCode::BufferTooSmall)?;
} else {
TOTAL_PAYLOAD_DATA_SIZE_IN_BITS
};
if bits.len() * 8 < total_payload_size_in_bits {
return Err(Error::BufferTooSmall);
}; };
let passwd = passwd_from_comm_data(payload.comm_data); let passwd = passwd_from_comm_data(payload.comm_data);
let mut offset: usize = 0;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.version as u64, payload.version as u64,
VERSION_FIELD_LENGTH_IN_BITS, VERSION_FIELD_LENGTH_IN_BITS,
@ -439,7 +481,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.dev_det.vid as u64, payload.dev_det.vid as u64,
VENDOR_IDFIELD_LENGTH_IN_BITS, VENDOR_IDFIELD_LENGTH_IN_BITS,
@ -447,7 +489,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.dev_det.pid as u64, payload.dev_det.pid as u64,
PRODUCT_IDFIELD_LENGTH_IN_BITS, PRODUCT_IDFIELD_LENGTH_IN_BITS,
@ -455,7 +497,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.flow_type as u64, payload.flow_type as u64,
COMMISSIONING_FLOW_FIELD_LENGTH_IN_BITS, COMMISSIONING_FLOW_FIELD_LENGTH_IN_BITS,
@ -463,7 +505,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.discovery_capabilities.as_bits() as u64, payload.discovery_capabilities.as_bits() as u64,
RENDEZVOUS_INFO_FIELD_LENGTH_IN_BITS, RENDEZVOUS_INFO_FIELD_LENGTH_IN_BITS,
@ -471,7 +513,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.comm_data.discriminator as u64, payload.comm_data.discriminator as u64,
PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS, PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS,
@ -479,7 +521,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
passwd as u64, passwd as u64,
SETUP_PINCODE_FIELD_LENGTH_IN_BITS, SETUP_PINCODE_FIELD_LENGTH_IN_BITS,
@ -487,7 +529,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
0, 0,
PADDING_FIELD_LENGTH_IN_BITS, PADDING_FIELD_LENGTH_IN_BITS,
@ -495,26 +537,22 @@ fn generate_bit_set(
)?; )?;
if let Some(tlv_data) = tlv_data { if let Some(tlv_data) = tlv_data {
populate_tlv_bits(bits, &mut offset, tlv_data, total_payload_size_in_bits)?; populate_tlv_bits(bits_buf, &mut offset, tlv_data, total_payload_size_in_bits)?;
} }
let bytes_written = (offset + 7) / 8; let bytes_written = (offset + 7) / 8;
Ok(bytes_written)
Ok(&bits_buf[..bytes_written])
} }
fn populate_tlv_bits( fn populate_tlv_bits(
bits: &mut [u8], bits_buf: &mut [u8],
offset: &mut usize, offset: &mut usize,
tlv_data: TlvData, tlv_data: &[u8],
total_payload_size_in_bits: usize, total_payload_size_in_bits: usize,
) -> Result<(), Error> { ) -> Result<(), Error> {
if let (Some(data), Some(data_length_in_bytes)) = (tlv_data.data, tlv_data.data_length_in_bytes) for b in tlv_data {
{ populate_bits(bits_buf, offset, *b as u64, 8, total_payload_size_in_bits)?;
for b in data.iter().take(data_length_in_bytes) {
populate_bits(bits, offset, *b as u64, 8, total_payload_size_in_bits)?;
}
} else {
return Err(Error::InvalidArgument);
} }
Ok(()) Ok(())
@ -532,16 +570,15 @@ fn is_common_tag(tag: u8) -> bool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::secure_channel::spake2p::VerifierData; use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand};
#[test] #[test]
fn can_base38_encode() { fn can_base38_encode() {
const QR_CODE: &str = "MT:YNJV7VSC00CMVH7SR00"; const QR_CODE: &str = "MT:YNJV7VSC00CMVH7SR00";
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(34567890), verifier: VerifierData::new_with_pw(34567890, dummy_rand),
discriminator: 2976, discriminator: 2976,
}; };
let dev_det = BasicInfoConfig { let dev_det = BasicInfoConfig {
@ -552,7 +589,9 @@ mod tests {
let disc_cap = DiscoveryCapabilities::new(false, true, false); let disc_cap = DiscoveryCapabilities::new(false, true, false);
let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap); let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap);
let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); let mut buf = [0; 1024];
let data_str =
payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode");
assert_eq!(data_str, QR_CODE) assert_eq!(data_str, QR_CODE)
} }
@ -561,19 +600,21 @@ mod tests {
const QR_CODE: &str = "MT:-24J0AFN00KA064IJ3P0IXZB0DK5N1K8SQ1RYCU1-A40"; const QR_CODE: &str = "MT:-24J0AFN00KA064IJ3P0IXZB0DK5N1K8SQ1RYCU1-A40";
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(20202021), verifier: VerifierData::new_with_pw(20202021, dummy_rand),
discriminator: 3840, discriminator: 3840,
}; };
let dev_det = BasicInfoConfig { let dev_det = BasicInfoConfig {
vid: 65521, vid: 65521,
pid: 32769, pid: 32769,
serial_no: "1234567890".to_string(), serial_no: "1234567890",
..Default::default() ..Default::default()
}; };
let disc_cap = DiscoveryCapabilities::new(true, false, false); let disc_cap = DiscoveryCapabilities::new(true, false, false);
let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap); let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap);
let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); let mut buf = [0; 1024];
let data_str =
payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode");
assert_eq!(data_str, QR_CODE) assert_eq!(data_str, QR_CODE)
} }
@ -588,13 +629,13 @@ mod tests {
const OPTIONAL_DEFAULT_INT_VALUE: i32 = 65550; const OPTIONAL_DEFAULT_INT_VALUE: i32 = 65550;
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(20202021), verifier: VerifierData::new_with_pw(20202021, dummy_rand),
discriminator: 3840, discriminator: 3840,
}; };
let dev_det = BasicInfoConfig { let dev_det = BasicInfoConfig {
vid: 65521, vid: 65521,
pid: 32769, pid: 32769,
serial_no: "1234567890".to_string(), serial_no: "1234567890",
..Default::default() ..Default::default()
}; };
@ -604,7 +645,7 @@ mod tests {
qr_code_data qr_code_data
.add_optional_vendor_data( .add_optional_vendor_data(
OPTIONAL_DEFAULT_STRING_TAG, OPTIONAL_DEFAULT_STRING_TAG,
QRCodeInfoType::String(OPTIONAL_DEFAULT_STRING_VALUE.to_string()), QRCodeInfoType::String(OPTIONAL_DEFAULT_STRING_VALUE.into()),
) )
.expect("Failed to add optional data"); .expect("Failed to add optional data");
@ -617,7 +658,9 @@ mod tests {
) )
.expect("Failed to add optional data"); .expect("Failed to add optional data");
let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); let mut buf = [0; 1024];
let data_str =
payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode");
assert_eq!(data_str, QR_CODE) assert_eq!(data_str, QR_CODE)
} }
} }

84
matter/src/persist.rs Normal file
View file

@ -0,0 +1,84 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#[cfg(feature = "std")]
pub use file_psm::*;
#[cfg(feature = "std")]
mod file_psm {
use std::fs;
use std::io::{Read, Write};
use std::path::PathBuf;
use log::info;
use crate::error::{Error, ErrorCode};
pub struct FilePsm {
dir: PathBuf,
}
impl FilePsm {
pub fn new(dir: PathBuf) -> Result<Self, Error> {
fs::create_dir_all(&dir)?;
Ok(Self { dir })
}
pub fn load<'a>(&self, key: &str, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> {
let path = self.dir.join(key);
match fs::File::open(path) {
Ok(mut file) => {
let mut offset = 0;
loop {
if offset == buf.len() {
Err(ErrorCode::NoSpace)?;
}
let len = file.read(&mut buf[offset..])?;
if len == 0 {
break;
}
offset += len;
}
let data = &buf[..offset];
info!("Key {}: loaded {} bytes {:?}", key, data.len(), data);
Ok(Some(data))
}
Err(_) => Ok(None),
}
}
pub fn store(&self, key: &str, data: &[u8]) -> Result<(), Error> {
let path = self.dir.join(key);
let mut file = fs::File::create(path)?;
file.write_all(data)?;
info!("Key {}: stored {} bytes {:?}", key, data.len(), data);
Ok(())
}
}
}

View file

@ -15,35 +15,33 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::Arc; use core::cell::RefCell;
use log::{error, trace}; use log::{error, trace};
use owning_ref::RwLockReadGuardRef;
use rand::prelude::*;
use crate::{ use crate::{
cert::Cert, cert::Cert,
crypto::{self, CryptoKeyPair, KeyPair, Sha256}, crypto::{self, KeyPair, Sha256},
error::Error, error::{Error, ErrorCode},
fabric::{Fabric, FabricMgr, FabricMgrInner}, fabric::{Fabric, FabricMgr},
secure_channel::common::SCStatusCodes, secure_channel::common::SCStatusCodes,
secure_channel::common::{self, OpCode}, secure_channel::common::{self, OpCode},
tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType},
transport::{ transport::{
network::Address, network::Address,
proto_demux::{ProtoCtx, ResponseRequired}, proto_ctx::ProtoCtx,
queue::{Msg, WorkQ},
session::{CaseDetails, CloneData, NocCatIds, SessionMode}, session::{CaseDetails, CloneData, NocCatIds, SessionMode},
}, },
utils::writebuf::WriteBuf, utils::{rand::Rand, writebuf::WriteBuf},
}; };
#[derive(PartialEq)] #[derive(Debug, Copy, Clone, Eq, PartialEq)]
enum State { enum State {
Sigma1Rx, Sigma1Rx,
Sigma3Rx, Sigma3Rx,
} }
#[derive(Debug, Clone)]
pub struct CaseSession { pub struct CaseSession {
state: State, state: State,
peer_sessid: u16, peer_sessid: u16,
@ -54,6 +52,7 @@ pub struct CaseSession {
peer_pub_key: [u8; crypto::EC_POINT_LEN_BYTES], peer_pub_key: [u8; crypto::EC_POINT_LEN_BYTES],
local_fabric_idx: usize, local_fabric_idx: usize,
} }
impl CaseSession { impl CaseSession {
pub fn new(peer_sessid: u16, local_sessid: u16) -> Result<Self, Error> { pub fn new(peer_sessid: u16, local_sessid: u16) -> Result<Self, Error> {
Ok(Self { Ok(Self {
@ -69,46 +68,53 @@ impl CaseSession {
} }
} }
pub struct Case { pub struct Case<'a> {
fabric_mgr: Arc<FabricMgr>, fabric_mgr: &'a RefCell<FabricMgr>,
rand: Rand,
} }
impl Case { impl<'a> Case<'a> {
pub fn new(fabric_mgr: Arc<FabricMgr>) -> Self { #[inline(always)]
Self { fabric_mgr } pub fn new(fabric_mgr: &'a RefCell<FabricMgr>, rand: Rand) -> Self {
Self { fabric_mgr, rand }
} }
pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { pub fn casesigma3_handler(
&mut self,
ctx: &mut ProtoCtx,
) -> Result<(bool, Option<CloneData>), Error> {
let mut case_session = ctx let mut case_session = ctx
.exch_ctx .exch_ctx
.exch .exch
.take_data_boxed::<CaseSession>() .take_case_session()
.ok_or(Error::InvalidState)?; .ok_or(ErrorCode::InvalidState)?;
if case_session.state != State::Sigma1Rx { if case_session.state != State::Sigma1Rx {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
case_session.state = State::Sigma3Rx; case_session.state = State::Sigma3Rx;
let fabric = self.fabric_mgr.get_fabric(case_session.local_fabric_idx)?; let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
if fabric.is_none() { if fabric.is_none() {
common::create_sc_status_report( common::create_sc_status_report(
&mut ctx.tx, ctx.tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )?;
ctx.exch_ctx.exch.close(); ctx.exch_ctx.exch.close();
return Ok(ResponseRequired::Yes); return Ok((true, None));
} }
// Safe to unwrap here // Safe to unwrap here
let fabric = fabric.as_ref().as_ref().unwrap(); let fabric = fabric.unwrap();
let root = get_root_node_struct(ctx.rx.as_borrow_slice())?; let root = get_root_node_struct(ctx.rx.as_slice())?;
let encrypted = root.find_tag(1)?.slice()?; let encrypted = root.find_tag(1)?.slice()?;
let mut decrypted: [u8; 800] = [0; 800]; let mut decrypted: [u8; 800] = [0; 800];
if encrypted.len() > decrypted.len() { if encrypted.len() > decrypted.len() {
error!("Data too large"); error!("Data too large");
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
let decrypted = &mut decrypted[..encrypted.len()]; let decrypted = &mut decrypted[..encrypted.len()];
decrypted.copy_from_slice(encrypted); decrypted.copy_from_slice(encrypted);
@ -126,13 +132,9 @@ impl Case {
} }
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) { if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) {
error!("Certificate Chain doesn't match: {}", e); error!("Certificate Chain doesn't match: {}", e);
common::create_sc_status_report( common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?;
&mut ctx.tx,
common::SCStatusCodes::InvalidParameter,
None,
)?;
ctx.exch_ctx.exch.close(); ctx.exch_ctx.exch.close();
return Ok(ResponseRequired::Yes); return Ok((true, None));
} }
if Case::validate_sigma3_sign( if Case::validate_sigma3_sign(
@ -145,19 +147,15 @@ impl Case {
.is_err() .is_err()
{ {
error!("Sigma3 Signature doesn't match"); error!("Sigma3 Signature doesn't match");
common::create_sc_status_report( common::create_sc_status_report(ctx.tx, common::SCStatusCodes::InvalidParameter, None)?;
&mut ctx.tx,
common::SCStatusCodes::InvalidParameter,
None,
)?;
ctx.exch_ctx.exch.close(); ctx.exch_ctx.exch.close();
return Ok(ResponseRequired::Yes); return Ok((true, None));
} }
// Only now do we add this message to the TT Hash // Only now do we add this message to the TT Hash
let mut peer_catids: NocCatIds = Default::default(); let mut peer_catids: NocCatIds = Default::default();
initiator_noc.get_cat_ids(&mut peer_catids); initiator_noc.get_cat_ids(&mut peer_catids);
case_session.tt_hash.update(ctx.rx.as_borrow_slice())?; case_session.tt_hash.update(ctx.rx.as_slice())?;
let clone_data = Case::get_session_clone_data( let clone_data = Case::get_session_clone_data(
fabric.ipk.op_key(), fabric.ipk.op_key(),
fabric.get_node_id(), fabric.get_node_id(),
@ -166,48 +164,42 @@ impl Case {
&case_session, &case_session,
&peer_catids, &peer_catids,
)?; )?;
// Queue a transport mgr request to add a new session
WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?;
common::create_sc_status_report( common::create_sc_status_report(ctx.tx, SCStatusCodes::SessionEstablishmentSuccess, None)?;
&mut ctx.tx, ctx.exch_ctx.exch.clear_data();
SCStatusCodes::SessionEstablishmentSuccess,
None,
)?;
ctx.exch_ctx.exch.clear_data_boxed();
ctx.exch_ctx.exch.close(); ctx.exch_ctx.exch.close();
Ok((true, Some(clone_data)))
Ok(ResponseRequired::Yes)
} }
pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<bool, Error> {
ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8);
let rx_buf = ctx.rx.as_borrow_slice(); let rx_buf = ctx.rx.as_slice();
let root = get_root_node_struct(rx_buf)?; let root = get_root_node_struct(rx_buf)?;
let r = Sigma1Req::from_tlv(&root)?; let r = Sigma1Req::from_tlv(&root)?;
let local_fabric_idx = self let local_fabric_idx = self
.fabric_mgr .fabric_mgr
.borrow_mut()
.match_dest_id(r.initiator_random.0, r.dest_id.0); .match_dest_id(r.initiator_random.0, r.dest_id.0);
if local_fabric_idx.is_err() { if local_fabric_idx.is_err() {
error!("Fabric Index mismatch"); error!("Fabric Index mismatch");
common::create_sc_status_report( common::create_sc_status_report(
&mut ctx.tx, ctx.tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )?;
ctx.exch_ctx.exch.close(); ctx.exch_ctx.exch.close();
return Ok(ResponseRequired::Yes); return Ok(true);
} }
let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id();
let mut case_session = Box::new(CaseSession::new(r.initiator_sessid, local_sessid)?); let mut case_session = CaseSession::new(r.initiator_sessid, local_sessid)?;
case_session.tt_hash.update(rx_buf)?; case_session.tt_hash.update(rx_buf)?;
case_session.local_fabric_idx = local_fabric_idx?; case_session.local_fabric_idx = local_fabric_idx?;
if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES {
error!("Invalid public key length"); error!("Invalid public key length");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
case_session.peer_pub_key.copy_from_slice(r.peer_pub_key.0); case_session.peer_pub_key.copy_from_slice(r.peer_pub_key.0);
trace!( trace!(
@ -216,19 +208,19 @@ impl Case {
); );
// Create an ephemeral Key Pair // Create an ephemeral Key Pair
let key_pair = KeyPair::new()?; let key_pair = KeyPair::new(self.rand)?;
let _ = key_pair.get_public_key(&mut case_session.our_pub_key)?; let _ = key_pair.get_public_key(&mut case_session.our_pub_key)?;
// Derive the Shared Secret // Derive the Shared Secret
let len = key_pair.derive_secret(r.peer_pub_key.0, &mut case_session.shared_secret)?; let len = key_pair.derive_secret(r.peer_pub_key.0, &mut case_session.shared_secret)?;
if len != 32 { if len != 32 {
error!("Derived secret length incorrect"); error!("Derived secret length incorrect");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
// println!("Derived secret: {:x?} len: {}", secret, len); // println!("Derived secret: {:x?} len: {}", secret, len);
let mut our_random: [u8; 32] = [0; 32]; let mut our_random: [u8; 32] = [0; 32];
rand::thread_rng().fill_bytes(&mut our_random); (self.rand)(&mut our_random);
// Derive the Encrypted Part // Derive the Encrypted Part
const MAX_ENCRYPTED_SIZE: usize = 800; const MAX_ENCRYPTED_SIZE: usize = 800;
@ -236,19 +228,21 @@ impl Case {
let mut encrypted: [u8; MAX_ENCRYPTED_SIZE] = [0; MAX_ENCRYPTED_SIZE]; let mut encrypted: [u8; MAX_ENCRYPTED_SIZE] = [0; MAX_ENCRYPTED_SIZE];
let encrypted_len = { let encrypted_len = {
let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES];
let fabric = self.fabric_mgr.get_fabric(case_session.local_fabric_idx)?; let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
if fabric.is_none() { if fabric.is_none() {
common::create_sc_status_report( common::create_sc_status_report(
&mut ctx.tx, ctx.tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )?;
ctx.exch_ctx.exch.close(); ctx.exch_ctx.exch.close();
return Ok(ResponseRequired::Yes); return Ok(true);
} }
let sign_len = Case::get_sigma2_sign( let sign_len = Case::get_sigma2_sign(
&fabric, fabric.unwrap(),
&case_session.our_pub_key, &case_session.our_pub_key,
&case_session.peer_pub_key, &case_session.peer_pub_key,
&mut signature, &mut signature,
@ -256,7 +250,8 @@ impl Case {
let signature = &signature[..sign_len]; let signature = &signature[..sign_len];
Case::get_sigma2_encryption( Case::get_sigma2_encryption(
&fabric, fabric.unwrap(),
self.rand,
&our_random, &our_random,
&mut case_session, &mut case_session,
signature, signature,
@ -273,9 +268,9 @@ impl Case {
tw.str8(TagType::Context(3), &case_session.our_pub_key)?; tw.str8(TagType::Context(3), &case_session.our_pub_key)?;
tw.str16(TagType::Context(4), encrypted)?; tw.str16(TagType::Context(4), encrypted)?;
tw.end_container()?; tw.end_container()?;
case_session.tt_hash.update(ctx.tx.as_borrow_slice())?; case_session.tt_hash.update(ctx.tx.as_mut_slice())?;
ctx.exch_ctx.exch.set_data_boxed(case_session); ctx.exch_ctx.exch.set_case_session(case_session);
Ok(ResponseRequired::Yes) Ok(true)
} }
fn get_session_clone_data( fn get_session_clone_data(
@ -322,8 +317,8 @@ impl Case {
case_session: &CaseSession, case_session: &CaseSession,
) -> Result<(), Error> { ) -> Result<(), Error> {
const MAX_TBS_SIZE: usize = 800; const MAX_TBS_SIZE: usize = 800;
let mut buf: [u8; MAX_TBS_SIZE] = [0; MAX_TBS_SIZE]; let mut buf = [0; MAX_TBS_SIZE];
let mut write_buf = WriteBuf::new(&mut buf, MAX_TBS_SIZE); let mut write_buf = WriteBuf::new(&mut buf);
let mut tw = TLVWriter::new(&mut write_buf); let mut tw = TLVWriter::new(&mut write_buf);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str16(TagType::Context(1), initiator_noc)?; tw.str16(TagType::Context(1), initiator_noc)?;
@ -343,20 +338,22 @@ impl Case {
let mut verifier = noc.verify_chain_start(); let mut verifier = noc.verify_chain_start();
if fabric.get_fabric_id() != noc.get_fabric_id()? { if fabric.get_fabric_id() != noc.get_fabric_id()? {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
if let Some(icac) = icac { if let Some(icac) = icac {
// If ICAC is present handle it // If ICAC is present handle it
if let Ok(fid) = icac.get_fabric_id() { if let Ok(fid) = icac.get_fabric_id() {
if fid != fabric.get_fabric_id() { if fid != fabric.get_fabric_id() {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
} }
verifier = verifier.add_cert(icac)?; verifier = verifier.add_cert(icac)?;
} }
verifier.add_cert(&fabric.root_ca)?.finalise()?; verifier
.add_cert(&Cert::new(&fabric.root_ca)?)?
.finalise()?;
Ok(()) Ok(())
} }
@ -370,18 +367,18 @@ impl Case {
0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x73, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x73,
]; ];
if key.len() < 48 { if key.len() < 48 {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
let mut salt = Vec::<u8>::with_capacity(256); let mut salt = heapless::Vec::<u8, 256>::new();
salt.extend_from_slice(ipk); salt.extend_from_slice(ipk).unwrap();
let tt = tt.clone(); let tt = tt.clone();
let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES];
tt.finish(&mut tt_hash)?; tt.finish(&mut tt_hash)?;
salt.extend_from_slice(&tt_hash); salt.extend_from_slice(&tt_hash).unwrap();
// println!("Session Key: salt: {:x?}, len: {}", salt, salt.len()); // println!("Session Key: salt: {:x?}, len: {}", salt, salt.len());
crypto::hkdf_sha256(salt.as_slice(), shared_secret, &SEKEYS_INFO, key) crypto::hkdf_sha256(salt.as_slice(), shared_secret, &SEKEYS_INFO, key)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// println!("Session Key: key: {:x?}", key); // println!("Session Key: key: {:x?}", key);
Ok(()) Ok(())
@ -418,20 +415,20 @@ impl Case {
) -> Result<(), Error> { ) -> Result<(), Error> {
const S3K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x33]; const S3K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x33];
if key.len() < 16 { if key.len() < 16 {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
let mut salt = Vec::<u8>::with_capacity(256); let mut salt = heapless::Vec::<u8, 256>::new();
salt.extend_from_slice(ipk); salt.extend_from_slice(ipk).unwrap();
let tt = tt.clone(); let tt = tt.clone();
let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES];
tt.finish(&mut tt_hash)?; tt.finish(&mut tt_hash)?;
salt.extend_from_slice(&tt_hash); salt.extend_from_slice(&tt_hash).unwrap();
// println!("Sigma3Key: salt: {:x?}, len: {}", salt, salt.len()); // println!("Sigma3Key: salt: {:x?}, len: {}", salt, salt.len());
crypto::hkdf_sha256(salt.as_slice(), shared_secret, &S3K_INFO, key) crypto::hkdf_sha256(salt.as_slice(), shared_secret, &S3K_INFO, key)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// println!("Sigma3Key: key: {:x?}", key); // println!("Sigma3Key: key: {:x?}", key);
Ok(()) Ok(())
@ -445,39 +442,37 @@ impl Case {
) -> Result<(), Error> { ) -> Result<(), Error> {
const S2K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x32]; const S2K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x32];
if key.len() < 16 { if key.len() < 16 {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
let mut salt = Vec::<u8>::with_capacity(256); let mut salt = heapless::Vec::<u8, 256>::new();
salt.extend_from_slice(ipk); salt.extend_from_slice(ipk).unwrap();
salt.extend_from_slice(our_random); salt.extend_from_slice(our_random).unwrap();
salt.extend_from_slice(&case_session.our_pub_key); salt.extend_from_slice(&case_session.our_pub_key).unwrap();
let tt = case_session.tt_hash.clone(); let tt = case_session.tt_hash.clone();
let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES];
tt.finish(&mut tt_hash)?; tt.finish(&mut tt_hash)?;
salt.extend_from_slice(&tt_hash); salt.extend_from_slice(&tt_hash).unwrap();
// println!("Sigma2Key: salt: {:x?}, len: {}", salt, salt.len()); // println!("Sigma2Key: salt: {:x?}, len: {}", salt, salt.len());
crypto::hkdf_sha256(salt.as_slice(), &case_session.shared_secret, &S2K_INFO, key) crypto::hkdf_sha256(salt.as_slice(), &case_session.shared_secret, &S2K_INFO, key)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// println!("Sigma2Key: key: {:x?}", key); // println!("Sigma2Key: key: {:x?}", key);
Ok(()) Ok(())
} }
fn get_sigma2_encryption( fn get_sigma2_encryption(
fabric: &RwLockReadGuardRef<FabricMgrInner, Option<Fabric>>, fabric: &Fabric,
rand: Rand,
our_random: &[u8], our_random: &[u8],
case_session: &mut CaseSession, case_session: &mut CaseSession,
signature: &[u8], signature: &[u8],
out: &mut [u8], out: &mut [u8],
) -> Result<usize, Error> { ) -> Result<usize, Error> {
let mut resumption_id: [u8; 16] = [0; 16]; let mut resumption_id: [u8; 16] = [0; 16];
rand::thread_rng().fill_bytes(&mut resumption_id); rand(&mut resumption_id);
// We are guaranteed this unwrap will work
let fabric = fabric.as_ref().as_ref().unwrap();
let mut sigma2_key = [0_u8; crypto::SYMM_KEY_LEN_BYTES]; let mut sigma2_key = [0_u8; crypto::SYMM_KEY_LEN_BYTES];
Case::get_sigma2_key( Case::get_sigma2_key(
@ -487,12 +482,12 @@ impl Case {
&mut sigma2_key, &mut sigma2_key,
)?; )?;
let mut write_buf = WriteBuf::new(out, out.len()); let mut write_buf = WriteBuf::new(out);
let mut tw = TLVWriter::new(&mut write_buf); let mut tw = TLVWriter::new(&mut write_buf);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; tw.str16(TagType::Context(1), &fabric.noc)?;
if let Some(icac_cert) = &fabric.icac { if let Some(icac_cert) = fabric.icac.as_ref() {
tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))? tw.str16(TagType::Context(2), icac_cert)?
}; };
tw.str8(TagType::Context(3), signature)?; tw.str8(TagType::Context(3), signature)?;
@ -521,21 +516,20 @@ impl Case {
} }
fn get_sigma2_sign( fn get_sigma2_sign(
fabric: &RwLockReadGuardRef<FabricMgrInner, Option<Fabric>>, fabric: &Fabric,
our_pub_key: &[u8], our_pub_key: &[u8],
peer_pub_key: &[u8], peer_pub_key: &[u8],
signature: &mut [u8], signature: &mut [u8],
) -> Result<usize, Error> { ) -> Result<usize, Error> {
// We are guaranteed this unwrap will work // We are guaranteed this unwrap will work
let fabric = fabric.as_ref().as_ref().unwrap();
const MAX_TBS_SIZE: usize = 800; const MAX_TBS_SIZE: usize = 800;
let mut buf: [u8; MAX_TBS_SIZE] = [0; MAX_TBS_SIZE]; let mut buf = [0; MAX_TBS_SIZE];
let mut write_buf = WriteBuf::new(&mut buf, MAX_TBS_SIZE); let mut write_buf = WriteBuf::new(&mut buf);
let mut tw = TLVWriter::new(&mut write_buf); let mut tw = TLVWriter::new(&mut write_buf);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; tw.str16(TagType::Context(1), &fabric.noc)?;
if let Some(icac_cert) = &fabric.icac { if let Some(icac_cert) = fabric.icac.as_deref() {
tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))?; tw.str16(TagType::Context(2), icac_cert)?;
} }
tw.str8(TagType::Context(3), our_pub_key)?; tw.str8(TagType::Context(3), our_pub_key)?;
tw.str8(TagType::Context(4), peer_pub_key)?; tw.str8(TagType::Context(4), peer_pub_key)?;

View file

@ -15,25 +15,16 @@
* limitations under the License. * limitations under the License.
*/ */
use boxslab::Slab;
use log::info;
use num_derive::FromPrimitive; use num_derive::FromPrimitive;
use crate::{ use crate::{error::Error, transport::packet::Packet};
error::Error,
transport::{
exchange::Exchange,
packet::{Packet, PacketPool},
session::SessionHandle,
},
};
use super::status_report::{create_status_report, GeneralCode}; use super::status_report::{create_status_report, GeneralCode};
/* Interaction Model ID as per the Matter Spec */ /* Interaction Model ID as per the Matter Spec */
pub const PROTO_ID_SECURE_CHANNEL: usize = 0x00; pub const PROTO_ID_SECURE_CHANNEL: u16 = 0x00;
#[derive(FromPrimitive, Debug)] #[derive(FromPrimitive, Debug, Copy, Clone, Eq, PartialEq)]
pub enum OpCode { pub enum OpCode {
MsgCounterSyncReq = 0x00, MsgCounterSyncReq = 0x00,
MsgCounterSyncResp = 0x01, MsgCounterSyncResp = 0x01,
@ -78,6 +69,7 @@ pub fn create_sc_status_report(
| SCStatusCodes::NoSharedTrustRoots | SCStatusCodes::NoSharedTrustRoots
| SCStatusCodes::SessionNotFound => GeneralCode::Failure, | SCStatusCodes::SessionNotFound => GeneralCode::Failure,
}; };
create_status_report( create_status_report(
proto_tx, proto_tx,
general_code, general_code,
@ -88,14 +80,16 @@ pub fn create_sc_status_report(
} }
pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) { pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) {
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); proto_tx.reset();
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8); proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8);
proto_tx.unset_reliable(); proto_tx.unset_reliable();
} }
pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> { // TODO
info!("Sending standalone ACK"); // pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> {
let mut ack_packet = Slab::<PacketPool>::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?; // info!("Sending standalone ACK");
create_mrp_standalone_ack(&mut ack_packet); // let mut ack_packet = Slab::<PacketPool>::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?;
exch.send(ack_packet, sess) // create_mrp_standalone_ack(&mut ack_packet);
} // exch.send(ack_packet, sess)
// }

View file

@ -15,65 +15,96 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::Arc; use core::{borrow::Borrow, cell::RefCell};
use crate::{ use crate::{
error::*, error::*,
fabric::FabricMgr, fabric::FabricMgr,
mdns::Mdns,
secure_channel::common::*, secure_channel::common::*,
tlv, tlv,
transport::proto_demux::{self, ProtoCtx, ResponseRequired}, transport::{proto_ctx::ProtoCtx, session::CloneData},
utils::{epoch::Epoch, rand::Rand},
}; };
use log::{error, info}; use log::{error, info};
use num;
use super::{case::Case, pake::PaseMgr}; use super::{case::Case, pake::PaseMgr};
/* Handle messages related to the Secure Channel /* Handle messages related to the Secure Channel
*/ */
pub struct SecureChannel { pub struct SecureChannel<'a> {
case: Case, case: Case<'a>,
pase: PaseMgr, pase: &'a RefCell<PaseMgr>,
mdns: &'a dyn Mdns,
} }
impl SecureChannel { impl<'a> SecureChannel<'a> {
pub fn new(pase: PaseMgr, fabric_mgr: Arc<FabricMgr>) -> SecureChannel { #[inline(always)]
SecureChannel { pub fn new<
T: Borrow<RefCell<FabricMgr>>
+ Borrow<RefCell<PaseMgr>>
+ Borrow<dyn Mdns + 'a>
+ Borrow<Epoch>
+ Borrow<Rand>,
>(
matter: &'a T,
) -> Self {
Self::wrap(
matter.borrow(),
matter.borrow(),
matter.borrow(),
*matter.borrow(),
)
}
#[inline(always)]
pub fn wrap(
pase: &'a RefCell<PaseMgr>,
fabric: &'a RefCell<FabricMgr>,
mdns: &'a dyn Mdns,
rand: Rand,
) -> Self {
Self {
case: Case::new(fabric, rand),
pase, pase,
case: Case::new(fabric_mgr), mdns,
}
} }
} }
impl proto_demux::HandleProto for SecureChannel { pub fn handle(&mut self, ctx: &mut ProtoCtx) -> Result<(bool, Option<CloneData>), Error> {
fn handle_proto_id(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { let proto_opcode: OpCode = ctx.rx.get_proto_opcode()?;
let proto_opcode: OpCode =
num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16);
info!("Received Opcode: {:?}", proto_opcode); info!("Received Opcode: {:?}", proto_opcode);
info!("Received Data:"); info!("Received Data:");
tlv::print_tlv_list(ctx.rx.as_borrow_slice()); tlv::print_tlv_list(ctx.rx.as_slice());
let result = match proto_opcode { let (reply, clone_data) = match proto_opcode {
OpCode::MRPStandAloneAck => Ok(ResponseRequired::No), OpCode::MRPStandAloneAck => Ok((false, None)),
OpCode::PBKDFParamRequest => self.pase.pbkdfparamreq_handler(ctx), OpCode::PBKDFParamRequest => self
OpCode::PASEPake1 => self.pase.pasepake1_handler(ctx), .pase
OpCode::PASEPake3 => self.pase.pasepake3_handler(ctx), .borrow_mut()
OpCode::CASESigma1 => self.case.casesigma1_handler(ctx), .pbkdfparamreq_handler(ctx)
.map(|reply| (reply, None)),
OpCode::PASEPake1 => self
.pase
.borrow_mut()
.pasepake1_handler(ctx)
.map(|reply| (reply, None)),
OpCode::PASEPake3 => self.pase.borrow_mut().pasepake3_handler(ctx, self.mdns),
OpCode::CASESigma1 => self.case.casesigma1_handler(ctx).map(|reply| (reply, None)),
OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), OpCode::CASESigma3 => self.case.casesigma3_handler(ctx),
_ => { _ => {
error!("OpCode Not Handled: {:?}", proto_opcode); error!("OpCode Not Handled: {:?}", proto_opcode);
Err(Error::InvalidOpcode) Err(ErrorCode::InvalidOpcode.into())
} }
}; }?;
if result == Ok(ResponseRequired::Yes) {
if reply {
info!("Sending response"); info!("Sending response");
tlv::print_tlv_list(ctx.tx.as_borrow_slice()); tlv::print_tlv_list(ctx.tx.as_mut_slice());
}
result
} }
fn get_proto_id(&self) -> usize { Ok((reply, clone_data))
PROTO_ID_SECURE_CHANNEL
} }
} }

View file

@ -15,40 +15,17 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error; #[cfg(not(any(
feature = "crypto_openssl",
// This trait allows us to switch between crypto providers like OpenSSL and mbedTLS for Spake2 feature = "crypto_mbedtls",
// Currently this is only validate for a verifier(responder) feature = "crypto_rustcrypto"
)))]
// A verifier will typically do: pub use super::crypto_dummy::CryptoSpake2;
// Step 1: w0 and L #[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))]
// set_w0_from_w0s pub use super::crypto_esp_mbedtls::CryptoSpake2;
// set_L #[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))]
// Step 2: get_pB pub use super::crypto_mbedtls::CryptoSpake2;
// Step 3: get_TT_as_verifier(pA) #[cfg(feature = "crypto_openssl")]
// Step 4: Computation of cA and cB happens outside since it doesn't use either BigNum or EcPoint pub use super::crypto_openssl::CryptoSpake2;
pub trait CryptoSpake2 { #[cfg(feature = "crypto_rustcrypto")]
fn new() -> Result<Self, Error> pub use super::crypto_rustcrypto::CryptoSpake2;
where
Self: Sized;
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error>;
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>;
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error>;
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error>;
#[allow(non_snake_case)]
fn set_L(&mut self, l: &[u8]) -> Result<(), Error>;
#[allow(non_snake_case)]
fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>;
#[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error>;
#[allow(non_snake_case)]
fn get_TT_as_verifier(
&mut self,
context: &[u8],
pA: &[u8],
pB: &[u8],
out: &mut [u8],
) -> Result<(), Error>;
}

View file

@ -0,0 +1,76 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::{
error::{Error, ErrorCode},
utils::rand::Rand,
};
#[allow(non_snake_case)]
pub struct CryptoSpake2 {}
impl CryptoSpake2 {
#[allow(non_snake_case)]
pub fn new() -> Result<Self, Error> {
Ok(Self {})
}
// Computes w0 from w0s respectively
pub fn set_w0_from_w0s(&mut self, _w0s: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
pub fn set_w1_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
pub fn set_w0(&mut self, _w0: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
pub fn set_w1(&mut self, _w1: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
#[allow(non_snake_case)]
pub fn set_L(&mut self, _l: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
#[allow(non_snake_case)]
#[allow(dead_code)]
pub fn set_L_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
#[allow(non_snake_case)]
pub fn get_pB(&mut self, _pB: &mut [u8], _rand: Rand) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
#[allow(non_snake_case)]
pub fn get_TT_as_verifier(
&mut self,
_context: &[u8],
_pA: &[u8],
_pB: &[u8],
_out: &mut [u8],
) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
}

View file

@ -16,8 +16,7 @@
*/ */
use crate::error::Error; use crate::error::Error;
use crate::utils::rand::Rand;
use super::crypto::CryptoSpake2;
const MATTER_M_BIN: [u8; 65] = [ const MATTER_M_BIN: [u8; 65] = [
0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2,
@ -36,16 +35,16 @@ const MATTER_N_BIN: [u8; 65] = [
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct CryptoEspMbedTls {} pub struct CryptoSpake2 {}
impl CryptoSpake2 for CryptoEspMbedTls { impl CryptoSpake2 {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
Ok(CryptoEspMbedTls {}) Ok(Self {})
} }
// Computes w0 from w0s respectively // Computes w0 from w0s respectively
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w0 = w0s mod p // w0 = w0s mod p
// where p is the order of the curve // where p is the order of the curve
@ -53,7 +52,7 @@ impl CryptoSpake2 for CryptoEspMbedTls {
Ok(()) Ok(())
} }
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w1 = w1s mod p // w1 = w1s mod p
// where p is the order of the curve // where p is the order of the curve
@ -61,17 +60,17 @@ impl CryptoSpake2 for CryptoEspMbedTls {
Ok(()) Ok(())
} }
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> {
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(dead_code)] #[allow(dead_code)]
fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec, // From the Matter spec,
// L = w1 * P // L = w1 * P
// where P is the generator of the underlying elliptic curve // where P is the generator of the underlying elliptic curve
@ -79,7 +78,16 @@ impl CryptoSpake2 for CryptoEspMbedTls {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { #[allow(dead_code)]
pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec,
// L = w1 * P
// where P is the generator of the underlying elliptic curve
Ok(())
}
#[allow(non_snake_case)]
pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> {
// From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/)
// for y // for y
// - select random y between 0 to p // - select random y between 0 to p
@ -90,7 +98,7 @@ impl CryptoSpake2 for CryptoEspMbedTls {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_TT_as_verifier( pub fn get_TT_as_verifier(
&mut self, &mut self,
context: &[u8], context: &[u8],
pA: &[u8], pA: &[u8],
@ -101,13 +109,10 @@ impl CryptoSpake2 for CryptoEspMbedTls {
} }
} }
impl CryptoEspMbedTls {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::CryptoEspMbedTls; use super::CryptoSpake2;
use crate::secure_channel::crypto::CryptoSpake2;
use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use crate::secure_channel::spake2p_test_vectors::test_vectors::*;
use openssl::bn::BigNum; use openssl::bn::BigNum;
use openssl::ec::{EcPoint, PointConversionForm}; use openssl::ec::{EcPoint, PointConversionForm};
@ -116,13 +121,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_X() { fn test_get_X() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoEspMbedTls::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = BigNum::from_slice(&t.x).unwrap(); let x = BigNum::from_slice(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator(); let P = c.group.generator();
let r = let r = CryptoSpake2::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
CryptoEspMbedTls::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
assert_eq!( assert_eq!(
t.X, t.X,
r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx)
@ -136,12 +140,11 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_Y() { fn test_get_Y() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoEspMbedTls::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = BigNum::from_slice(&t.y).unwrap(); let y = BigNum::from_slice(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator(); let P = c.group.generator();
let r = let r = CryptoSpake2::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
CryptoEspMbedTls::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
assert_eq!( assert_eq!(
t.Y, t.Y,
r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx)
@ -155,12 +158,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_prover() { fn test_get_ZV_as_prover() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoEspMbedTls::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = BigNum::from_slice(&t.x).unwrap(); let x = BigNum::from_slice(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
c.set_w1(&t.w1).unwrap(); c.set_w1(&t.w1).unwrap();
let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap(); let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap();
let (Z, V) = CryptoEspMbedTls::get_ZV_as_prover( let (Z, V) = CryptoSpake2::get_ZV_as_prover(
&c.w0, &c.w0,
&c.w1, &c.w1,
&mut c.N, &mut c.N,
@ -191,12 +194,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_verifier() { fn test_get_ZV_as_verifier() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoEspMbedTls::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = BigNum::from_slice(&t.y).unwrap(); let y = BigNum::from_slice(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap(); let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap();
let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap(); let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap();
let (Z, V) = CryptoEspMbedTls::get_ZV_as_verifier( let (Z, V) = CryptoSpake2::get_ZV_as_verifier(
&c.w0, &c.w0,
&L, &L,
&mut c.M, &mut c.M,

View file

@ -15,14 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use alloc::sync::Arc;
ops::{Mul, Sub}, use core::ops::{Mul, Sub};
sync::Arc,
use crate::{
error::{Error, ErrorCode},
utils::rand::Rand,
}; };
use crate::error::Error;
use super::crypto::CryptoSpake2;
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use log::error; use log::error;
use mbedtls::{ use mbedtls::{
@ -33,6 +33,8 @@ use mbedtls::{
rng::{CtrDrbg, OsEntropy}, rng::{CtrDrbg, OsEntropy},
}; };
extern crate alloc;
const MATTER_M_BIN: [u8; 65] = [ const MATTER_M_BIN: [u8; 65] = [
0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2,
0x99, 0x3b, 0x64, 0xe1, 0x6e, 0xf3, 0xdc, 0xab, 0x95, 0xaf, 0xd4, 0x97, 0x33, 0x3d, 0x8f, 0xa1, 0x99, 0x3b, 0x64, 0xe1, 0x6e, 0xf3, 0xdc, 0xab, 0x95, 0xaf, 0xd4, 0x97, 0x33, 0x3d, 0x8f, 0xa1,
@ -50,7 +52,7 @@ const MATTER_N_BIN: [u8; 65] = [
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct CryptoMbedTLS { pub struct CryptoSpake2 {
group: EcGroup, group: EcGroup,
order: Mpi, order: Mpi,
xy: Mpi, xy: Mpi,
@ -62,15 +64,15 @@ pub struct CryptoMbedTLS {
pB: EcPoint, pB: EcPoint,
} }
impl CryptoSpake2 for CryptoMbedTLS { impl CryptoSpake2 {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let group = EcGroup::new(mbedtls::pk::EcGroupId::SecP256R1)?; let group = EcGroup::new(mbedtls::pk::EcGroupId::SecP256R1)?;
let order = group.order()?; let order = group.order()?;
let M = EcPoint::from_binary(&group, &MATTER_M_BIN)?; let M = EcPoint::from_binary(&group, &MATTER_M_BIN)?;
let N = EcPoint::from_binary(&group, &MATTER_N_BIN)?; let N = EcPoint::from_binary(&group, &MATTER_N_BIN)?;
Ok(CryptoMbedTLS { Ok(Self {
group, group,
order, order,
xy: Mpi::new(0)?, xy: Mpi::new(0)?,
@ -84,7 +86,7 @@ impl CryptoSpake2 for CryptoMbedTLS {
} }
// Computes w0 from w0s respectively // Computes w0 from w0s respectively
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w0 = w0s mod p // w0 = w0s mod p
// where p is the order of the curve // where p is the order of the curve
@ -94,7 +96,7 @@ impl CryptoSpake2 for CryptoMbedTLS {
Ok(()) Ok(())
} }
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w1 = w1s mod p // w1 = w1s mod p
// where p is the order of the curve // where p is the order of the curve
@ -104,24 +106,25 @@ impl CryptoSpake2 for CryptoMbedTLS {
Ok(()) Ok(())
} }
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> {
self.w0 = Mpi::from_binary(w0)?; self.w0 = Mpi::from_binary(w0)?;
Ok(()) Ok(())
} }
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> {
self.w1 = Mpi::from_binary(w1)?; self.w1 = Mpi::from_binary(w1)?;
Ok(()) Ok(())
} }
fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { #[allow(non_snake_case)]
pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> {
self.L = EcPoint::from_binary(&self.group, l)?; self.L = EcPoint::from_binary(&self.group, l)?;
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(dead_code)] #[allow(dead_code)]
fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec, // From the Matter spec,
// L = w1 * P // L = w1 * P
// where P is the generator of the underlying elliptic curve // where P is the generator of the underlying elliptic curve
@ -132,7 +135,7 @@ impl CryptoSpake2 for CryptoMbedTLS {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> {
// From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/)
// for y // for y
// - select random y between 0 to p // - select random y between 0 to p
@ -150,14 +153,14 @@ impl CryptoSpake2 for CryptoMbedTLS {
let pB_internal = pB_internal.as_slice(); let pB_internal = pB_internal.as_slice();
if pB_internal.len() != pB.len() { if pB_internal.len() != pB.len() {
error!("pB length mismatch"); error!("pB length mismatch");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
pB.copy_from_slice(pB_internal); pB.copy_from_slice(pB_internal);
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_TT_as_verifier( pub fn get_TT_as_verifier(
&mut self, &mut self,
context: &[u8], context: &[u8],
pA: &[u8], pA: &[u8],
@ -166,21 +169,21 @@ impl CryptoSpake2 for CryptoMbedTLS {
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut TT = Md::new(mbedtls::hash::Type::Sha256)?; let mut TT = Md::new(mbedtls::hash::Type::Sha256)?;
// context // context
CryptoMbedTLS::add_to_tt(&mut TT, context)?; Self::add_to_tt(&mut TT, context)?;
// 2 empty identifiers // 2 empty identifiers
CryptoMbedTLS::add_to_tt(&mut TT, &[])?; Self::add_to_tt(&mut TT, &[])?;
CryptoMbedTLS::add_to_tt(&mut TT, &[])?; Self::add_to_tt(&mut TT, &[])?;
// M // M
CryptoMbedTLS::add_to_tt(&mut TT, &MATTER_M_BIN)?; Self::add_to_tt(&mut TT, &MATTER_M_BIN)?;
// N // N
CryptoMbedTLS::add_to_tt(&mut TT, &MATTER_N_BIN)?; Self::add_to_tt(&mut TT, &MATTER_N_BIN)?;
// X = pA // X = pA
CryptoMbedTLS::add_to_tt(&mut TT, pA)?; Self::add_to_tt(&mut TT, pA)?;
// Y = pB // Y = pB
CryptoMbedTLS::add_to_tt(&mut TT, pB)?; Self::add_to_tt(&mut TT, pB)?;
let X = EcPoint::from_binary(&self.group, pA)?; let X = EcPoint::from_binary(&self.group, pA)?;
let (Z, V) = CryptoMbedTLS::get_ZV_as_verifier( let (Z, V) = Self::get_ZV_as_verifier(
&self.w0, &self.w0,
&self.L, &self.L,
&mut self.M, &mut self.M,
@ -193,24 +196,22 @@ impl CryptoSpake2 for CryptoMbedTLS {
// Z // Z
let tmp = Z.to_binary(&self.group, false)?; let tmp = Z.to_binary(&self.group, false)?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
// V // V
let tmp = V.to_binary(&self.group, false)?; let tmp = V.to_binary(&self.group, false)?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
// w0 // w0
let tmp = self.w0.to_binary()?; let tmp = self.w0.to_binary()?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
TT.finish(out)?; TT.finish(out)?;
Ok(()) Ok(())
} }
}
impl CryptoMbedTLS {
fn add_to_tt(tt: &mut Md, buf: &[u8]) -> Result<(), Error> { fn add_to_tt(tt: &mut Md, buf: &[u8]) -> Result<(), Error> {
let mut len_buf: [u8; 8] = [0; 8]; let mut len_buf: [u8; 8] = [0; 8];
LittleEndian::write_u64(&mut len_buf, buf.len() as u64); LittleEndian::write_u64(&mut len_buf, buf.len() as u64);
@ -247,7 +248,7 @@ impl CryptoMbedTLS {
let mut tmp = x.mul(w0)?; let mut tmp = x.mul(w0)?;
tmp = tmp.modulo(order)?; tmp = tmp.modulo(order)?;
let inverted_N = CryptoMbedTLS::invert(group, N)?; let inverted_N = Self::invert(group, N)?;
let Z = EcPoint::muladd(group, Y, x, &inverted_N, &tmp)?; let Z = EcPoint::muladd(group, Y, x, &inverted_N, &tmp)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
@ -283,7 +284,7 @@ impl CryptoMbedTLS {
let mut tmp = y.mul(w0)?; let mut tmp = y.mul(w0)?;
tmp = tmp.modulo(order)?; tmp = tmp.modulo(order)?;
let inverted_M = CryptoMbedTLS::invert(group, M)?; let inverted_M = Self::invert(group, M)?;
let Z = EcPoint::muladd(group, X, y, &inverted_M, &tmp)?; let Z = EcPoint::muladd(group, X, y, &inverted_M, &tmp)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
@ -302,8 +303,7 @@ impl CryptoMbedTLS {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::CryptoMbedTLS; use super::CryptoSpake2;
use crate::secure_channel::crypto::CryptoSpake2;
use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use crate::secure_channel::spake2p_test_vectors::test_vectors::*;
use mbedtls::bignum::Mpi; use mbedtls::bignum::Mpi;
use mbedtls::ecp::EcPoint; use mbedtls::ecp::EcPoint;
@ -312,7 +312,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_X() { fn test_get_X() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoMbedTLS::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = Mpi::from_binary(&t.x).unwrap(); let x = Mpi::from_binary(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator().unwrap(); let P = c.group.generator().unwrap();
@ -326,7 +326,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_Y() { fn test_get_Y() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoMbedTLS::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = Mpi::from_binary(&t.y).unwrap(); let y = Mpi::from_binary(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator().unwrap(); let P = c.group.generator().unwrap();
@ -339,12 +339,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_prover() { fn test_get_ZV_as_prover() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoMbedTLS::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = Mpi::from_binary(&t.x).unwrap(); let x = Mpi::from_binary(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
c.set_w1(&t.w1).unwrap(); c.set_w1(&t.w1).unwrap();
let Y = EcPoint::from_binary(&c.group, &t.Y).unwrap(); let Y = EcPoint::from_binary(&c.group, &t.Y).unwrap();
let (Z, V) = CryptoMbedTLS::get_ZV_as_prover( let (Z, V) = CryptoSpake2::get_ZV_as_prover(
&c.w0, &c.w0,
&c.w1, &c.w1,
&mut c.N, &mut c.N,
@ -364,12 +364,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_verifier() { fn test_get_ZV_as_verifier() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoMbedTLS::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = Mpi::from_binary(&t.y).unwrap(); let y = Mpi::from_binary(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let X = EcPoint::from_binary(&c.group, &t.X).unwrap(); let X = EcPoint::from_binary(&c.group, &t.X).unwrap();
let L = EcPoint::from_binary(&c.group, &t.L).unwrap(); let L = EcPoint::from_binary(&c.group, &t.L).unwrap();
let (Z, V) = CryptoMbedTLS::get_ZV_as_verifier( let (Z, V) = CryptoSpake2::get_ZV_as_verifier(
&c.w0, &c.w0,
&L, &L,
&mut c.M, &mut c.M,

View file

@ -15,9 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error; use crate::{
error::{Error, ErrorCode},
utils::rand::Rand,
};
use super::crypto::CryptoSpake2;
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use log::error; use log::error;
use openssl::{ use openssl::{
@ -44,7 +46,7 @@ const MATTER_N_BIN: [u8; 65] = [
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct CryptoOpenSSL { pub struct CryptoSpake2 {
group: EcGroup, group: EcGroup,
bn_ctx: BigNumContext, bn_ctx: BigNumContext,
// Stores the randomly generated x or y depending upon who we are // Stores the randomly generated x or y depending upon who we are
@ -58,9 +60,9 @@ pub struct CryptoOpenSSL {
order: BigNum, order: BigNum,
} }
impl CryptoSpake2 for CryptoOpenSSL { impl CryptoSpake2 {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
let mut bn_ctx = BigNumContext::new()?; let mut bn_ctx = BigNumContext::new()?;
let M = EcPoint::from_bytes(&group, &MATTER_M_BIN, &mut bn_ctx)?; let M = EcPoint::from_bytes(&group, &MATTER_M_BIN, &mut bn_ctx)?;
@ -70,7 +72,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
let mut order = BigNum::new()?; let mut order = BigNum::new()?;
group.as_ref().order(&mut order, &mut bn_ctx)?; group.as_ref().order(&mut order, &mut bn_ctx)?;
Ok(CryptoOpenSSL { Ok(Self {
group, group,
bn_ctx, bn_ctx,
xy: BigNum::new()?, xy: BigNum::new()?,
@ -85,7 +87,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
} }
// Computes w0 from w0s respectively // Computes w0 from w0s respectively
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w0 = w0s mod p // w0 = w0s mod p
// where p is the order of the curve // where p is the order of the curve
@ -96,7 +98,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
Ok(()) Ok(())
} }
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w1 = w1s mod p // w1 = w1s mod p
// where p is the order of the curve // where p is the order of the curve
@ -107,24 +109,25 @@ impl CryptoSpake2 for CryptoOpenSSL {
Ok(()) Ok(())
} }
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> {
self.w0 = BigNum::from_slice(w0)?; self.w0 = BigNum::from_slice(w0)?;
Ok(()) Ok(())
} }
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> {
self.w1 = BigNum::from_slice(w1)?; self.w1 = BigNum::from_slice(w1)?;
Ok(()) Ok(())
} }
fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { #[allow(non_snake_case)]
pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> {
self.L = EcPoint::from_bytes(&self.group, l, &mut self.bn_ctx)?; self.L = EcPoint::from_bytes(&self.group, l, &mut self.bn_ctx)?;
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(dead_code)] #[allow(dead_code)]
fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec, // From the Matter spec,
// L = w1 * P // L = w1 * P
// where P is the generator of the underlying elliptic curve // where P is the generator of the underlying elliptic curve
@ -135,7 +138,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> {
// From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/)
// for y // for y
// - select random y between 0 to p // - select random y between 0 to p
@ -143,7 +146,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
// - pB = Y // - pB = Y
self.order.rand_range(&mut self.xy)?; self.order.rand_range(&mut self.xy)?;
let P = self.group.generator(); let P = self.group.generator();
self.pB = CryptoOpenSSL::do_add_mul( self.pB = Self::do_add_mul(
P, P,
&self.xy, &self.xy,
&self.N, &self.N,
@ -159,14 +162,14 @@ impl CryptoSpake2 for CryptoOpenSSL {
let pB_internal = pB_internal.as_slice(); let pB_internal = pB_internal.as_slice();
if pB_internal.len() != pB.len() { if pB_internal.len() != pB.len() {
error!("pB length mismatch"); error!("pB length mismatch");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
pB.copy_from_slice(pB_internal); pB.copy_from_slice(pB_internal);
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_TT_as_verifier( pub fn get_TT_as_verifier(
&mut self, &mut self,
context: &[u8], context: &[u8],
pA: &[u8], pA: &[u8],
@ -175,21 +178,21 @@ impl CryptoSpake2 for CryptoOpenSSL {
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut TT = Hasher::new(MessageDigest::sha256())?; let mut TT = Hasher::new(MessageDigest::sha256())?;
// context // context
CryptoOpenSSL::add_to_tt(&mut TT, context)?; Self::add_to_tt(&mut TT, context)?;
// 2 empty identifiers // 2 empty identifiers
CryptoOpenSSL::add_to_tt(&mut TT, &[])?; Self::add_to_tt(&mut TT, &[])?;
CryptoOpenSSL::add_to_tt(&mut TT, &[])?; Self::add_to_tt(&mut TT, &[])?;
// M // M
CryptoOpenSSL::add_to_tt(&mut TT, &MATTER_M_BIN)?; Self::add_to_tt(&mut TT, &MATTER_M_BIN)?;
// N // N
CryptoOpenSSL::add_to_tt(&mut TT, &MATTER_N_BIN)?; Self::add_to_tt(&mut TT, &MATTER_N_BIN)?;
// X = pA // X = pA
CryptoOpenSSL::add_to_tt(&mut TT, pA)?; Self::add_to_tt(&mut TT, pA)?;
// Y = pB // Y = pB
CryptoOpenSSL::add_to_tt(&mut TT, pB)?; Self::add_to_tt(&mut TT, pB)?;
let X = EcPoint::from_bytes(&self.group, pA, &mut self.bn_ctx)?; let X = EcPoint::from_bytes(&self.group, pA, &mut self.bn_ctx)?;
let (Z, V) = CryptoOpenSSL::get_ZV_as_verifier( let (Z, V) = Self::get_ZV_as_verifier(
&self.w0, &self.w0,
&self.L, &self.L,
&mut self.M, &mut self.M,
@ -207,7 +210,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
&mut self.bn_ctx, &mut self.bn_ctx,
)?; )?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
// V // V
let tmp = V.to_bytes( let tmp = V.to_bytes(
@ -216,20 +219,18 @@ impl CryptoSpake2 for CryptoOpenSSL {
&mut self.bn_ctx, &mut self.bn_ctx,
)?; )?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
// w0 // w0
let tmp = self.w0.to_vec(); let tmp = self.w0.to_vec();
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
let h = TT.finish()?; let h = TT.finish()?;
TT_hash.copy_from_slice(h.as_ref()); TT_hash.copy_from_slice(h.as_ref());
Ok(()) Ok(())
} }
}
impl CryptoOpenSSL {
fn add_to_tt(tt: &mut Hasher, buf: &[u8]) -> Result<(), Error> { fn add_to_tt(tt: &mut Hasher, buf: &[u8]) -> Result<(), Error> {
let mut len_buf: [u8; 8] = [0; 8]; let mut len_buf: [u8; 8] = [0; 8];
LittleEndian::write_u64(&mut len_buf, buf.len() as u64); LittleEndian::write_u64(&mut len_buf, buf.len() as u64);
@ -286,11 +287,11 @@ impl CryptoOpenSSL {
let mut tmp = BigNum::new()?; let mut tmp = BigNum::new()?;
tmp.mod_mul(x, w0, order, bn_ctx)?; tmp.mod_mul(x, w0, order, bn_ctx)?;
N.invert(group, bn_ctx)?; N.invert(group, bn_ctx)?;
let Z = CryptoOpenSSL::do_add_mul(Y, x, N, &tmp, group, bn_ctx)?; let Z = Self::do_add_mul(Y, x, N, &tmp, group, bn_ctx)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
tmp.mod_mul(w1, w0, order, bn_ctx)?; tmp.mod_mul(w1, w0, order, bn_ctx)?;
let V = CryptoOpenSSL::do_add_mul(Y, w1, N, &tmp, group, bn_ctx)?; let V = Self::do_add_mul(Y, w1, N, &tmp, group, bn_ctx)?;
Ok((Z, V)) Ok((Z, V))
} }
@ -321,7 +322,7 @@ impl CryptoOpenSSL {
let mut tmp = BigNum::new()?; let mut tmp = BigNum::new()?;
tmp.mod_mul(y, w0, order, bn_ctx)?; tmp.mod_mul(y, w0, order, bn_ctx)?;
M.invert(group, bn_ctx)?; M.invert(group, bn_ctx)?;
let Z = CryptoOpenSSL::do_add_mul(X, y, M, &tmp, group, bn_ctx)?; let Z = Self::do_add_mul(X, y, M, &tmp, group, bn_ctx)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
let mut V = EcPoint::new(group)?; let mut V = EcPoint::new(group)?;
@ -333,8 +334,7 @@ impl CryptoOpenSSL {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::CryptoOpenSSL; use super::CryptoSpake2;
use crate::secure_channel::crypto::CryptoSpake2;
use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use crate::secure_channel::spake2p_test_vectors::test_vectors::*;
use openssl::bn::BigNum; use openssl::bn::BigNum;
use openssl::ec::{EcPoint, PointConversionForm}; use openssl::ec::{EcPoint, PointConversionForm};
@ -343,12 +343,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_X() { fn test_get_X() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoOpenSSL::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = BigNum::from_slice(&t.x).unwrap(); let x = BigNum::from_slice(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator(); let P = c.group.generator();
let r = CryptoOpenSSL::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); let r = CryptoSpake2::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
assert_eq!( assert_eq!(
t.X, t.X,
r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx)
@ -362,11 +362,11 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_Y() { fn test_get_Y() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoOpenSSL::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = BigNum::from_slice(&t.y).unwrap(); let y = BigNum::from_slice(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator(); let P = c.group.generator();
let r = CryptoOpenSSL::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); let r = CryptoSpake2::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
assert_eq!( assert_eq!(
t.Y, t.Y,
r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx)
@ -380,12 +380,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_prover() { fn test_get_ZV_as_prover() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoOpenSSL::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = BigNum::from_slice(&t.x).unwrap(); let x = BigNum::from_slice(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
c.set_w1(&t.w1).unwrap(); c.set_w1(&t.w1).unwrap();
let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap(); let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap();
let (Z, V) = CryptoOpenSSL::get_ZV_as_prover( let (Z, V) = CryptoSpake2::get_ZV_as_prover(
&c.w0, &c.w0,
&c.w1, &c.w1,
&mut c.N, &mut c.N,
@ -416,12 +416,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_verifier() { fn test_get_ZV_as_verifier() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoOpenSSL::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = BigNum::from_slice(&t.y).unwrap(); let y = BigNum::from_slice(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap(); let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap();
let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap(); let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap();
let (Z, V) = CryptoOpenSSL::get_ZV_as_verifier( let (Z, V) = CryptoSpake2::get_ZV_as_verifier(
&c.w0, &c.w0,
&L, &L,
&mut c.M, &mut c.M,

View file

@ -21,11 +21,12 @@ use elliptic_curve::ops::*;
use elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}; use elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint};
use elliptic_curve::Field; use elliptic_curve::Field;
use elliptic_curve::PrimeField; use elliptic_curve::PrimeField;
use rand_core::CryptoRng;
use rand_core::RngCore;
use sha2::Digest; use sha2::Digest;
use crate::error::Error; use crate::error::Error;
use crate::utils::rand::Rand;
use super::crypto::CryptoSpake2;
const MATTER_M_BIN: [u8; 65] = [ const MATTER_M_BIN: [u8; 65] = [
0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2,
@ -44,7 +45,7 @@ const MATTER_N_BIN: [u8; 65] = [
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct CryptoRustCrypto { pub struct CryptoSpake2 {
xy: p256::Scalar, xy: p256::Scalar,
w0: p256::Scalar, w0: p256::Scalar,
w1: p256::Scalar, w1: p256::Scalar,
@ -54,15 +55,15 @@ pub struct CryptoRustCrypto {
pB: p256::EncodedPoint, pB: p256::EncodedPoint,
} }
impl CryptoSpake2 for CryptoRustCrypto { impl CryptoSpake2 {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let M = p256::EncodedPoint::from_bytes(MATTER_M_BIN).unwrap(); let M = p256::EncodedPoint::from_bytes(MATTER_M_BIN).unwrap();
let N = p256::EncodedPoint::from_bytes(MATTER_N_BIN).unwrap(); let N = p256::EncodedPoint::from_bytes(MATTER_N_BIN).unwrap();
let L = p256::EncodedPoint::default(); let L = p256::EncodedPoint::default();
let pB = p256::EncodedPoint::default(); let pB = p256::EncodedPoint::default();
Ok(CryptoRustCrypto { Ok(Self {
xy: p256::Scalar::ZERO, xy: p256::Scalar::ZERO,
w0: p256::Scalar::ZERO, w0: p256::Scalar::ZERO,
w1: p256::Scalar::ZERO, w1: p256::Scalar::ZERO,
@ -74,7 +75,7 @@ impl CryptoSpake2 for CryptoRustCrypto {
} }
// Computes w0 from w0s respectively // Computes w0 from w0s respectively
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w0 = w0s mod p // w0 = w0s mod p
// where p is the order of the curve // where p is the order of the curve
@ -103,7 +104,7 @@ impl CryptoSpake2 for CryptoRustCrypto {
Ok(()) Ok(())
} }
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w1 = w1s mod p // w1 = w1s mod p
// where p is the order of the curve // where p is the order of the curve
@ -132,14 +133,14 @@ impl CryptoSpake2 for CryptoRustCrypto {
Ok(()) Ok(())
} }
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> {
self.w0 = self.w0 =
p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w0)) p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w0))
.unwrap(); .unwrap();
Ok(()) Ok(())
} }
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> {
self.w1 = self.w1 =
p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w1)) p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w1))
.unwrap(); .unwrap();
@ -148,12 +149,13 @@ impl CryptoSpake2 for CryptoRustCrypto {
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(dead_code)] #[allow(dead_code)]
fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> {
self.L = p256::EncodedPoint::from_bytes(l).unwrap(); self.L = p256::EncodedPoint::from_bytes(l).unwrap();
Ok(()) Ok(())
} }
fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { #[allow(non_snake_case)]
pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec, // From the Matter spec,
// L = w1 * P // L = w1 * P
// where P is the generator of the underlying elliptic curve // where P is the generator of the underlying elliptic curve
@ -163,14 +165,14 @@ impl CryptoSpake2 for CryptoRustCrypto {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { pub fn get_pB(&mut self, pB: &mut [u8], rand: Rand) -> Result<(), Error> {
// From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/)
// for y // for y
// - select random y between 0 to p // - select random y between 0 to p
// - Y = y*P + w0*N // - Y = y*P + w0*N
// - pB = Y // - pB = Y
let mut rng = rand::thread_rng(); let mut rand = RandRngCore(rand);
self.xy = p256::Scalar::random(&mut rng); self.xy = p256::Scalar::random(&mut rand);
let P = p256::AffinePoint::GENERATOR; let P = p256::AffinePoint::GENERATOR;
let N = p256::AffinePoint::from_encoded_point(&self.N).unwrap(); let N = p256::AffinePoint::from_encoded_point(&self.N).unwrap();
@ -182,7 +184,7 @@ impl CryptoSpake2 for CryptoRustCrypto {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_TT_as_verifier( pub fn get_TT_as_verifier(
&mut self, &mut self,
context: &[u8], context: &[u8],
pA: &[u8], pA: &[u8],
@ -222,9 +224,7 @@ impl CryptoSpake2 for CryptoRustCrypto {
Ok(()) Ok(())
} }
}
impl CryptoRustCrypto {
fn add_to_tt(tt: &mut sha2::Sha256, buf: &[u8]) -> Result<(), Error> { fn add_to_tt(tt: &mut sha2::Sha256, buf: &[u8]) -> Result<(), Error> {
tt.update((buf.len() as u64).to_le_bytes()); tt.update((buf.len() as u64).to_le_bytes());
if !buf.is_empty() { if !buf.is_empty() {
@ -266,11 +266,11 @@ impl CryptoRustCrypto {
let mut tmp = x * w0; let mut tmp = x * w0;
let N_neg = N.neg(); let N_neg = N.neg();
let Z = CryptoRustCrypto::do_add_mul(Y, x, N_neg, tmp)?; let Z = Self::do_add_mul(Y, x, N_neg, tmp)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
tmp = w1 * w0; tmp = w1 * w0;
let V = CryptoRustCrypto::do_add_mul(Y, w1, N_neg, tmp)?; let V = Self::do_add_mul(Y, w1, N_neg, tmp)?;
Ok((Z, V)) Ok((Z, V))
} }
@ -297,27 +297,55 @@ impl CryptoRustCrypto {
let tmp = y * w0; let tmp = y * w0;
let M_neg = M.neg(); let M_neg = M.neg();
let Z = CryptoRustCrypto::do_add_mul(X, y, M_neg, tmp)?; let Z = Self::do_add_mul(X, y, M_neg, tmp)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
let V = (L * y).to_encoded_point(false); let V = (L * y).to_encoded_point(false);
Ok((Z, V)) Ok((Z, V))
} }
} }
pub struct RandRngCore(pub Rand);
impl RngCore for RandRngCore {
fn next_u32(&mut self) -> u32 {
let mut buf = [0; 4];
self.fill_bytes(&mut buf);
u32::from_be_bytes(buf)
}
fn next_u64(&mut self) -> u64 {
let mut buf = [0; 8];
self.fill_bytes(&mut buf);
u64::from_be_bytes(buf)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
(self.0)(dest);
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
self.fill_bytes(dest);
Ok(())
}
}
impl CryptoRng for RandRngCore {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use elliptic_curve::sec1::FromEncodedPoint; use elliptic_curve::sec1::FromEncodedPoint;
use crate::secure_channel::crypto::CryptoSpake2;
use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use crate::secure_channel::spake2p_test_vectors::test_vectors::*;
#[test] #[test]
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_X() { fn test_get_X() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoRustCrypto::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = p256::Scalar::from_repr( let x = p256::Scalar::from_repr(
*elliptic_curve::generic_array::GenericArray::from_slice(&t.x), *elliptic_curve::generic_array::GenericArray::from_slice(&t.x),
) )
@ -325,7 +353,7 @@ mod tests {
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = p256::AffinePoint::GENERATOR; let P = p256::AffinePoint::GENERATOR;
let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap(); let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap();
let r: p256::EncodedPoint = CryptoRustCrypto::do_add_mul(P, x, M, c.w0).unwrap(); let r: p256::EncodedPoint = CryptoSpake2::do_add_mul(P, x, M, c.w0).unwrap();
assert_eq!(&t.X, r.as_bytes()); assert_eq!(&t.X, r.as_bytes());
} }
} }
@ -334,7 +362,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_Y() { fn test_get_Y() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoRustCrypto::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = p256::Scalar::from_repr( let y = p256::Scalar::from_repr(
*elliptic_curve::generic_array::GenericArray::from_slice(&t.y), *elliptic_curve::generic_array::GenericArray::from_slice(&t.y),
) )
@ -342,7 +370,7 @@ mod tests {
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = p256::AffinePoint::GENERATOR; let P = p256::AffinePoint::GENERATOR;
let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap(); let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap();
let r = CryptoRustCrypto::do_add_mul(P, y, N, c.w0).unwrap(); let r = CryptoSpake2::do_add_mul(P, y, N, c.w0).unwrap();
assert_eq!(&t.Y, r.as_bytes()); assert_eq!(&t.Y, r.as_bytes());
} }
} }
@ -351,7 +379,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_prover() { fn test_get_ZV_as_prover() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoRustCrypto::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = p256::Scalar::from_repr( let x = p256::Scalar::from_repr(
*elliptic_curve::generic_array::GenericArray::from_slice(&t.x), *elliptic_curve::generic_array::GenericArray::from_slice(&t.x),
) )
@ -361,7 +389,7 @@ mod tests {
let Y = p256::EncodedPoint::from_bytes(t.Y).unwrap(); let Y = p256::EncodedPoint::from_bytes(t.Y).unwrap();
let Y = p256::AffinePoint::from_encoded_point(&Y).unwrap(); let Y = p256::AffinePoint::from_encoded_point(&Y).unwrap();
let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap(); let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap();
let (Z, V) = CryptoRustCrypto::get_ZV_as_prover(c.w0, c.w1, N, Y, x).unwrap(); let (Z, V) = CryptoSpake2::get_ZV_as_prover(c.w0, c.w1, N, Y, x).unwrap();
assert_eq!(&t.Z, Z.as_bytes()); assert_eq!(&t.Z, Z.as_bytes());
assert_eq!(&t.V, V.as_bytes()); assert_eq!(&t.V, V.as_bytes());
@ -372,7 +400,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_verifier() { fn test_get_ZV_as_verifier() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoRustCrypto::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = p256::Scalar::from_repr( let y = p256::Scalar::from_repr(
*elliptic_curve::generic_array::GenericArray::from_slice(&t.y), *elliptic_curve::generic_array::GenericArray::from_slice(&t.y),
) )
@ -383,7 +411,7 @@ mod tests {
let L = p256::EncodedPoint::from_bytes(t.L).unwrap(); let L = p256::EncodedPoint::from_bytes(t.L).unwrap();
let L = p256::AffinePoint::from_encoded_point(&L).unwrap(); let L = p256::AffinePoint::from_encoded_point(&L).unwrap();
let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap(); let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap();
let (Z, V) = CryptoRustCrypto::get_ZV_as_verifier(c.w0, L, M, X, y).unwrap(); let (Z, V) = CryptoSpake2::get_ZV_as_verifier(c.w0, L, M, X, y).unwrap();
assert_eq!(&t.Z, Z.as_bytes()); assert_eq!(&t.Z, Z.as_bytes());
assert_eq!(&t.V, V.as_bytes()); assert_eq!(&t.V, V.as_bytes());

View file

@ -17,10 +17,16 @@
pub mod case; pub mod case;
pub mod common; pub mod common;
#[cfg(feature = "crypto_esp_mbedtls")] #[cfg(not(any(
pub mod crypto_esp_mbedtls; feature = "crypto_openssl",
#[cfg(feature = "crypto_mbedtls")] feature = "crypto_mbedtls",
pub mod crypto_mbedtls; feature = "crypto_rustcrypto"
)))]
mod crypto_dummy;
#[cfg(all(feature = "crypto_mbedtls", target_os = "espidf"))]
mod crypto_esp_mbedtls;
#[cfg(all(feature = "crypto_mbedtls", not(target_os = "espidf")))]
mod crypto_mbedtls;
#[cfg(feature = "crypto_openssl")] #[cfg(feature = "crypto_openssl")]
pub mod crypto_openssl; pub mod crypto_openssl;
#[cfg(feature = "crypto_rustcrypto")] #[cfg(feature = "crypto_rustcrypto")]

View file

@ -15,10 +15,7 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::{fmt::Write, time::Duration};
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use super::{ use super::{
common::{create_sc_status_report, SCStatusCodes}, common::{create_sc_status_report, SCStatusCodes},
@ -26,98 +23,118 @@ use super::{
}; };
use crate::{ use crate::{
crypto, crypto,
error::Error, error::{Error, ErrorCode},
mdns::{self, Mdns}, mdns::{Mdns, ServiceMode},
secure_channel::common::OpCode, secure_channel::common::OpCode,
sys::SysMdnsService, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV},
tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV},
transport::{ transport::{
exchange::ExchangeCtx, exchange::ExchangeCtx,
network::Address, network::Address,
proto_demux::{ProtoCtx, ResponseRequired}, proto_ctx::ProtoCtx,
queue::{Msg, WorkQ},
session::{CloneData, SessionMode}, session::{CloneData, SessionMode},
}, },
utils::{epoch::Epoch, rand::Rand},
}; };
use log::{error, info}; use log::{error, info};
use rand::prelude::*;
#[allow(clippy::large_enum_variant)]
enum PaseMgrState { enum PaseMgrState {
Enabled(PAKE, SysMdnsService), Enabled(Pake, heapless::String<16>),
Disabled, Disabled,
} }
pub struct PaseMgrInternal { pub struct PaseMgr {
state: PaseMgrState, state: PaseMgrState,
epoch: Epoch,
rand: Rand,
} }
#[derive(Clone)]
// Could this lock be avoided?
pub struct PaseMgr(Arc<Mutex<PaseMgrInternal>>);
impl PaseMgr { impl PaseMgr {
pub fn new() -> Self { #[inline(always)]
Self(Arc::new(Mutex::new(PaseMgrInternal { pub fn new(epoch: Epoch, rand: Rand) -> Self {
Self {
state: PaseMgrState::Disabled, state: PaseMgrState::Disabled,
}))) epoch,
rand,
}
}
pub fn is_pase_session_enabled(&self) -> bool {
matches!(&self.state, PaseMgrState::Enabled(_, _))
} }
pub fn enable_pase_session( pub fn enable_pase_session(
&mut self, &mut self,
verifier: VerifierData, verifier: VerifierData,
discriminator: u16, discriminator: u16,
mdns: &dyn Mdns,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut s = self.0.lock().unwrap(); let mut buf = [0; 8];
let name: u64 = rand::thread_rng().gen_range(0..0xFFFFFFFFFFFFFFFF); (self.rand)(&mut buf);
let name = format!("{:016X}", name); let num = u64::from_be_bytes(buf);
let mdns = Mdns::get()?
.publish_service(&name, mdns::ServiceMode::Commissionable(discriminator))?; let mut mdns_service_name = heapless::String::<16>::new();
s.state = PaseMgrState::Enabled(PAKE::new(verifier), mdns); write!(&mut mdns_service_name, "{:016X}", num).unwrap();
mdns.add(
&mdns_service_name,
ServiceMode::Commissionable(discriminator),
)?;
self.state = PaseMgrState::Enabled(
Pake::new(verifier, self.epoch, self.rand),
mdns_service_name,
);
Ok(()) Ok(())
} }
pub fn disable_pase_session(&mut self) { pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> {
let mut s = self.0.lock().unwrap(); if let PaseMgrState::Enabled(_, mdns_service_name) = &self.state {
s.state = PaseMgrState::Disabled; mdns.remove(mdns_service_name)?;
}
self.state = PaseMgrState::Disabled;
Ok(())
} }
/// If the PASE Session is enabled, execute the closure, /// If the PASE Session is enabled, execute the closure,
/// if not enabled, generate SC Status Report /// if not enabled, generate SC Status Report
fn if_enabled<F>(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<(), Error> fn if_enabled<F, T>(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<Option<T>, Error>
where where
F: FnOnce(&mut PAKE, &mut ProtoCtx) -> Result<(), Error>, F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result<T, Error>,
{ {
let mut s = self.0.lock().unwrap(); if let PaseMgrState::Enabled(pake, _) = &mut self.state {
if let PaseMgrState::Enabled(pake, _) = &mut s.state { let data = f(pake, ctx)?;
f(pake, ctx)
Ok(Some(data))
} else { } else {
error!("PASE Not enabled"); error!("PASE Not enabled");
create_sc_status_report(&mut ctx.tx, SCStatusCodes::InvalidParameter, None) create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None)?;
Ok(None)
} }
} }
pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result<bool, Error> {
ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8); ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?; self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?;
Ok(ResponseRequired::Yes) Ok(true)
} }
pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<bool, Error> {
ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8); ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8);
self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?; self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?;
Ok(ResponseRequired::Yes) Ok(true)
} }
pub fn pasepake3_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { pub fn pasepake3_handler(
self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?; &mut self,
self.disable_pase_session(); ctx: &mut ProtoCtx,
Ok(ResponseRequired::Yes) mdns: &dyn Mdns,
} ) -> Result<(bool, Option<CloneData>), Error> {
} let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?;
self.disable_pase_session(mdns)?;
impl Default for PaseMgr { Ok((true, clone_data.flatten()))
fn default() -> Self {
Self::new()
} }
} }
@ -131,53 +148,54 @@ const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60);
const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys"; const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys";
struct SessionData { struct SessionData {
start_time: SystemTime, start_time: Duration,
exch_id: u16, exch_id: u16,
peer_addr: Address, peer_addr: Address,
spake2p: Box<Spake2P>, spake2p: Spake2P,
} }
impl SessionData { impl SessionData {
fn is_sess_expired(&self) -> Result<bool, Error> { fn is_sess_expired(&self, epoch: Epoch) -> Result<bool, Error> {
if SystemTime::now().duration_since(self.start_time)? > PASE_DISCARD_TIMEOUT_SECS { Ok(epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS)
Ok(true)
} else {
Ok(false)
}
} }
} }
#[allow(clippy::large_enum_variant)]
enum PakeState { enum PakeState {
Idle, Idle,
InProgress(SessionData), InProgress(SessionData),
} }
impl PakeState { impl PakeState {
const fn new() -> Self {
Self::Idle
}
fn take(&mut self) -> Result<SessionData, Error> { fn take(&mut self) -> Result<SessionData, Error> {
let new = std::mem::replace(self, PakeState::Idle); let new = core::mem::replace(self, PakeState::Idle);
if let PakeState::InProgress(s) = new { if let PakeState::InProgress(s) = new {
Ok(s) Ok(s)
} else { } else {
Err(Error::InvalidSignature) Err(ErrorCode::InvalidSignature.into())
} }
} }
fn is_idle(&self) -> bool { fn is_idle(&self) -> bool {
std::mem::discriminant(self) == std::mem::discriminant(&PakeState::Idle) core::mem::discriminant(self) == core::mem::discriminant(&PakeState::Idle)
} }
fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result<SessionData, Error> { fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result<SessionData, Error> {
let sd = self.take()?; let sd = self.take()?;
if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() { if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() {
Err(Error::InvalidState) Err(ErrorCode::InvalidState.into())
} else { } else {
Ok(sd) Ok(sd)
} }
} }
fn make_in_progress(&mut self, spake2p: Box<Spake2P>, exch_ctx: &ExchangeCtx) { fn make_in_progress(&mut self, epoch: Epoch, spake2p: Spake2P, exch_ctx: &ExchangeCtx) {
*self = PakeState::InProgress(SessionData { *self = PakeState::InProgress(SessionData {
start_time: SystemTime::now(), start_time: epoch(),
spake2p, spake2p,
exch_id: exch_ctx.exch.get_id(), exch_id: exch_ctx.exch.get_id(),
peer_addr: exch_ctx.sess.get_peer_addr(), peer_addr: exch_ctx.sess.get_peer_addr(),
@ -191,37 +209,41 @@ impl PakeState {
impl Default for PakeState { impl Default for PakeState {
fn default() -> Self { fn default() -> Self {
Self::Idle Self::new()
} }
} }
pub struct PAKE { struct Pake {
pub verifier: VerifierData, verifier: VerifierData,
state: PakeState, state: PakeState,
epoch: Epoch,
rand: Rand,
} }
impl PAKE { impl Pake {
pub fn new(verifier: VerifierData) -> Self { pub fn new(verifier: VerifierData, epoch: Epoch, rand: Rand) -> Self {
// TODO: Can any PBKDF2 calculation be pre-computed here // TODO: Can any PBKDF2 calculation be pre-computed here
PAKE { Self {
verifier, verifier,
state: Default::default(), state: PakeState::new(),
epoch,
rand,
} }
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<Option<CloneData>, Error> {
let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?;
let cA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?;
let (status_code, Ke) = sd.spake2p.handle_cA(cA); let (status_code, ke) = sd.spake2p.handle_cA(cA);
if status_code == SCStatusCodes::SessionEstablishmentSuccess { let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess {
// Get the keys // Get the keys
let Ke = Ke.ok_or(Error::Invalid)?; let ke = ke.ok_or(ErrorCode::Invalid)?;
let mut session_keys: [u8; 48] = [0; 48]; let mut session_keys: [u8; 48] = [0; 48];
crypto::hkdf_sha256(&[], Ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) crypto::hkdf_sha256(&[], ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// Create a session // Create a session
let data = sd.spake2p.get_app_data(); let data = sd.spake2p.get_app_data();
@ -242,23 +264,25 @@ impl PAKE {
.copy_from_slice(&session_keys[32..48]); .copy_from_slice(&session_keys[32..48]);
// Queue a transport mgr request to add a new session // Queue a transport mgr request to add a new session
WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; Some(clone_data)
} } else {
None
};
create_sc_status_report(&mut ctx.tx, status_code, None)?; create_sc_status_report(ctx.tx, status_code, None)?;
ctx.exch_ctx.exch.close(); ctx.exch_ctx.exch.close();
Ok(()) Ok(clone_data)
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> {
let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?;
let pA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; let pA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?;
let mut pB: [u8; 65] = [0; 65]; let mut pB: [u8; 65] = [0; 65];
let mut cB: [u8; 32] = [0; 32]; let mut cB: [u8; 32] = [0; 32];
sd.spake2p.start_verifier(&self.verifier)?; sd.spake2p.start_verifier(&self.verifier)?;
sd.spake2p.handle_pA(pA, &mut pB, &mut cB)?; sd.spake2p.handle_pA(pA, &mut pB, &mut cB, self.rand)?;
let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?);
let resp = Pake1Resp { let resp = Pake1Resp {
@ -275,30 +299,30 @@ impl PAKE {
pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> {
if !self.state.is_idle() { if !self.state.is_idle() {
let sd = self.state.take()?; let sd = self.state.take()?;
if sd.is_sess_expired()? { if sd.is_sess_expired(self.epoch)? {
info!("Previous session expired, clearing it"); info!("Previous session expired, clearing it");
self.state = PakeState::Idle; self.state = PakeState::Idle;
} else { } else {
info!("Previous session in-progress, denying new request"); info!("Previous session in-progress, denying new request");
// little-endian timeout (here we've hardcoded 500ms) // little-endian timeout (here we've hardcoded 500ms)
create_sc_status_report(&mut ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; create_sc_status_report(ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?;
return Ok(()); return Ok(());
} }
} }
let root = tlv::get_root_node(ctx.rx.as_borrow_slice())?; let root = tlv::get_root_node(ctx.rx.as_slice())?;
let a = PBKDFParamReq::from_tlv(&root)?; let a = PBKDFParamReq::from_tlv(&root)?;
if a.passcode_id != 0 { if a.passcode_id != 0 {
error!("Can't yet handle passcode_id != 0"); error!("Can't yet handle passcode_id != 0");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
let mut our_random: [u8; 32] = [0; 32]; let mut our_random: [u8; 32] = [0; 32];
rand::thread_rng().fill_bytes(&mut our_random); (self.rand)(&mut our_random);
let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id();
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
let mut spake2p = Box::new(Spake2P::new()); let mut spake2p = Spake2P::new();
spake2p.set_app_data(spake2p_data); spake2p.set_app_data(spake2p_data);
// Generate response // Generate response
@ -318,8 +342,9 @@ impl PAKE {
} }
resp.to_tlv(&mut tw, TagType::Anonymous)?; resp.to_tlv(&mut tw, TagType::Anonymous)?;
spake2p.set_context(ctx.rx.as_borrow_slice(), ctx.tx.as_borrow_slice())?; spake2p.set_context(ctx.rx.as_slice(), ctx.tx.as_mut_slice())?;
self.state.make_in_progress(spake2p, &ctx.exch_ctx); self.state
.make_in_progress(self.epoch, spake2p, &ctx.exch_ctx);
Ok(()) Ok(())
} }

View file

@ -17,33 +17,20 @@
use crate::{ use crate::{
crypto::{self, HmacSha256}, crypto::{self, HmacSha256},
sys, utils::rand::Rand,
}; };
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use log::error; use log::error;
use rand::prelude::*;
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use crate::{ use crate::{
crypto::{pbkdf2_hmac, Sha256}, crypto::{pbkdf2_hmac, Sha256},
error::Error, error::{Error, ErrorCode},
}; };
#[cfg(feature = "crypto_openssl")]
use super::crypto_openssl::CryptoOpenSSL;
#[cfg(feature = "crypto_mbedtls")]
use super::crypto_mbedtls::CryptoMbedTLS;
#[cfg(feature = "crypto_esp_mbedtls")]
use super::crypto_esp_mbedtls::CryptoEspMbedTls;
#[cfg(feature = "crypto_rustcrypto")]
use super::crypto_rustcrypto::CryptoRustCrypto;
use super::{common::SCStatusCodes, crypto::CryptoSpake2}; use super::{common::SCStatusCodes, crypto::CryptoSpake2};
// This file handle Spake2+ specific instructions. In itself, this file is // This file handles Spake2+ specific instructions. In itself, this file is
// independent from the BigNum and EC operations that are typically required // independent from the BigNum and EC operations that are typically required
// Spake2+. We use the CryptoSpake2 trait object that allows us to abstract // Spake2+. We use the CryptoSpake2 trait object that allows us to abstract
// out the specific implementations. // out the specific implementations.
@ -51,6 +38,8 @@ use super::{common::SCStatusCodes, crypto::CryptoSpake2};
// In the case of the verifier, we don't actually release the Ke until we // In the case of the verifier, we don't actually release the Ke until we
// validate that the cA is confirmed. // validate that the cA is confirmed.
pub const SPAKE2_ITERATION_COUNT: u32 = 2000;
#[derive(PartialEq, Copy, Clone, Debug)] #[derive(PartialEq, Copy, Clone, Debug)]
pub enum Spake2VerifierState { pub enum Spake2VerifierState {
// Initialised - w0, L are set // Initialised - w0, L are set
@ -74,7 +63,7 @@ pub struct Spake2P {
context: Option<Sha256>, context: Option<Sha256>,
Ke: [u8; 16], Ke: [u8; 16],
cA: [u8; 32], cA: [u8; 32],
crypto_spake2: Option<Box<dyn CryptoSpake2>>, crypto_spake2: Option<CryptoSpake2>,
app_data: u32, app_data: u32,
} }
@ -87,24 +76,8 @@ const CRYPTO_PUBLIC_KEY_SIZE_BYTES: usize = (2 * CRYPTO_GROUP_SIZE_BYTES) + 1;
const MAX_SALT_SIZE_BYTES: usize = 32; const MAX_SALT_SIZE_BYTES: usize = 32;
const VERIFIER_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES; const VERIFIER_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES;
#[cfg(feature = "crypto_openssl")] fn crypto_spake2_new() -> Result<CryptoSpake2, Error> {
fn crypto_spake2_new() -> Result<Box<dyn CryptoSpake2>, Error> { CryptoSpake2::new()
Ok(Box::new(CryptoOpenSSL::new()?))
}
#[cfg(feature = "crypto_mbedtls")]
fn crypto_spake2_new() -> Result<Box<dyn CryptoSpake2>, Error> {
Ok(Box::new(CryptoMbedTLS::new()?))
}
#[cfg(feature = "crypto_esp_mbedtls")]
fn crypto_spake2_new() -> Result<Box<dyn CryptoSpake2>, Error> {
Ok(Box::new(CryptoEspMbedTls::new()?))
}
#[cfg(feature = "crypto_rustcrypto")]
fn crypto_spake2_new() -> Result<Box<dyn CryptoSpake2>, Error> {
Ok(Box::new(CryptoRustCrypto::new()?))
} }
impl Default for Spake2P { impl Default for Spake2P {
@ -129,13 +102,13 @@ pub enum VerifierOption {
} }
impl VerifierData { impl VerifierData {
pub fn new_with_pw(pw: u32) -> Self { pub fn new_with_pw(pw: u32, rand: Rand) -> Self {
let mut s = Self { let mut s = Self {
salt: [0; MAX_SALT_SIZE_BYTES], salt: [0; MAX_SALT_SIZE_BYTES],
count: sys::SPAKE2_ITERATION_COUNT, count: SPAKE2_ITERATION_COUNT,
data: VerifierOption::Password(pw), data: VerifierOption::Password(pw),
}; };
rand::thread_rng().fill_bytes(&mut s.salt); rand(&mut s.salt);
s s
} }
@ -158,7 +131,7 @@ impl VerifierData {
} }
impl Spake2P { impl Spake2P {
pub fn new() -> Self { pub const fn new() -> Self {
Spake2P { Spake2P {
mode: Spake2Mode::Unknown, mode: Spake2Mode::Unknown,
context: None, context: None,
@ -198,7 +171,7 @@ impl Spake2P {
match verifier.data { match verifier.data {
VerifierOption::Password(pw) => { VerifierOption::Password(pw) => {
// Derive w0 and L from the password // Derive w0 and L from the password
let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; 2 * CRYPTO_W_SIZE_BYTES]; let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; (2 * CRYPTO_W_SIZE_BYTES)];
Spake2P::get_w0w1s(pw, verifier.count, &verifier.salt, &mut w0w1s); Spake2P::get_w0w1s(pw, verifier.count, &verifier.salt, &mut w0w1s);
let w0s_len = w0w1s.len() / 2; let w0s_len = w0w1s.len() / 2;
@ -223,13 +196,19 @@ impl Spake2P {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn handle_pA(&mut self, pA: &[u8], pB: &mut [u8], cB: &mut [u8]) -> Result<(), Error> { pub fn handle_pA(
&mut self,
pA: &[u8],
pB: &mut [u8],
cB: &mut [u8],
rand: Rand,
) -> Result<(), Error> {
if self.mode != Spake2Mode::Verifier(Spake2VerifierState::Init) { if self.mode != Spake2Mode::Verifier(Spake2VerifierState::Init) {
return Err(Error::InvalidState); Err(ErrorCode::InvalidState)?;
} }
if let Some(crypto_spake2) = &mut self.crypto_spake2 { if let Some(crypto_spake2) = &mut self.crypto_spake2 {
crypto_spake2.get_pB(pB)?; crypto_spake2.get_pB(pB, rand)?;
if let Some(context) = self.context.take() { if let Some(context) = self.context.take() {
let mut hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; let mut hash = [0u8; crypto::SHA256_HASH_LEN_BYTES];
context.finish(&mut hash)?; context.finish(&mut hash)?;
@ -278,13 +257,13 @@ impl Spake2P {
if ke_internal.len() == Ke.len() { if ke_internal.len() == Ke.len() {
Ke.copy_from_slice(ke_internal); Ke.copy_from_slice(ke_internal);
} else { } else {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
// Step 2: KcA || KcB = KDF(nil, Ka, "ConfirmationKeys") // Step 2: KcA || KcB = KDF(nil, Ka, "ConfirmationKeys")
let mut KcAKcB: [u8; 32] = [0; 32]; let mut KcAKcB: [u8; 32] = [0; 32];
crypto::hkdf_sha256(&[], Ka, &SPAKE2P_KEY_CONFIRM_INFO, &mut KcAKcB) crypto::hkdf_sha256(&[], Ka, &SPAKE2P_KEY_CONFIRM_INFO, &mut KcAKcB)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
let KcA = &KcAKcB[0..(KcAKcB.len() / 2)]; let KcA = &KcAKcB[0..(KcAKcB.len() / 2)];
let KcB = &KcAKcB[(KcAKcB.len() / 2)..]; let KcB = &KcAKcB[(KcAKcB.len() / 2)..];
@ -317,7 +296,7 @@ mod tests {
0x4, 0xa1, 0xd2, 0xc6, 0x11, 0xf0, 0xbd, 0x36, 0x78, 0x67, 0x79, 0x7b, 0xfe, 0x82, 0x4, 0xa1, 0xd2, 0xc6, 0x11, 0xf0, 0xbd, 0x36, 0x78, 0x67, 0x79, 0x7b, 0xfe, 0x82,
0x36, 0x0, 0x36, 0x0,
]; ];
let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; 2 * CRYPTO_W_SIZE_BYTES]; let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; (2 * CRYPTO_W_SIZE_BYTES)];
Spake2P::get_w0w1s(123456, 2000, &salt, &mut w0w1s); Spake2P::get_w0w1s(123456, 2000, &salt, &mut w0w1s);
assert_eq!( assert_eq!(
w0w1s, w0w1s,

View file

@ -39,6 +39,7 @@ pub enum GeneralCode {
PermissionDenied = 15, PermissionDenied = 15,
DataLoss = 16, DataLoss = 16,
} }
pub fn create_status_report( pub fn create_status_report(
proto_tx: &mut Packet, proto_tx: &mut Packet,
general_code: GeneralCode, general_code: GeneralCode,
@ -46,7 +47,8 @@ pub fn create_status_report(
proto_code: u16, proto_code: u16,
proto_data: Option<&[u8]>, proto_data: Option<&[u8]>,
) -> Result<(), Error> { ) -> Result<(), Error> {
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); proto_tx.reset();
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
proto_tx.set_proto_opcode(OpCode::StatusReport as u8); proto_tx.set_proto_opcode(OpCode::StatusReport as u8);
let wb = proto_tx.get_writebuf()?; let wb = proto_tx.get_writebuf()?;
wb.le_u16(general_code as u16)?; wb.le_u16(general_code as u16)?;

View file

@ -1,31 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#[cfg(target_os = "macos")]
mod sys_macos;
#[cfg(target_os = "macos")]
pub use self::sys_macos::*;
#[cfg(target_os = "linux")]
mod sys_linux;
#[cfg(target_os = "linux")]
pub use self::sys_linux::*;
#[cfg(any(target_os = "macos", target_os = "linux"))]
mod posix;
#[cfg(any(target_os = "macos", target_os = "linux"))]
pub use self::posix::*;

View file

@ -1,96 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use std::{
convert::TryInto,
fs::{remove_file, DirBuilder, File},
io::{Read, Write},
sync::{Arc, Mutex, Once},
};
use crate::error::Error;
pub const SPAKE2_ITERATION_COUNT: u32 = 2000;
// The Packet Pool that is allocated from. POSIX systems can use
// higher values unlike embedded systems
pub const MAX_PACKET_POOL_SIZE: usize = 25;
pub struct Psm {}
static mut G_PSM: Option<Arc<Mutex<Psm>>> = None;
static INIT: Once = Once::new();
const PSM_DIR: &str = "/tmp/matter_psm";
macro_rules! psm_path {
($key:ident) => {
format!("{}/{}", PSM_DIR, $key)
};
}
impl Psm {
fn new() -> Result<Self, Error> {
let result = DirBuilder::new().create(PSM_DIR);
if let Err(e) = result {
if e.kind() != std::io::ErrorKind::AlreadyExists {
return Err(e.into());
}
}
Ok(Self {})
}
pub fn get() -> Result<Arc<Mutex<Self>>, Error> {
unsafe {
INIT.call_once(|| {
G_PSM = Some(Arc::new(Mutex::new(Psm::new().unwrap())));
});
Ok(G_PSM.as_ref().ok_or(Error::Invalid)?.clone())
}
}
pub fn set_kv_slice(&self, key: &str, val: &[u8]) -> Result<(), Error> {
let mut f = File::create(psm_path!(key))?;
f.write_all(val)?;
Ok(())
}
pub fn get_kv_slice(&self, key: &str, val: &mut Vec<u8>) -> Result<usize, Error> {
let mut f = File::open(psm_path!(key))?;
let len = f.read_to_end(val)?;
Ok(len)
}
pub fn set_kv_u64(&self, key: &str, val: u64) -> Result<(), Error> {
let mut f = File::create(psm_path!(key))?;
f.write_all(&val.to_be_bytes())?;
Ok(())
}
pub fn get_kv_u64(&self, key: &str, val: &mut u64) -> Result<(), Error> {
let mut f = File::open(psm_path!(key))?;
let mut vec = Vec::new();
let _ = f.read_to_end(&mut vec)?;
*val = u64::from_be_bytes(vec.as_slice().try_into()?);
Ok(())
}
pub fn rm(&self, key: &str) {
let _ = remove_file(psm_path!(key));
}
}

View file

@ -1,58 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::error::Error;
use lazy_static::lazy_static;
use libmdns::{Responder, Service};
use log::info;
use std::sync::{Arc, Mutex};
use std::vec::Vec;
#[allow(dead_code)]
pub struct SysMdnsService {
service: Service,
}
lazy_static! {
static ref RESPONDER: Arc<Mutex<Responder>> = Arc::new(Mutex::new(Responder::new().unwrap()));
}
/// Publish a mDNS service
/// name - can be a service name (comma separate subtypes may follow)
/// regtype - registration type (e.g. _matter_.tcp etc)
/// port - the port
pub fn sys_publish_service(
name: &str,
regtype: &str,
port: u16,
txt_kvs: &[[&str; 2]],
) -> Result<SysMdnsService, Error> {
info!("mDNS Registration Type {}", regtype);
info!("mDNS properties {:?}", txt_kvs);
let mut properties = Vec::new();
for kvs in txt_kvs {
info!("mDNS TXT key {} val {}", kvs[0], kvs[1]);
properties.push(format!("{}={}", kvs[0], kvs[1]));
}
let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect();
let responder = RESPONDER.lock().map_err(|_| Error::MdnsError)?;
let service = responder.register(regtype.to_owned(), name.to_owned(), port, &properties);
Ok(SysMdnsService { service })
}

View file

@ -1,46 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::error::Error;
use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService};
use log::info;
#[allow(dead_code)]
pub struct SysMdnsService {
s: RegisteredDnsService,
}
/// Publish a mDNS service
/// name - can be a service name (comma separate subtypes may follow)
/// regtype - registration type (e.g. _matter_.tcp etc)
/// port - the port
pub fn sys_publish_service(
name: &str,
regtype: &str,
port: u16,
txt_kvs: &[[&str; 2]],
) -> Result<SysMdnsService, Error> {
let mut builder = DNSServiceBuilder::new(regtype, port).with_name(name);
info!("mDNS Registration Type {}", regtype);
for kvs in txt_kvs {
info!("mDNS TXT key {} val {}", kvs[0], kvs[1]);
builder = builder.with_key_value(kvs[0].to_string(), kvs[1].to_string());
}
let s = builder.register().map_err(|_| Error::MdnsError)?;
Ok(SysMdnsService { s })
}

View file

@ -15,11 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error; use crate::error::{Error, ErrorCode};
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use core::fmt;
use log::{error, info}; use log::{error, info};
use std::fmt;
use super::{TagType, MAX_TAG_INDEX, TAG_MASK, TAG_SHIFT_BITS, TAG_SIZE_MAP, TYPE_MASK}; use super::{TagType, MAX_TAG_INDEX, TAG_MASK, TAG_SHIFT_BITS, TAG_SIZE_MAP, TYPE_MASK};
@ -33,14 +33,7 @@ impl<'a> TLVList<'a> {
} }
} }
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Pointer<'a> {
buf: &'a [u8],
current: usize,
left: usize,
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ElementType<'a> { pub enum ElementType<'a> {
S8(i8), S8(i8),
S16(i16), S16(i16),
@ -63,9 +56,9 @@ pub enum ElementType<'a> {
Str32l, Str32l,
Str64l, Str64l,
Null, Null,
Struct(Pointer<'a>), Struct(&'a [u8]),
Array(Pointer<'a>), Array(&'a [u8]),
List(Pointer<'a>), List(&'a [u8]),
EndCnt, EndCnt,
Last, Last,
} }
@ -204,44 +197,11 @@ static VALUE_EXTRACTOR: [ExtractValue; MAX_VALUE_INDEX] = [
// Null 20 // Null 20
{ |_t| (0, ElementType::Null) }, { |_t| (0, ElementType::Null) },
// Struct 21 // Struct 21
{ { |t| (0, ElementType::Struct(&t.buf[t.current..])) },
|t| {
(
0,
ElementType::Struct(Pointer {
buf: t.buf,
current: t.current,
left: t.left,
}),
)
}
},
// Array 22 // Array 22
{ { |t| (0, ElementType::Array(&t.buf[t.current..])) },
|t| {
(
0,
ElementType::Array(Pointer {
buf: t.buf,
current: t.current,
left: t.left,
}),
)
}
},
// List 23 // List 23
{ { |t| (0, ElementType::List(&t.buf[t.current..])) },
|t| {
(
0,
ElementType::List(Pointer {
buf: t.buf,
current: t.current,
left: t.left,
}),
)
}
},
// EndCnt 24 // EndCnt 24
{ |_t| (0, ElementType::EndCnt) }, { |_t| (0, ElementType::EndCnt) },
]; ];
@ -282,9 +242,9 @@ fn read_length_value<'a>(
// The current offset is the string size // The current offset is the string size
let length: usize = LittleEndian::read_uint(&t.buf[t.current..], size_of_length_field) as usize; let length: usize = LittleEndian::read_uint(&t.buf[t.current..], size_of_length_field) as usize;
// We'll consume the current offset (len) + the entire string // We'll consume the current offset (len) + the entire string
if length + size_of_length_field > t.left { if length + size_of_length_field > t.buf.len() - t.current {
// Return Error // Return Error
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} else { } else {
Ok(( Ok((
// return the additional size only // return the additional size only
@ -294,7 +254,7 @@ fn read_length_value<'a>(
} }
} }
#[derive(Debug, Copy, Clone)] #[derive(Debug, Clone)]
pub struct TLVElement<'a> { pub struct TLVElement<'a> {
tag_type: TagType, tag_type: TagType,
element_type: ElementType<'a>, element_type: ElementType<'a>,
@ -303,11 +263,11 @@ pub struct TLVElement<'a> {
impl<'a> PartialEq for TLVElement<'a> { impl<'a> PartialEq for TLVElement<'a> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
match self.element_type { match self.element_type {
ElementType::Struct(a) | ElementType::Array(a) | ElementType::List(a) => { ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => {
let mut our_iter = TLVListIterator::from_pointer(a); let mut our_iter = TLVListIterator::from_buf(buf);
let mut their = match other.element_type { let mut their = match other.element_type {
ElementType::Struct(b) | ElementType::Array(b) | ElementType::List(b) => { ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => {
TLVListIterator::from_pointer(b) TLVListIterator::from_buf(buf)
} }
_ => { _ => {
// If we are a container, the other must be a container, else this is a mismatch // If we are a container, the other must be a container, else this is a mismatch
@ -318,7 +278,7 @@ impl<'a> PartialEq for TLVElement<'a> {
loop { loop {
let ours = our_iter.next(); let ours = our_iter.next();
let theirs = their.next(); let theirs = their.next();
if std::mem::discriminant(&ours) != std::mem::discriminant(&theirs) { if core::mem::discriminant(&ours) != core::mem::discriminant(&theirs) {
// One of us reached end of list, but the other didn't, that's a mismatch // One of us reached end of list, but the other didn't, that's a mismatch
return false; return false;
} }
@ -336,13 +296,13 @@ impl<'a> PartialEq for TLVElement<'a> {
} }
nest_level -= 1; nest_level -= 1;
} else { } else {
if is_container(ours.element_type) { if is_container(&ours.element_type) {
nest_level += 1; nest_level += 1;
// Only compare the discriminants in case of array/list/structures, // Only compare the discriminants in case of array/list/structures,
// instead of actual element values. Those will be subsets within this same // instead of actual element values. Those will be subsets within this same
// list that will get validated anyway // list that will get validated anyway
if std::mem::discriminant(&ours.element_type) if core::mem::discriminant(&ours.element_type)
!= std::mem::discriminant(&theirs.element_type) != core::mem::discriminant(&theirs.element_type)
{ {
return false; return false;
} }
@ -364,15 +324,11 @@ impl<'a> PartialEq for TLVElement<'a> {
impl<'a> TLVElement<'a> { impl<'a> TLVElement<'a> {
pub fn enter(&self) -> Option<TLVContainerIterator<'a>> { pub fn enter(&self) -> Option<TLVContainerIterator<'a>> {
let ptr = match self.element_type { let buf = match self.element_type {
ElementType::Struct(a) | ElementType::Array(a) | ElementType::List(a) => a, ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => buf,
_ => return None, _ => return None,
}; };
let list_iter = TLVListIterator { let list_iter = TLVListIterator { buf, current: 0 };
buf: ptr.buf,
current: ptr.current,
left: ptr.left,
};
Some(TLVContainerIterator { Some(TLVContainerIterator {
list_iter, list_iter,
prev_container: false, prev_container: false,
@ -390,14 +346,22 @@ impl<'a> TLVElement<'a> {
pub fn i8(&self) -> Result<i8, Error> { pub fn i8(&self) -> Result<i8, Error> {
match self.element_type { match self.element_type {
ElementType::S8(a) => Ok(a), ElementType::S8(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn u8(&self) -> Result<u8, Error> { pub fn u8(&self) -> Result<u8, Error> {
match self.element_type { match self.element_type {
ElementType::U8(a) => Ok(a), ElementType::U8(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
}
}
pub fn i16(&self) -> Result<i16, Error> {
match self.element_type {
ElementType::S8(a) => Ok(a.into()),
ElementType::S16(a) => Ok(a),
_ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -405,7 +369,16 @@ impl<'a> TLVElement<'a> {
match self.element_type { match self.element_type {
ElementType::U8(a) => Ok(a.into()), ElementType::U8(a) => Ok(a.into()),
ElementType::U16(a) => Ok(a), ElementType::U16(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
}
}
pub fn i32(&self) -> Result<i32, Error> {
match self.element_type {
ElementType::S8(a) => Ok(a.into()),
ElementType::S16(a) => Ok(a.into()),
ElementType::S32(a) => Ok(a),
_ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -414,7 +387,17 @@ impl<'a> TLVElement<'a> {
ElementType::U8(a) => Ok(a.into()), ElementType::U8(a) => Ok(a.into()),
ElementType::U16(a) => Ok(a.into()), ElementType::U16(a) => Ok(a.into()),
ElementType::U32(a) => Ok(a), ElementType::U32(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
}
}
pub fn i64(&self) -> Result<i64, Error> {
match self.element_type {
ElementType::S8(a) => Ok(a.into()),
ElementType::S16(a) => Ok(a.into()),
ElementType::S32(a) => Ok(a.into()),
ElementType::S64(a) => Ok(a),
_ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -424,7 +407,7 @@ impl<'a> TLVElement<'a> {
ElementType::U16(a) => Ok(a.into()), ElementType::U16(a) => Ok(a.into()),
ElementType::U32(a) => Ok(a.into()), ElementType::U32(a) => Ok(a.into()),
ElementType::U64(a) => Ok(a), ElementType::U64(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -434,7 +417,19 @@ impl<'a> TLVElement<'a> {
| ElementType::Utf8l(s) | ElementType::Utf8l(s)
| ElementType::Str16l(s) | ElementType::Str16l(s)
| ElementType::Utf16l(s) => Ok(s), | ElementType::Utf16l(s) => Ok(s),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
}
}
pub fn str(&self) -> Result<&'a str, Error> {
match self.element_type {
ElementType::Str8l(s)
| ElementType::Utf8l(s)
| ElementType::Str16l(s)
| ElementType::Utf16l(s) => {
Ok(core::str::from_utf8(s).map_err(|_| Error::from(ErrorCode::InvalidData))?)
}
_ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -442,48 +437,48 @@ impl<'a> TLVElement<'a> {
match self.element_type { match self.element_type {
ElementType::False => Ok(false), ElementType::False => Ok(false),
ElementType::True => Ok(true), ElementType::True => Ok(true),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn null(&self) -> Result<(), Error> { pub fn null(&self) -> Result<(), Error> {
match self.element_type { match self.element_type {
ElementType::Null => Ok(()), ElementType::Null => Ok(()),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn confirm_struct(&self) -> Result<TLVElement<'a>, Error> { pub fn confirm_struct(&self) -> Result<&TLVElement<'a>, Error> {
match self.element_type { match self.element_type {
ElementType::Struct(_) => Ok(*self), ElementType::Struct(_) => Ok(self),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn confirm_array(&self) -> Result<TLVElement<'a>, Error> { pub fn confirm_array(&self) -> Result<&TLVElement<'a>, Error> {
match self.element_type { match self.element_type {
ElementType::Array(_) => Ok(*self), ElementType::Array(_) => Ok(self),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn confirm_list(&self) -> Result<TLVElement<'a>, Error> { pub fn confirm_list(&self) -> Result<&TLVElement<'a>, Error> {
match self.element_type { match self.element_type {
ElementType::List(_) => Ok(*self), ElementType::List(_) => Ok(self),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn find_tag(&self, tag: u32) -> Result<TLVElement<'a>, Error> { pub fn find_tag(&self, tag: u32) -> Result<TLVElement<'a>, Error> {
let match_tag: TagType = TagType::Context(tag as u8); let match_tag: TagType = TagType::Context(tag as u8);
let iter = self.enter().ok_or(Error::TLVTypeMismatch)?; let iter = self.enter().ok_or(ErrorCode::TLVTypeMismatch)?;
for a in iter { for a in iter {
if match_tag == a.tag_type { if match_tag == a.tag_type {
return Ok(a); return Ok(a);
} }
} }
Err(Error::NoTagFound) Err(ErrorCode::NoTagFound.into())
} }
pub fn get_tag(&self) -> TagType { pub fn get_tag(&self) -> TagType {
@ -499,8 +494,8 @@ impl<'a> TLVElement<'a> {
false false
} }
pub fn get_element_type(&self) -> ElementType { pub fn get_element_type(&self) -> &ElementType {
self.element_type &self.element_type
} }
} }
@ -522,7 +517,7 @@ impl<'a> fmt::Display for TLVElement<'a> {
| ElementType::Utf8l(a) | ElementType::Utf8l(a)
| ElementType::Str16l(a) | ElementType::Str16l(a)
| ElementType::Utf16l(a) => { | ElementType::Utf16l(a) => {
if let Ok(s) = std::str::from_utf8(a) { if let Ok(s) = core::str::from_utf8(a) {
write!(f, "len[{}]\"{}\"", s.len(), s) write!(f, "len[{}]\"{}\"", s.len(), s)
} else { } else {
write!(f, "len[{}]{:x?}", a.len(), a) write!(f, "len[{}]{:x?}", a.len(), a)
@ -534,25 +529,19 @@ impl<'a> fmt::Display for TLVElement<'a> {
} }
// This is a TLV List iterator, it only iterates over the individual TLVs in a TLV list // This is a TLV List iterator, it only iterates over the individual TLVs in a TLV list
#[derive(Copy, Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct TLVListIterator<'a> { pub struct TLVListIterator<'a> {
buf: &'a [u8], buf: &'a [u8],
current: usize, current: usize,
left: usize,
} }
impl<'a> TLVListIterator<'a> { impl<'a> TLVListIterator<'a> {
fn from_pointer(p: Pointer<'a>) -> Self { fn from_buf(buf: &'a [u8]) -> Self {
Self { Self { buf, current: 0 }
buf: p.buf,
current: p.current,
left: p.left,
}
} }
fn advance(&mut self, len: usize) { fn advance(&mut self, len: usize) {
self.current += len; self.current += len;
self.left -= len;
} }
// Caller should ensure they are reading the _right_ tag at the _right_ place // Caller should ensure they are reading the _right_ tag at the _right_ place
@ -561,7 +550,7 @@ impl<'a> TLVListIterator<'a> {
return None; return None;
} }
let tag_size = TAG_SIZE_MAP[tag_type as usize]; let tag_size = TAG_SIZE_MAP[tag_type as usize];
if tag_size > self.left { if tag_size > self.buf.len() - self.current {
return None; return None;
} }
let tag = (TAG_EXTRACTOR[tag_type as usize])(self); let tag = (TAG_EXTRACTOR[tag_type as usize])(self);
@ -574,7 +563,7 @@ impl<'a> TLVListIterator<'a> {
return None; return None;
} }
let mut size = VALUE_SIZE_MAP[element_type as usize]; let mut size = VALUE_SIZE_MAP[element_type as usize];
if size > self.left { if size > self.buf.len() - self.current {
error!( error!(
"Invalid value found: {} self {:?} size {}", "Invalid value found: {} self {:?} size {}",
element_type, self, size element_type, self, size
@ -597,7 +586,7 @@ impl<'a> Iterator for TLVListIterator<'a> {
type Item = TLVElement<'a>; type Item = TLVElement<'a>;
/* Code for going to the next Element */ /* Code for going to the next Element */
fn next(&mut self) -> Option<TLVElement<'a>> { fn next(&mut self) -> Option<TLVElement<'a>> {
if self.left < 1 { if self.buf.len() - self.current < 1 {
return None; return None;
} }
/* Read Control */ /* Read Control */
@ -623,13 +612,12 @@ impl<'a> TLVList<'a> {
pub fn iter(&self) -> TLVListIterator<'a> { pub fn iter(&self) -> TLVListIterator<'a> {
TLVListIterator { TLVListIterator {
current: 0, current: 0,
left: self.buf.len(),
buf: self.buf, buf: self.buf,
} }
} }
} }
fn is_container(element_type: ElementType) -> bool { fn is_container(element_type: &ElementType) -> bool {
matches!( matches!(
element_type, element_type,
ElementType::Struct(_) | ElementType::Array(_) | ElementType::List(_) ElementType::Struct(_) | ElementType::Array(_) | ElementType::List(_)
@ -668,7 +656,7 @@ impl<'a> TLVContainerIterator<'a> {
nest_level -= 1; nest_level -= 1;
} }
_ => { _ => {
if is_container(element.element_type) { if is_container(&element.element_type) {
nest_level += 1; nest_level += 1;
} }
} }
@ -699,33 +687,38 @@ impl<'a> Iterator for TLVContainerIterator<'a> {
return None; return None;
} }
if is_container(element.element_type) { self.prev_container = is_container(&element.element_type);
self.prev_container = true;
} else {
self.prev_container = false;
}
Some(element) Some(element)
} }
} }
pub fn get_root_node(b: &[u8]) -> Result<TLVElement, Error> { pub fn get_root_node(b: &[u8]) -> Result<TLVElement, Error> {
TLVList::new(b).iter().next().ok_or(Error::InvalidData) Ok(TLVList::new(b)
.iter()
.next()
.ok_or(ErrorCode::InvalidData)?)
} }
pub fn get_root_node_struct(b: &[u8]) -> Result<TLVElement, Error> { pub fn get_root_node_struct(b: &[u8]) -> Result<TLVElement, Error> {
TLVList::new(b) let root = TLVList::new(b)
.iter() .iter()
.next() .next()
.ok_or(Error::InvalidData)? .ok_or(ErrorCode::InvalidData)?;
.confirm_struct()
root.confirm_struct()?;
Ok(root)
} }
pub fn get_root_node_list(b: &[u8]) -> Result<TLVElement, Error> { pub fn get_root_node_list(b: &[u8]) -> Result<TLVElement, Error> {
TLVList::new(b) let root = TLVList::new(b)
.iter() .iter()
.next() .next()
.ok_or(Error::InvalidData)? .ok_or(ErrorCode::InvalidData)?;
.confirm_list()
root.confirm_list()?;
Ok(root)
} }
pub fn print_tlv_list(b: &[u8]) { pub fn print_tlv_list(b: &[u8]) {
@ -752,7 +745,7 @@ pub fn print_tlv_list(b: &[u8]) {
match a.element_type { match a.element_type {
ElementType::Struct(_) => { ElementType::Struct(_) => {
if index < MAX_DEPTH { if index < MAX_DEPTH {
println!("{}{}", space[index], a); info!("{}{}", space[index], a);
stack[index] = '}'; stack[index] = '}';
index += 1; index += 1;
} else { } else {
@ -761,7 +754,7 @@ pub fn print_tlv_list(b: &[u8]) {
} }
ElementType::Array(_) | ElementType::List(_) => { ElementType::Array(_) | ElementType::List(_) => {
if index < MAX_DEPTH { if index < MAX_DEPTH {
println!("{}{}", space[index], a); info!("{}{}", space[index], a);
stack[index] = ']'; stack[index] = ']';
index += 1; index += 1;
} else { } else {
@ -771,24 +764,25 @@ pub fn print_tlv_list(b: &[u8]) {
ElementType::EndCnt => { ElementType::EndCnt => {
if index > 0 { if index > 0 {
index -= 1; index -= 1;
println!("{}{}", space[index], stack[index]); info!("{}{}", space[index], stack[index]);
} else { } else {
error!("Incorrect TLV List"); error!("Incorrect TLV List");
} }
} }
_ => println!("{}{}", space[index], a), _ => info!("{}{}", space[index], a),
} }
} }
println!("---------"); info!("---------");
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use log::info;
use super::{ use super::{
get_root_node_list, get_root_node_struct, ElementType, Pointer, TLVElement, TLVList, get_root_node_list, get_root_node_struct, ElementType, TLVElement, TLVList, TagType,
TagType,
}; };
use crate::error::Error; use crate::error::ErrorCode;
#[test] #[test]
fn test_short_length_tag() { fn test_short_length_tag() {
@ -846,11 +840,7 @@ mod tests {
tlv_iter.next(), tlv_iter.next(),
Some(TLVElement { Some(TLVElement {
tag_type: TagType::Context(0), tag_type: TagType::Context(0),
element_type: ElementType::Array(Pointer { element_type: ElementType::Array(&[]),
buf: &[21, 54, 0],
current: 3,
left: 0
}),
}) })
); );
} }
@ -1105,12 +1095,13 @@ mod tests {
.unwrap() .unwrap()
.enter() .enter()
.unwrap(); .unwrap();
println!("Command list iterator: {:?}", cmd_list_iter); info!("Command list iterator: {:?}", cmd_list_iter);
// This is an array of CommandDataIB, but we'll only use the first element // This is an array of CommandDataIB, but we'll only use the first element
let cmd_data_ib = cmd_list_iter.next().unwrap(); let cmd_data_ib = cmd_list_iter.next().unwrap();
let cmd_path = cmd_data_ib.find_tag(0).unwrap().confirm_list().unwrap(); let cmd_path = cmd_data_ib.find_tag(0).unwrap();
let cmd_path = cmd_path.confirm_list().unwrap();
assert_eq!( assert_eq!(
cmd_path.find_tag(0).unwrap(), cmd_path.find_tag(0).unwrap(),
TLVElement { TLVElement {
@ -1132,7 +1123,10 @@ mod tests {
element_type: ElementType::U32(1), element_type: ElementType::U32(1),
} }
); );
assert_eq!(cmd_path.find_tag(3), Err(Error::NoTagFound)); assert_eq!(
cmd_path.find_tag(3).map_err(|e| e.code()),
Err(ErrorCode::NoTagFound)
);
// This is the variable of the invoke command // This is the variable of the invoke command
assert_eq!( assert_eq!(
@ -1172,11 +1166,7 @@ mod tests {
0x35, 0x1, 0x18, 0x18, 0x18, 0x18, 0x35, 0x1, 0x18, 0x18, 0x18, 0x18,
]; ];
let dummy_pointer = Pointer { let dummy_pointer = &b[1..];
buf: &b,
current: 1,
left: 21,
};
// These are the decoded elements that we expect from this input // These are the decoded elements that we expect from this input
let verify_matrix: [(TagType, ElementType); 13] = [ let verify_matrix: [(TagType, ElementType); 13] = [
(TagType::Anonymous, ElementType::Struct(dummy_pointer)), (TagType::Anonymous, ElementType::Struct(dummy_pointer)),
@ -1203,8 +1193,8 @@ mod tests {
Some(a) => { Some(a) => {
assert_eq!(a.tag_type, verify_matrix[index].0); assert_eq!(a.tag_type, verify_matrix[index].0);
assert_eq!( assert_eq!(
std::mem::discriminant(&a.element_type), core::mem::discriminant(&a.element_type),
std::mem::discriminant(&verify_matrix[index].1) core::mem::discriminant(&verify_matrix[index].1)
); );
} }
} }

View file

@ -16,10 +16,10 @@
*/ */
use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType};
use crate::error::Error; use crate::error::{Error, ErrorCode};
use core::fmt::Debug;
use core::slice::Iter; use core::slice::Iter;
use log::error; use log::error;
use std::fmt::Debug;
pub trait FromTLV<'a> { pub trait FromTLV<'a> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error>
@ -31,33 +31,54 @@ pub trait FromTLV<'a> {
where where
Self: Sized, Self: Sized,
{ {
Err(Error::TLVNotFound) Err(ErrorCode::TLVNotFound.into())
} }
} }
impl<'a, T: Default + FromTLV<'a> + Copy, const N: usize> FromTLV<'a> for [T; N] { impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error>
where where
Self: Sized, Self: Sized,
{ {
t.confirm_array()?; t.confirm_array()?;
let mut a: [T; N] = [Default::default(); N];
let mut index = 0; let mut a = heapless::Vec::<T, N>::new();
if let Some(tlv_iter) = t.enter() { if let Some(tlv_iter) = t.enter() {
for element in tlv_iter { for element in tlv_iter {
if index < N { a.push(T::from_tlv(&element)?)
a[index] = T::from_tlv(&element)?; .map_err(|_| ErrorCode::NoSpace)?;
index += 1;
} else {
error!("Received TLV Array with elements larger than current size");
break;
} }
} }
// TODO: This was the old behavior before rebasing the
// implementation on top of heapless::Vec (to avoid requiring Copy)
// Not sure why we actually need that yet, but without it unit tests fail
while a.len() < N {
a.push(Default::default()).map_err(|_| ErrorCode::NoSpace)?;
} }
Ok(a)
a.into_array().map_err(|_| ErrorCode::Invalid.into())
} }
} }
pub fn from_tlv<'a, T: FromTLV<'a>, const N: usize>(
vec: &mut heapless::Vec<T, N>,
t: &TLVElement<'a>,
) -> Result<(), Error> {
vec.clear();
t.confirm_array()?;
if let Some(tlv_iter) = t.enter() {
for element in tlv_iter {
vec.push(T::from_tlv(&element)?)
.map_err(|_| ErrorCode::NoSpace)?;
}
}
Ok(())
}
macro_rules! fromtlv_for { macro_rules! fromtlv_for {
($($t:ident)*) => { ($($t:ident)*) => {
$( $(
@ -70,12 +91,21 @@ macro_rules! fromtlv_for {
}; };
} }
fromtlv_for!(u8 u16 u32 u64 bool); fromtlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool);
pub trait ToTLV { pub trait ToTLV {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>; fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>;
} }
impl<T> ToTLV for &T
where
T: ToTLV,
{
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
(**self).to_tlv(tw, tag)
}
}
macro_rules! totlv_for { macro_rules! totlv_for {
($($t:ident)*) => { ($($t:ident)*) => {
$( $(
@ -98,30 +128,39 @@ impl<T: ToTLV, const N: usize> ToTLV for [T; N] {
} }
} }
impl<'a, T: ToTLV> ToTLV for &'a [T] {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.start_array(tag)?;
for i in *self {
i.to_tlv(tw, TagType::Anonymous)?;
}
tw.end_container()
}
}
// Generate ToTLV for standard data types // Generate ToTLV for standard data types
totlv_for!(i8 u8 u16 u32 u64 bool); totlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool);
// We define a few common data types that will be required here // We define a few common data types that will be required here
// //
// - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec // - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec
// - These only have references into the original list // - These only have references into the original list
// - String, Vec<u8>: Is the owned version of utfstr and ostr, data is cloned into this // - heapless::String<N>, Vheapless::ec<u8, N>: Is the owned version of utfstr and ostr, data is cloned into this
// - String is only partially implemented // - heapless::String is only partially implemented
// //
// - TLVArray: Is an array of entries, with reference within the original list // - TLVArray: Is an array of entries, with reference within the original list
// - TLVArrayOwned: Is the owned version of this, data is cloned into this
/// Implements UTFString from the spec /// Implements UTFString from the spec
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Copy, Clone, PartialEq, Default)]
pub struct UtfStr<'a>(pub &'a [u8]); pub struct UtfStr<'a>(pub &'a [u8]);
impl<'a> UtfStr<'a> { impl<'a> UtfStr<'a> {
pub fn new(str: &'a [u8]) -> Self { pub const fn new(str: &'a [u8]) -> Self {
Self(str) Self(str)
} }
pub fn to_string(self) -> Result<String, Error> { pub fn as_str(&self) -> Result<&str, Error> {
String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid) core::str::from_utf8(self.0).map_err(|_| ErrorCode::Invalid.into())
} }
} }
@ -138,7 +177,7 @@ impl<'a> FromTLV<'a> for UtfStr<'a> {
} }
/// Implements OctetString from the spec /// Implements OctetString from the spec
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Copy, Clone, PartialEq, Default)]
pub struct OctetStr<'a>(pub &'a [u8]); pub struct OctetStr<'a>(pub &'a [u8]);
impl<'a> OctetStr<'a> { impl<'a> OctetStr<'a> {
@ -160,35 +199,32 @@ impl<'a> ToTLV for OctetStr<'a> {
} }
/// Implements the Owned version of Octet String /// Implements the Owned version of Octet String
impl FromTLV<'_> for Vec<u8> { impl<const N: usize> FromTLV<'_> for heapless::Vec<u8, N> {
fn from_tlv(t: &TLVElement) -> Result<Vec<u8>, Error> { fn from_tlv(t: &TLVElement) -> Result<heapless::Vec<u8, N>, Error> {
t.slice().map(|x| x.to_owned()) heapless::Vec::from_slice(t.slice()?).map_err(|_| ErrorCode::NoSpace.into())
} }
} }
impl ToTLV for Vec<u8> { impl<const N: usize> ToTLV for heapless::Vec<u8, N> {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.str16(tag, self.as_slice()) tw.str16(tag, self.as_slice())
} }
} }
/// Implements the Owned version of UTF String /// Implements the Owned version of UTF String
impl FromTLV<'_> for String { impl<const N: usize> FromTLV<'_> for heapless::String<N> {
fn from_tlv(t: &TLVElement) -> Result<String, Error> { fn from_tlv(t: &TLVElement) -> Result<heapless::String<N>, Error> {
match t.slice() { let mut string = heapless::String::new();
Ok(x) => {
if let Ok(s) = String::from_utf8(x.to_vec()) { string
Ok(s) .push_str(core::str::from_utf8(t.slice()?)?)
} else { .map_err(|_| ErrorCode::NoSpace)?;
Err(Error::Invalid)
} Ok(string)
}
Err(e) => Err(e),
}
} }
} }
impl ToTLV for String { impl<const N: usize> ToTLV for heapless::String<N> {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.utf16(tag, self.as_bytes()) tw.utf16(tag, self.as_bytes())
} }
@ -262,38 +298,7 @@ impl<T: ToTLV> ToTLV for Nullable<T> {
} }
} }
/// Owned version of a TLVArray #[derive(Clone)]
pub struct TLVArrayOwned<T>(Vec<T>);
impl<'a, T: FromTLV<'a>> FromTLV<'a> for TLVArrayOwned<T> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> {
t.confirm_array()?;
let mut vec = Vec::<T>::new();
if let Some(tlv_iter) = t.enter() {
for element in tlv_iter {
vec.push(T::from_tlv(&element)?);
}
}
Ok(Self(vec))
}
}
impl<T: ToTLV> ToTLV for TLVArrayOwned<T> {
fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
tw.start_array(tag_type)?;
for t in &self.0 {
t.to_tlv(tw, TagType::Anonymous)?;
}
tw.end_container()
}
}
impl<T> TLVArrayOwned<T> {
pub fn iter(&self) -> Iter<T> {
self.0.iter()
}
}
#[derive(Copy, Clone)]
pub enum TLVArray<'a, T> { pub enum TLVArray<'a, T> {
// This is used for the to-tlv path // This is used for the to-tlv path
Slice(&'a [T]), Slice(&'a [T]),
@ -312,14 +317,14 @@ impl<'a, T: ToTLV> TLVArray<'a, T> {
} }
pub fn iter(&self) -> TLVArrayIter<'a, T> { pub fn iter(&self) -> TLVArrayIter<'a, T> {
match *self { match self {
Self::Slice(s) => TLVArrayIter::Slice(s.iter()), Self::Slice(s) => TLVArrayIter::Slice(s.iter()),
Self::Ptr(p) => TLVArrayIter::Ptr(p.enter()), Self::Ptr(p) => TLVArrayIter::Ptr(p.enter()),
} }
} }
} }
impl<'a, T: ToTLV + FromTLV<'a> + Copy> TLVArray<'a, T> { impl<'a, T: ToTLV + FromTLV<'a> + Clone> TLVArray<'a, T> {
pub fn get_index(&self, index: usize) -> T { pub fn get_index(&self, index: usize) -> T {
for (curr, element) in self.iter().enumerate() { for (curr, element) in self.iter().enumerate() {
if curr == index { if curr == index {
@ -330,12 +335,12 @@ impl<'a, T: ToTLV + FromTLV<'a> + Copy> TLVArray<'a, T> {
} }
} }
impl<'a, T: FromTLV<'a> + Copy> Iterator for TLVArrayIter<'a, T> { impl<'a, T: FromTLV<'a> + Clone> Iterator for TLVArrayIter<'a, T> {
type Item = T; type Item = T;
/* Code for going to the next Element */ /* Code for going to the next Element */
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
match self { match self {
Self::Slice(s_iter) => s_iter.next().copied(), Self::Slice(s_iter) => s_iter.next().cloned(),
Self::Ptr(p_iter) => { Self::Ptr(p_iter) => {
if let Some(tlv_iter) = p_iter.as_mut() { if let Some(tlv_iter) = p_iter.as_mut() {
let e = tlv_iter.next(); let e = tlv_iter.next();
@ -354,7 +359,7 @@ impl<'a, T: FromTLV<'a> + Copy> Iterator for TLVArrayIter<'a, T> {
impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T> impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T>
where where
T: ToTLV + FromTLV<'a> + Copy + PartialEq, T: ToTLV + FromTLV<'a> + Clone + PartialEq,
{ {
fn eq(&self, other: &&[T]) -> bool { fn eq(&self, other: &&[T]) -> bool {
let mut iter1 = self.iter(); let mut iter1 = self.iter();
@ -373,34 +378,46 @@ where
} }
} }
impl<'a, T: ToTLV> ToTLV for TLVArray<'a, T> { impl<'a, T: FromTLV<'a> + Clone + ToTLV> ToTLV for TLVArray<'a, T> {
fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
match *self {
Self::Slice(s) => {
tw.start_array(tag_type)?; tw.start_array(tag_type)?;
for a in s { for a in self.iter() {
a.to_tlv(tw, TagType::Anonymous)?; a.to_tlv(tw, TagType::Anonymous)?;
} }
tw.end_container() tw.end_container()
} // match *self {
Self::Ptr(t) => t.to_tlv(tw, tag_type), // Self::Slice(s) => {
} // tw.start_array(tag_type)?;
// for a in s {
// a.to_tlv(tw, TagType::Anonymous)?;
// }
// tw.end_container()
// }
// Self::Ptr(t) => t.to_tlv(tw, tag_type), <-- TODO: this fails the unit tests of Cert from/to TLV
// }
} }
} }
impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { impl<'a, T> FromTLV<'a> for TLVArray<'a, T> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> { fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> {
t.confirm_array()?; t.confirm_array()?;
Ok(Self::Ptr(*t)) Ok(Self::Ptr(t.clone()))
} }
} }
impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { impl<'a, T: Debug + ToTLV + FromTLV<'a> + Clone> Debug for TLVArray<'a, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "TLVArray [")?;
let mut first = true;
for i in self.iter() { for i in self.iter() {
writeln!(f, "{:?}", i)?; if !first {
write!(f, ", ")?;
} }
writeln!(f)
write!(f, "{:?}", i)?;
first = false;
}
write!(f, "]")
} }
} }
@ -423,7 +440,7 @@ impl<'a> ToTLV for TLVElement<'a> {
ElementType::EndCnt => tw.end_container(), ElementType::EndCnt => tw.end_container(),
_ => { _ => {
error!("ToTLV Not supported"); error!("ToTLV Not supported");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
} }
@ -431,7 +448,7 @@ impl<'a> ToTLV for TLVElement<'a> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}; use super::{FromTLV, OctetStr, TLVWriter, TagType, ToTLV};
use crate::{error::Error, tlv::TLVList, utils::writebuf::WriteBuf}; use crate::{error::Error, tlv::TLVList, utils::writebuf::WriteBuf};
use matter_macro_derive::{FromTLV, ToTLV}; use matter_macro_derive::{FromTLV, ToTLV};
@ -442,9 +459,8 @@ mod tests {
} }
#[test] #[test]
fn test_derive_totlv() { fn test_derive_totlv() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
let abc = TestDerive { let abc = TestDerive {
@ -525,9 +541,8 @@ mod tests {
#[test] #[test]
fn test_derive_totlv_fab_scoped() { fn test_derive_totlv_fab_scoped() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
let abc = TestDeriveFabScoped { a: 20, fab_idx: 3 }; let abc = TestDeriveFabScoped { a: 20, fab_idx: 3 };
@ -557,9 +572,8 @@ mod tests {
enum_val = TestDeriveEnum::ValueB(10); enum_val = TestDeriveEnum::ValueB(10);
// Test ToTLV // Test ToTLV
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
enum_val.to_tlv(&mut tw, TagType::Anonymous).unwrap(); enum_val.to_tlv(&mut tw, TagType::Anonymous).unwrap();

View file

@ -50,11 +50,11 @@ enum WriteElementType {
} }
pub struct TLVWriter<'a, 'b> { pub struct TLVWriter<'a, 'b> {
buf: &'b mut WriteBuf<'a>, buf: &'a mut WriteBuf<'b>,
} }
impl<'a, 'b> TLVWriter<'a, 'b> { impl<'a, 'b> TLVWriter<'a, 'b> {
pub fn new(buf: &'b mut WriteBuf<'a>) -> Self { pub fn new(buf: &'a mut WriteBuf<'b>) -> Self {
TLVWriter { buf } TLVWriter { buf }
} }
@ -164,7 +164,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> {
pub fn str8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { pub fn str8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> {
if data.len() > 256 { if data.len() > 256 {
error!("use str16() instead"); error!("use str16() instead");
return Err(Error::Invalid); return Err(ErrorCode::Invalid.into());
} }
self.put_control_tag(tag_type, WriteElementType::Str8l)?; self.put_control_tag(tag_type, WriteElementType::Str8l)?;
self.buf.le_u8(data.len() as u8)?; self.buf.le_u8(data.len() as u8)?;
@ -265,7 +265,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> {
self.buf.rewind_tail_to(anchor); self.buf.rewind_tail_to(anchor);
} }
pub fn get_buf<'c>(&'c mut self) -> &'c mut WriteBuf<'a> { pub fn get_buf(&mut self) -> &mut WriteBuf<'b> {
self.buf self.buf
} }
} }
@ -277,9 +277,8 @@ mod tests {
#[test] #[test]
fn test_write_success() { fn test_write_success() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
tw.start_struct(TagType::Anonymous).unwrap(); tw.start_struct(TagType::Anonymous).unwrap();
@ -299,9 +298,8 @@ mod tests {
#[test] #[test]
fn test_write_overflow() { fn test_write_overflow() {
let mut buf: [u8; 6] = [0; 6]; let mut buf = [0; 6];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
tw.u8(TagType::Anonymous, 12).unwrap(); tw.u8(TagType::Anonymous, 12).unwrap();
@ -317,9 +315,8 @@ mod tests {
#[test] #[test]
fn test_put_str8() { fn test_put_str8() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
tw.u8(TagType::Context(1), 13).unwrap(); tw.u8(TagType::Context(1), 13).unwrap();
@ -334,9 +331,8 @@ mod tests {
#[test] #[test]
fn test_put_str16_as() { fn test_put_str16_as() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
tw.u8(TagType::Context(1), 13).unwrap(); tw.u8(TagType::Context(1), 13).unwrap();

View file

@ -0,0 +1,251 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use log::info;
use crate::{error::*, CommissioningData, Matter};
use crate::secure_channel::common::PROTO_ID_SECURE_CHANNEL;
use crate::secure_channel::core::SecureChannel;
use crate::transport::mrp::ReliableMessage;
use crate::transport::{exchange, network::Address, packet::Packet};
use super::proto_ctx::ProtoCtx;
use super::session::CloneData;
enum RecvState {
New,
OpenExchange,
AddSession(CloneData),
EvictSession,
EvictSession2(CloneData),
Ack,
}
pub enum RecvAction<'r, 'p> {
Send(Address, &'r [u8]),
Interact(ProtoCtx<'r, 'p>),
}
pub struct RecvCompletion<'r, 'a> {
transport: &'r mut Transport<'a>,
rx: Packet<'r>,
tx: Packet<'r>,
state: RecvState,
}
impl<'r, 'a> RecvCompletion<'r, 'a> {
pub fn next_action(&mut self) -> Result<Option<RecvAction<'_, 'r>>, Error> {
loop {
// Polonius will remove the need for unsafe one day
let this = unsafe { (self as *mut RecvCompletion).as_mut().unwrap() };
if let Some(action) = this.maybe_next_action()? {
return Ok(action);
}
}
}
fn maybe_next_action(&mut self) -> Result<Option<Option<RecvAction<'_, 'r>>>, Error> {
self.transport.exch_mgr.purge();
self.tx.reset();
let (state, next) = match core::mem::replace(&mut self.state, RecvState::New) {
RecvState::New => {
self.rx.plain_hdr_decode()?;
(RecvState::OpenExchange, None)
}
RecvState::OpenExchange => match self.transport.exch_mgr.recv(&mut self.rx) {
Ok(Some(exch_ctx)) => {
if self.rx.get_proto_id() == PROTO_ID_SECURE_CHANNEL {
let mut proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx);
let mut secure_channel = SecureChannel::new(self.transport.matter);
let (reply, clone_data) = secure_channel.handle(&mut proto_ctx)?;
let state = if let Some(clone_data) = clone_data {
RecvState::AddSession(clone_data)
} else {
RecvState::Ack
};
if reply {
if proto_ctx.send()? {
(
state,
Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))),
)
} else {
(state, None)
}
} else {
(state, None)
}
} else {
let proto_ctx = ProtoCtx::new(exch_ctx, &self.rx, &mut self.tx);
(RecvState::Ack, Some(Some(RecvAction::Interact(proto_ctx))))
}
}
Ok(None) => (RecvState::Ack, None),
Err(e) => match e.code() {
ErrorCode::Duplicate => (RecvState::Ack, None),
ErrorCode::NoSpace => (RecvState::EvictSession, None),
_ => Err(e)?,
},
},
RecvState::AddSession(clone_data) => {
match self.transport.exch_mgr.add_session(&clone_data) {
Ok(_) => (RecvState::Ack, None),
Err(e) => match e.code() {
ErrorCode::NoSpace => (RecvState::EvictSession2(clone_data), None),
_ => Err(e)?,
},
}
}
RecvState::EvictSession => {
if self.transport.exch_mgr.evict_session(&mut self.tx)? {
(
RecvState::OpenExchange,
Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))),
)
} else {
(RecvState::EvictSession, None)
}
}
RecvState::EvictSession2(clone_data) => {
if self.transport.exch_mgr.evict_session(&mut self.tx)? {
(
RecvState::AddSession(clone_data),
Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))),
)
} else {
(RecvState::EvictSession2(clone_data), None)
}
}
RecvState::Ack => {
if let Some(exch_id) = self.transport.exch_mgr.pending_ack() {
info!("Sending MRP Standalone ACK for exch {}", exch_id);
ReliableMessage::prepare_ack(exch_id, &mut self.tx);
if self.transport.exch_mgr.send(exch_id, &mut self.tx)? {
(
RecvState::Ack,
Some(Some(RecvAction::Send(self.tx.peer, self.tx.as_slice()))),
)
} else {
(RecvState::Ack, None)
}
} else {
(RecvState::Ack, Some(None))
}
}
};
self.state = state;
Ok(next)
}
}
enum NotifyState {}
pub enum NotifyAction<'r, 'p> {
Send(&'r [u8]),
Notify(ProtoCtx<'r, 'p>),
}
pub struct NotifyCompletion<'r, 'a> {
// TODO
_transport: &'r mut Transport<'a>,
_rx: Packet<'r>,
_tx: Packet<'r>,
_state: NotifyState,
}
impl<'r, 'a> NotifyCompletion<'r, 'a> {
pub fn next_action(&mut self) -> Result<Option<NotifyAction<'_, 'r>>, Error> {
loop {
// Polonius will remove the need for unsafe one day
let this = unsafe { (self as *mut NotifyCompletion).as_mut().unwrap() };
if let Some(action) = this.maybe_next_action()? {
return Ok(action);
}
}
}
fn maybe_next_action(&mut self) -> Result<Option<Option<NotifyAction<'_, 'r>>>, Error> {
Ok(Some(None)) // TODO: Future
}
}
pub struct Transport<'a> {
matter: &'a Matter<'a>,
exch_mgr: exchange::ExchangeMgr,
}
impl<'a> Transport<'a> {
#[inline(always)]
pub fn new(matter: &'a Matter<'a>) -> Self {
let epoch = matter.epoch;
let rand = matter.rand;
Self {
matter,
exch_mgr: exchange::ExchangeMgr::new(epoch, rand),
}
}
pub fn matter(&self) -> &Matter<'a> {
self.matter
}
pub fn start(&mut self, dev_comm: CommissioningData, buf: &mut [u8]) -> Result<(), Error> {
info!("Starting Matter transport");
if self.matter().start_comissioning(dev_comm, buf)? {
info!("Comissioning started");
}
Ok(())
}
pub fn recv<'r>(
&'r mut self,
addr: Address,
rx_buf: &'r mut [u8],
tx_buf: &'r mut [u8],
) -> RecvCompletion<'r, 'a> {
let mut rx = Packet::new_rx(rx_buf);
let tx = Packet::new_tx(tx_buf);
rx.peer = addr;
RecvCompletion {
transport: self,
rx,
tx,
state: RecvState::New,
}
}
pub fn notify(&mut self, _tx: &mut Packet) -> Result<bool, Error> {
Ok(false)
}
}

View file

@ -86,6 +86,8 @@ impl RxCtrState {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use log::info;
use super::RxCtrState; use super::RxCtrState;
const ENCRYPTED: bool = true; const ENCRYPTED: bool = true;
@ -194,10 +196,10 @@ mod tests {
#[test] #[test]
fn unencrypted_device_reboot() { fn unencrypted_device_reboot() {
println!("Sub 65532 is {:?}", 1_u16.overflowing_sub(65532)); info!("Sub 65532 is {:?}", 1_u16.overflowing_sub(65532));
println!("Sub 65535 is {:?}", 1_u16.overflowing_sub(65535)); info!("Sub 65535 is {:?}", 1_u16.overflowing_sub(65535));
println!("Sub 11-13 is {:?}", 11_u32.wrapping_sub(13_u32) as i32); info!("Sub 11-13 is {:?}", 11_u32.wrapping_sub(13_u32) as i32);
println!("Sub regular is {:?}", 2000_u16.overflowing_sub(1998)); info!("Sub regular is {:?}", 2000_u16.overflowing_sub(1998));
let mut s = RxCtrState::new(20010); let mut s = RxCtrState::new(20010);
assert_ndup(s.recv(20011, NOT_ENCRYPTED)); assert_ndup(s.recv(20011, NOT_ENCRYPTED));

View file

@ -15,43 +15,46 @@
* limitations under the License. * limitations under the License.
*/ */
use boxslab::{BoxSlab, Slab}; use core::fmt;
use colored::*; use core::time::Duration;
use log::{error, info, trace}; use log::{error, info, trace};
use std::any::Any; use owo_colors::OwoColorize;
use std::fmt;
use std::time::SystemTime;
use crate::error::Error; use crate::error::{Error, ErrorCode};
use crate::interaction_model::core::{ResumeReadReq, ResumeSubscribeReq};
use crate::secure_channel; use crate::secure_channel;
use crate::secure_channel::case::CaseSession;
use crate::utils::epoch::Epoch;
use crate::utils::rand::Rand;
use heapless::LinearMap; use heapless::LinearMap;
use super::packet::PacketPool;
use super::session::CloneData; use super::session::CloneData;
use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr}; use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr};
pub struct ExchangeCtx<'a> { pub struct ExchangeCtx<'a> {
pub exch: &'a mut Exchange, pub exch: &'a mut Exchange,
pub sess: SessionHandle<'a>, pub sess: SessionHandle<'a>,
pub epoch: Epoch,
} }
#[derive(Debug, PartialEq, Eq, Copy, Clone)] impl<'a> ExchangeCtx<'a> {
pub fn send(&mut self, tx: &mut Packet) -> Result<bool, Error> {
self.exch.send(tx, &mut self.sess)
}
}
#[derive(Debug, PartialEq, Eq, Copy, Clone, Default)]
pub enum Role { pub enum Role {
#[default]
Initiator = 0, Initiator = 0,
Responder = 1, Responder = 1,
} }
impl Default for Role { #[derive(Debug, PartialEq, Default)]
fn default() -> Self {
Role::Initiator
}
}
/// State of the exchange
#[derive(Debug, PartialEq)]
enum State { enum State {
/// The exchange is open and active /// The exchange is open and active
#[default]
Open, Open,
/// The exchange is closed, but keys are active since retransmissions/acks may be pending /// The exchange is closed, but keys are active since retransmissions/acks may be pending
Close, Close,
@ -59,28 +62,20 @@ enum State {
Terminate, Terminate,
} }
impl Default for State {
fn default() -> Self {
State::Open
}
}
// Instead of just doing an Option<>, we create some special handling // Instead of just doing an Option<>, we create some special handling
// where the commonly used higher layer data store does't have to do a Box // where the commonly used higher layer data store does't have to do a Box
#[derive(Debug)] #[derive(Default)]
pub enum DataOption { pub enum DataOption {
Boxed(Box<dyn Any>), CaseSession(CaseSession),
Time(SystemTime), Time(Duration),
SuspendedReadReq(ResumeReadReq),
SubscriptionId(u32),
SuspendedSubscibeReq(ResumeSubscribeReq),
#[default]
None, None,
} }
impl Default for DataOption { #[derive(Default)]
fn default() -> Self {
DataOption::None
}
}
#[derive(Debug, Default)]
pub struct Exchange { pub struct Exchange {
id: u16, id: u16,
sess_idx: usize, sess_idx: usize,
@ -132,75 +127,117 @@ impl Exchange {
self.role self.role
} }
pub fn is_data_none(&self) -> bool { pub fn clear_data(&mut self) {
matches!(self.data, DataOption::None)
}
pub fn set_data_boxed(&mut self, data: Box<dyn Any>) {
self.data = DataOption::Boxed(data);
}
pub fn clear_data_boxed(&mut self) {
self.data = DataOption::None; self.data = DataOption::None;
} }
pub fn get_data_boxed<T: Any>(&mut self) -> Option<&mut T> { pub fn set_case_session(&mut self, session: CaseSession) {
if let DataOption::Boxed(a) = &mut self.data { self.data = DataOption::CaseSession(session);
a.downcast_mut::<T>() }
pub fn get_case_session(&mut self) -> Option<&mut CaseSession> {
if let DataOption::CaseSession(session) = &mut self.data {
Some(session)
} else { } else {
None None
} }
} }
pub fn take_data_boxed<T: Any>(&mut self) -> Option<Box<T>> { pub fn take_case_session(&mut self) -> Option<CaseSession> {
let old = std::mem::replace(&mut self.data, DataOption::None); let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::Boxed(d) = old { if let DataOption::CaseSession(session) = old {
d.downcast::<T>().ok() Some(session)
} else { } else {
self.data = old; self.data = old;
None None
} }
} }
pub fn set_data_time(&mut self, expiry_ts: Option<SystemTime>) { pub fn set_suspended_read_req(&mut self, req: ResumeReadReq) {
self.data = DataOption::SuspendedReadReq(req);
}
pub fn take_suspended_read_req(&mut self) -> Option<ResumeReadReq> {
let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::SuspendedReadReq(req) = old {
Some(req)
} else {
self.data = old;
None
}
}
pub fn set_subscription_id(&mut self, id: u32) {
self.data = DataOption::SubscriptionId(id);
}
pub fn take_subscription_id(&mut self) -> Option<u32> {
let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::SubscriptionId(id) = old {
Some(id)
} else {
self.data = old;
None
}
}
pub fn set_suspended_subscribe_req(&mut self, req: ResumeSubscribeReq) {
self.data = DataOption::SuspendedSubscibeReq(req);
}
pub fn take_suspended_subscribe_req(&mut self) -> Option<ResumeSubscribeReq> {
let old = core::mem::replace(&mut self.data, DataOption::None);
if let DataOption::SuspendedSubscibeReq(req) = old {
Some(req)
} else {
self.data = old;
None
}
}
pub fn set_data_time(&mut self, expiry_ts: Option<Duration>) {
if let Some(t) = expiry_ts { if let Some(t) = expiry_ts {
self.data = DataOption::Time(t); self.data = DataOption::Time(t);
} }
} }
pub fn get_data_time(&self) -> Option<SystemTime> { pub fn get_data_time(&self) -> Option<Duration> {
match self.data { match self.data {
DataOption::Time(t) => Some(t), DataOption::Time(t) => Some(t),
_ => None, _ => None,
} }
} }
pub fn send( pub(crate) fn send(
&mut self, &mut self,
mut proto_tx: BoxSlab<PacketPool>, tx: &mut Packet,
session: &mut SessionHandle, session: &mut SessionHandle,
) -> Result<(), Error> { ) -> Result<bool, Error> {
if self.state == State::Terminate { if self.state == State::Terminate {
info!("Skipping tx for terminated exchange {}", self.id); info!("Skipping tx for terminated exchange {}", self.id);
return Ok(()); return Ok(false);
} }
trace!("payload: {:x?}", proto_tx.as_borrow_slice()); trace!("payload: {:x?}", tx.as_slice());
info!( info!(
"{} with proto id: {} opcode: {}", "{} with proto id: {} opcode: {}, tlv:\n",
"Sending".blue(), "Sending".blue(),
proto_tx.get_proto_id(), tx.get_proto_id(),
proto_tx.get_proto_opcode(), tx.get_proto_raw_opcode(),
); );
proto_tx.proto.exch_id = self.id; //print_tlv_list(tx.as_slice());
tx.proto.exch_id = self.id;
if self.role == Role::Initiator { if self.role == Role::Initiator {
proto_tx.proto.set_initiator(); tx.proto.set_initiator();
} }
session.pre_send(&mut proto_tx)?; session.pre_send(tx)?;
self.mrp.pre_send(&mut proto_tx)?; self.mrp.pre_send(tx)?;
session.send(proto_tx) session.send(tx)?;
Ok(true)
} }
} }
@ -208,8 +245,8 @@ impl fmt::Display for Exchange {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!( write!(
f, f,
"exch_id: {:?}, sess_index: {}, role: {:?}, data: {:?}, mrp: {:?}, state: {:?}", "exch_id: {:?}, sess_index: {}, role: {:?}, mrp: {:?}, state: {:?}",
self.id, self.sess_idx, self.role, self.data, self.mrp, self.state self.id, self.sess_idx, self.role, self.mrp, self.state
) )
} }
} }
@ -232,20 +269,22 @@ pub fn get_complementary_role(is_initiator: bool) -> Role {
const MAX_EXCHANGES: usize = 8; const MAX_EXCHANGES: usize = 8;
#[derive(Default)]
pub struct ExchangeMgr { pub struct ExchangeMgr {
// keys: exch-id // keys: exch-id
exchanges: LinearMap<u16, Exchange, MAX_EXCHANGES>, exchanges: LinearMap<u16, Exchange, MAX_EXCHANGES>,
sess_mgr: SessionMgr, sess_mgr: SessionMgr,
epoch: Epoch,
} }
pub const MAX_MRP_ENTRIES: usize = 4; pub const MAX_MRP_ENTRIES: usize = 4;
impl ExchangeMgr { impl ExchangeMgr {
pub fn new(sess_mgr: SessionMgr) -> Self { #[inline(always)]
pub fn new(epoch: Epoch, rand: Rand) -> Self {
Self { Self {
sess_mgr, sess_mgr: SessionMgr::new(epoch, rand),
exchanges: Default::default(), exchanges: LinearMap::new(),
epoch,
} }
} }
@ -278,10 +317,10 @@ impl ExchangeMgr {
info!("Creating new exchange"); info!("Creating new exchange");
let e = Exchange::new(id, sess_idx, role); let e = Exchange::new(id, sess_idx, role);
if exchanges.insert(id, e).is_err() { if exchanges.insert(id, e).is_err() {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
} else { } else {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
} }
@ -291,54 +330,42 @@ impl ExchangeMgr {
if result.get_role() == role && sess_idx == result.sess_idx { if result.get_role() == role && sess_idx == result.sess_idx {
Ok(result) Ok(result)
} else { } else {
Err(Error::NoExchange) Err(ErrorCode::NoExchange.into())
} }
} else { } else {
error!("This should never happen"); error!("This should never happen");
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
} }
/// The Exchange Mgr receive is like a big processing function /// The Exchange Mgr receive is like a big processing function
pub fn recv(&mut self) -> Result<Option<(BoxSlab<PacketPool>, ExchangeCtx)>, Error> { pub fn recv(&mut self, rx: &mut Packet) -> Result<Option<ExchangeCtx>, Error> {
// Get the session // Get the session
let (mut proto_rx, index) = self.sess_mgr.recv()?; let index = self.sess_mgr.post_recv(rx)?;
let index = if let Some(s) = index {
s
} else {
// The sessions were full, evict one session, and re-perform post-recv
let evict_index = self.sess_mgr.get_lru();
self.evict_session(evict_index)?;
info!("Reattempting session creation");
self.sess_mgr.post_recv(&proto_rx)?.ok_or(Error::Invalid)?
};
let mut session = self.sess_mgr.get_session_handle(index); let mut session = self.sess_mgr.get_session_handle(index);
// Decrypt the message // Decrypt the message
session.recv(&mut proto_rx)?; session.recv(self.epoch, rx)?;
// Get the exchange // Get the exchange
let exch = ExchangeMgr::_get( let exch = ExchangeMgr::_get(
&mut self.exchanges, &mut self.exchanges,
index, index,
proto_rx.proto.exch_id, rx.proto.exch_id,
get_complementary_role(proto_rx.proto.is_initiator()), get_complementary_role(rx.proto.is_initiator()),
// We create a new exchange, only if the peer is the initiator // We create a new exchange, only if the peer is the initiator
proto_rx.proto.is_initiator(), rx.proto.is_initiator(),
)?; )?;
// Message Reliability Protocol // Message Reliability Protocol
exch.mrp.recv(&proto_rx)?; exch.mrp.recv(rx, self.epoch)?;
if exch.is_state_open() { if exch.is_state_open() {
Ok(Some(( Ok(Some(ExchangeCtx {
proto_rx,
ExchangeCtx {
exch, exch,
sess: session, sess: session,
}, epoch: self.epoch,
))) }))
} else { } else {
// Instead of an error, we send None here, because it is likely that // Instead of an error, we send None here, because it is likely that
// we just processed an acknowledgement that cleared the exchange // we just processed an acknowledgement that cleared the exchange
@ -346,11 +373,11 @@ impl ExchangeMgr {
} }
} }
pub fn send(&mut self, exch_id: u16, proto_tx: BoxSlab<PacketPool>) -> Result<(), Error> { pub fn send(&mut self, exch_id: u16, tx: &mut Packet) -> Result<bool, Error> {
let exchange = let exchange =
ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(Error::NoExchange)?; ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(ErrorCode::NoExchange)?;
let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx);
exchange.send(proto_tx, &mut session) exchange.send(tx, &mut session)
} }
pub fn purge(&mut self) { pub fn purge(&mut self) {
@ -366,28 +393,29 @@ impl ExchangeMgr {
} }
} }
pub fn pending_acks(&mut self, expired_entries: &mut LinearMap<u16, (), MAX_MRP_ENTRIES>) { pub fn pending_ack(&mut self) -> Option<u16> {
for (exch_id, exchange) in self.exchanges.iter() { self.exchanges
if exchange.mrp.is_ack_ready() { .iter()
expired_entries.insert(*exch_id, ()).unwrap(); .find(|(_, exchange)| exchange.mrp.is_ack_ready(self.epoch))
} .map(|(exch_id, _)| *exch_id)
}
} }
pub fn evict_session(&mut self, index: usize) -> Result<(), Error> { pub fn evict_session(&mut self, tx: &mut Packet) -> Result<bool, Error> {
if let Some(index) = self.sess_mgr.get_session_for_eviction() {
info!("Sessions full, vacating session with index: {}", index); info!("Sessions full, vacating session with index: {}", index);
// If we enter here, we have an LRU session that needs to be reclaimed // If we enter here, we have an LRU session that needs to be reclaimed
// As per the spec, we need to send a CLOSE here // As per the spec, we need to send a CLOSE here
let mut session = self.sess_mgr.get_session_handle(index); let mut session = self.sess_mgr.get_session_handle(index);
let mut tx = Slab::<PacketPool>::try_new(Packet::new_tx()?).ok_or(Error::NoSpace)?;
secure_channel::common::create_sc_status_report( secure_channel::common::create_sc_status_report(
&mut tx, tx,
secure_channel::common::SCStatusCodes::CloseSession, secure_channel::common::SCStatusCodes::CloseSession,
None, None,
)?; )?;
if let Some((_, exchange)) = self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index) { if let Some((_, exchange)) =
self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index)
{
// Send Close_session on this exchange, and then close the session // Send Close_session on this exchange, and then close the session
// Should this be done for all exchanges? // Should this be done for all exchanges?
error!("Sending Close Session"); error!("Sending Close Session");
@ -395,7 +423,7 @@ impl ExchangeMgr {
// TODO: This wouldn't actually send it out, because 'transport' isn't owned yet. // TODO: This wouldn't actually send it out, because 'transport' isn't owned yet.
} }
let remove_exchanges: Vec<u16> = self let remove_exchanges: heapless::Vec<u16, MAX_EXCHANGES> = self
.exchanges .exchanges
.iter() .iter()
.filter_map(|(eid, e)| { .filter_map(|(eid, e)| {
@ -414,22 +442,18 @@ impl ExchangeMgr {
// Remove from exchange list // Remove from exchange list
self.exchanges.remove(&exch_id); self.exchanges.remove(&exch_id);
} }
self.sess_mgr.remove(index); self.sess_mgr.remove(index);
Ok(())
Ok(true)
} else {
Ok(false)
}
} }
pub fn add_session(&mut self, clone_data: &CloneData) -> Result<SessionHandle, Error> { pub fn add_session(&mut self, clone_data: &CloneData) -> Result<SessionHandle, Error> {
let sess_idx = match self.sess_mgr.clone_session(clone_data) { let sess_idx = self.sess_mgr.clone_session(clone_data)?;
Ok(idx) => idx,
Err(Error::NoSpace) => {
let evict_index = self.sess_mgr.get_lru();
self.evict_session(evict_index)?;
self.sess_mgr.clone_session(clone_data)?
}
Err(e) => {
return Err(e);
}
};
Ok(self.sess_mgr.get_session_handle(sess_idx)) Ok(self.sess_mgr.get_session_handle(sess_idx))
} }
} }
@ -449,21 +473,20 @@ impl fmt::Display for ExchangeMgr {
#[cfg(test)] #[cfg(test)]
#[allow(clippy::bool_assert_comparison)] #[allow(clippy::bool_assert_comparison)]
mod tests { mod tests {
use crate::{ use crate::{
error::Error, error::ErrorCode,
transport::{ transport::{
network::{Address, NetworkInterface}, network::Address,
session::{CloneData, SessionMgr, SessionMode, MAX_SESSIONS}, session::{CloneData, SessionMode},
}, },
utils::{epoch::dummy_epoch, rand::dummy_rand},
}; };
use super::{ExchangeMgr, Role}; use super::{ExchangeMgr, Role};
#[test] #[test]
fn test_purge() { fn test_purge() {
let sess_mgr = SessionMgr::new(); let mut mgr = ExchangeMgr::new(dummy_epoch, dummy_rand);
let mut mgr = ExchangeMgr::new(sess_mgr);
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, true).unwrap(); let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, true).unwrap();
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, true).unwrap(); let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, true).unwrap();
@ -509,9 +532,12 @@ mod tests {
let clone_data = get_clone_data(peer_sess_id, local_sess_id); let clone_data = get_clone_data(peer_sess_id, local_sess_id);
match mgr.add_session(&clone_data) { match mgr.add_session(&clone_data) {
Ok(s) => assert_eq!(peer_sess_id, s.get_peer_sess_id()), Ok(s) => assert_eq!(peer_sess_id, s.get_peer_sess_id()),
Err(Error::NoSpace) => break, Err(e) => {
_ => { if e.code() == ErrorCode::NoSpace {
panic!("Couldn't, create session"); break;
} else {
panic!("Could not create sessions");
}
} }
} }
local_sess_id += 1; local_sess_id += 1;
@ -519,33 +545,17 @@ mod tests {
} }
} }
pub struct DummyNetwork; #[cfg(feature = "std")]
impl DummyNetwork {
pub fn new() -> Self {
Self {}
}
}
impl NetworkInterface for DummyNetwork {
fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, Address), Error> {
Ok((0, Address::default()))
}
fn send(&self, _out_buf: &[u8], _addr: Address) -> Result<usize, Error> {
Ok(0)
}
}
#[test] #[test]
/// We purposefuly overflow the sessions /// We purposefuly overflow the sessions
/// and when the overflow happens, we confirm that /// and when the overflow happens, we confirm that
/// - The sessions are evicted in LRU /// - The sessions are evicted in LRU
/// - The exchanges associated with those sessions are evicted too /// - The exchanges associated with those sessions are evicted too
fn test_sess_evict() { fn test_sess_evict() {
let mut sess_mgr = SessionMgr::new(); use crate::transport::packet::{Packet, MAX_TX_BUF_SIZE};
let transport = Box::new(DummyNetwork::new()); use crate::transport::session::MAX_SESSIONS;
sess_mgr.add_network_interface(transport).unwrap();
let mut mgr = ExchangeMgr::new(sess_mgr); let mut mgr = ExchangeMgr::new(crate::utils::epoch::sys_epoch, dummy_rand);
fill_sessions(&mut mgr, MAX_SESSIONS + 1); fill_sessions(&mut mgr, MAX_SESSIONS + 1);
// Sessions are now full from local session id 1 to 16 // Sessions are now full from local session id 1 to 16
@ -568,6 +578,17 @@ mod tests {
for i in 1..(MAX_SESSIONS + 1) { for i in 1..(MAX_SESSIONS + 1) {
// Now purposefully overflow the sessions by adding another session // Now purposefully overflow the sessions by adding another session
let result = mgr.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id));
assert!(matches!(
result.map_err(|e| e.code()),
Err(ErrorCode::NoSpace)
));
let mut buf = [0; MAX_TX_BUF_SIZE];
let tx = &mut Packet::new_tx(&mut buf);
let evicted = mgr.evict_session(tx).unwrap();
assert!(evicted);
let session = mgr let session = mgr
.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id)) .add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id))
.unwrap(); .unwrap();

View file

@ -1,175 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use async_channel::Receiver;
use boxslab::{BoxSlab, Slab};
use heapless::LinearMap;
use log::{debug, error, info};
use crate::error::*;
use crate::transport::mrp::ReliableMessage;
use crate::transport::packet::PacketPool;
use crate::transport::{exchange, packet::Packet, proto_demux, queue, session, udp};
use super::proto_demux::ProtoCtx;
use super::queue::Msg;
pub struct Mgr {
exch_mgr: exchange::ExchangeMgr,
proto_demux: proto_demux::ProtoDemux,
rx_q: Receiver<Msg>,
}
impl Mgr {
pub fn new() -> Result<Mgr, Error> {
let mut sess_mgr = session::SessionMgr::new();
let udp_transport = Box::new(udp::UdpListener::new()?);
sess_mgr.add_network_interface(udp_transport)?;
Ok(Mgr {
proto_demux: proto_demux::ProtoDemux::new(),
exch_mgr: exchange::ExchangeMgr::new(sess_mgr),
rx_q: queue::WorkQ::init()?,
})
}
// Allows registration of different protocols with the Transport/Protocol Demux
pub fn register_protocol(
&mut self,
proto_id_handle: Box<dyn proto_demux::HandleProto>,
) -> Result<(), Error> {
self.proto_demux.register(proto_id_handle)
}
fn send_to_exchange(
&mut self,
exch_id: u16,
proto_tx: BoxSlab<PacketPool>,
) -> Result<(), Error> {
self.exch_mgr.send(exch_id, proto_tx)
}
fn handle_rxtx(&mut self) -> Result<(), Error> {
let result = self.exch_mgr.recv().map_err(|e| {
error!("Error in recv: {:?}", e);
e
})?;
if result.is_none() {
// Nothing to process, return quietly
return Ok(());
}
// result contains something worth processing, we can safely unwrap
// as we already checked for none above
let (rx, exch_ctx) = result.unwrap();
debug!("Exchange is {:?}", exch_ctx.exch);
let tx = Self::new_tx()?;
let mut proto_ctx = ProtoCtx::new(exch_ctx, rx, tx);
// Proto Dispatch
match self.proto_demux.handle(&mut proto_ctx) {
Ok(r) => {
if let proto_demux::ResponseRequired::No = r {
// We need to send the Ack if reliability is enabled, in this case
return Ok(());
}
}
Err(e) => {
error!("Error in proto_demux {:?}", e);
return Err(e);
}
}
let ProtoCtx {
exch_ctx,
rx: _,
tx,
} = proto_ctx;
// tx_ctx now contains the response payload, send the packet
let exch_id = exch_ctx.exch.get_id();
self.send_to_exchange(exch_id, tx).map_err(|e| {
error!("Error in sending msg {:?}", e);
e
})?;
Ok(())
}
fn handle_queue_msgs(&mut self) -> Result<(), Error> {
if let Ok(msg) = self.rx_q.try_recv() {
match msg {
Msg::NewSession(clone_data) => {
// If a new session was created, add it
let _ = self
.exch_mgr
.add_session(&clone_data)
.map_err(|e| error!("Error adding new session {:?}", e));
}
_ => {
error!("Queue Message Type not yet handled {:?}", msg);
}
}
}
Ok(())
}
pub fn start(&mut self) -> Result<(), Error> {
loop {
// Handle network operations
if self.handle_rxtx().is_err() {
error!("Error in handle_rxtx");
continue;
}
if self.handle_queue_msgs().is_err() {
error!("Error in handle_queue_msg");
continue;
}
// Handle any pending acknowledgement send
let mut acks_to_send: LinearMap<u16, (), { exchange::MAX_MRP_ENTRIES }> =
LinearMap::new();
self.exch_mgr.pending_acks(&mut acks_to_send);
for exch_id in acks_to_send.keys() {
info!("Sending MRP Standalone ACK for exch {}", exch_id);
let mut proto_tx = match Self::new_tx() {
Ok(p) => p,
Err(e) => {
error!("Error creating proto_tx {:?}", e);
break;
}
};
ReliableMessage::prepare_ack(*exch_id, &mut proto_tx);
if let Err(e) = self.send_to_exchange(*exch_id, proto_tx) {
error!("Error in sending Ack {:?}", e);
}
}
// Handle exchange purging
// This need not be done in each turn of the loop, maybe once in 5 times or so?
self.exch_mgr.purge();
info!("Exchange Mgr: {}", self.exch_mgr);
}
}
fn new_tx() -> Result<BoxSlab<PacketPool>, Error> {
Slab::<PacketPool>::try_new(Packet::new_tx()?).ok_or(Error::PacketPoolExhaust)
}
}

View file

@ -15,15 +15,15 @@
* limitations under the License. * limitations under the License.
*/ */
pub mod core;
mod dedup; mod dedup;
pub mod exchange; pub mod exchange;
pub mod mgr;
pub mod mrp; pub mod mrp;
pub mod network; pub mod network;
pub mod packet; pub mod packet;
pub mod pipe;
pub mod plain_hdr; pub mod plain_hdr;
pub mod proto_demux; pub mod proto_ctx;
pub mod proto_hdr; pub mod proto_hdr;
pub mod queue;
pub mod session; pub mod session;
pub mod udp; pub mod udp;

View file

@ -15,8 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
use std::time::Duration; use crate::utils::epoch::Epoch;
use std::time::SystemTime; use core::time::Duration;
use crate::{error::*, secure_channel, transport::packet::Packet}; use crate::{error::*, secure_channel, transport::packet::Packet};
use log::error; use log::error;
@ -41,25 +41,25 @@ impl RetransEntry {
} }
} }
#[derive(Debug, Copy, Clone)] #[derive(Debug, Clone)]
pub struct AckEntry { pub struct AckEntry {
// The msg counter that we should acknowledge // The msg counter that we should acknowledge
msg_ctr: u32, msg_ctr: u32,
// The max time after which this entry must be ACK // The max time after which this entry must be ACK
ack_timeout: SystemTime, ack_timeout: Duration,
} }
impl AckEntry { impl AckEntry {
pub fn new(msg_ctr: u32) -> Result<Self, Error> { pub fn new(msg_ctr: u32, epoch: Epoch) -> Result<Self, Error> {
if let Some(ack_timeout) = if let Some(ack_timeout) =
SystemTime::now().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT)) epoch().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT))
{ {
Ok(Self { Ok(Self {
msg_ctr, msg_ctr,
ack_timeout, ack_timeout,
}) })
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
@ -67,8 +67,8 @@ impl AckEntry {
self.msg_ctr self.msg_ctr
} }
pub fn has_timed_out(&self) -> bool { pub fn has_timed_out(&self, epoch: Epoch) -> bool {
self.ack_timeout > SystemTime::now() self.ack_timeout > epoch()
} }
} }
@ -90,10 +90,10 @@ impl ReliableMessage {
} }
// Check any pending acknowledgements / retransmissions and take action // Check any pending acknowledgements / retransmissions and take action
pub fn is_ack_ready(&self) -> bool { pub fn is_ack_ready(&self, epoch: Epoch) -> bool {
// Acknowledgements // Acknowledgements
if let Some(ack_entry) = self.ack { if let Some(ack_entry) = &self.ack {
ack_entry.has_timed_out() ack_entry.has_timed_out(epoch)
} else { } else {
false false
} }
@ -107,7 +107,7 @@ impl ReliableMessage {
// Check if any acknowledgements are pending for this exchange, // Check if any acknowledgements are pending for this exchange,
// if so, piggy back in the encoded header here // if so, piggy back in the encoded header here
if let Some(ack_entry) = self.ack { if let Some(ack_entry) = &self.ack {
// Ack Entry exists, set ACK bit and remove from table // Ack Entry exists, set ACK bit and remove from table
proto_tx.proto.set_ack(ack_entry.get_msg_ctr()); proto_tx.proto.set_ack(ack_entry.get_msg_ctr());
self.ack = None; self.ack = None;
@ -120,7 +120,7 @@ impl ReliableMessage {
if self.retrans.is_some() { if self.retrans.is_some() {
// This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen
error!("Previous retrans entry for this exchange already exists"); error!("Previous retrans entry for this exchange already exists");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
self.retrans = Some(RetransEntry::new(proto_tx.plain.ctr)); self.retrans = Some(RetransEntry::new(proto_tx.plain.ctr));
@ -132,10 +132,10 @@ impl ReliableMessage {
* - there can be only one pending retransmission per exchange (so this is per-exchange) * - there can be only one pending retransmission per exchange (so this is per-exchange)
* - duplicate detection should happen per session (obviously), so that part is per-session * - duplicate detection should happen per session (obviously), so that part is per-session
*/ */
pub fn recv(&mut self, proto_rx: &Packet) -> Result<(), Error> { pub fn recv(&mut self, proto_rx: &Packet, epoch: Epoch) -> Result<(), Error> {
if proto_rx.proto.is_ack() { if proto_rx.proto.is_ack() {
// Handle received Acks // Handle received Acks
let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(Error::Invalid)?; let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(ErrorCode::Invalid)?;
if let Some(entry) = &self.retrans { if let Some(entry) = &self.retrans {
if entry.get_msg_ctr() != ack_msg_ctr { if entry.get_msg_ctr() != ack_msg_ctr {
// TODO: XXX Fix this // TODO: XXX Fix this
@ -150,10 +150,10 @@ impl ReliableMessage {
// This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen
// TODO: As per the spec if this happens, we need to send out the previous ACK and note this new ACK // TODO: As per the spec if this happens, we need to send out the previous ACK and note this new ACK
error!("Previous ACK entry for this exchange already exists"); error!("Previous ACK entry for this exchange already exists");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
self.ack = Some(AckEntry::new(proto_rx.plain.ctr)?); self.ack = Some(AckEntry::new(proto_rx.plain.ctr, epoch)?);
} }
Ok(()) Ok(())
} }

View file

@ -15,18 +15,25 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::fmt::{Debug, Display};
fmt::{Debug, Display}, #[cfg(not(feature = "std"))]
net::{IpAddr, Ipv4Addr, SocketAddr}, pub use no_std_net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
}; #[cfg(feature = "std")]
pub use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use crate::error::Error; #[derive(Eq, PartialEq, Copy, Clone)]
#[derive(PartialEq, Copy, Clone)]
pub enum Address { pub enum Address {
Udp(SocketAddr), Udp(SocketAddr),
} }
impl Address {
pub fn unwrap_udp(self) -> SocketAddr {
match self {
Self::Udp(addr) => addr,
}
}
}
impl Default for Address { impl Default for Address {
fn default() -> Self { fn default() -> Self {
Address::Udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080)) Address::Udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080))
@ -34,7 +41,7 @@ impl Default for Address {
} }
impl Display for Address { impl Display for Address {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
Address::Udp(addr) => writeln!(f, "{}", addr), Address::Udp(addr) => writeln!(f, "{}", addr),
} }
@ -42,14 +49,9 @@ impl Display for Address {
} }
impl Debug for Address { impl Debug for Address {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
Address::Udp(addr) => writeln!(f, "{}", addr), Address::Udp(addr) => writeln!(f, "{}", addr),
} }
} }
} }
pub trait NetworkInterface {
fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error>;
fn send(&self, out_buf: &[u8], addr: Address) -> Result<usize, Error>;
}

View file

@ -15,14 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
use log::{error, trace}; use log::{error, info, trace};
use std::sync::Mutex; use owo_colors::OwoColorize;
use boxslab::box_slab;
use crate::{ use crate::{
error::Error, error::{Error, ErrorCode},
sys::MAX_PACKET_POOL_SIZE, interaction_model::core::PROTO_ID_INTERACTION_MODEL,
secure_channel::common::PROTO_ID_SECURE_CHANNEL,
tlv,
utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, utils::{parsebuf::ParseBuf, writebuf::WriteBuf},
}; };
@ -33,56 +33,10 @@ use super::{
}; };
pub const MAX_RX_BUF_SIZE: usize = 1583; pub const MAX_RX_BUF_SIZE: usize = 1583;
type Buffer = [u8; MAX_RX_BUF_SIZE]; pub const MAX_RX_STATUS_BUF_SIZE: usize = 100;
pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/;
// TODO: I am not very happy with this construction, need to find another way to do this #[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct BufferPool {
buffers: [Option<Buffer>; MAX_PACKET_POOL_SIZE],
}
impl BufferPool {
const INIT: Option<Buffer> = None;
fn get() -> &'static Mutex<BufferPool> {
static mut BUFFER_HOLDER: Option<Mutex<BufferPool>> = None;
static ONCE: Once = Once::new();
unsafe {
ONCE.call_once(|| {
BUFFER_HOLDER = Some(Mutex::new(BufferPool {
buffers: [BufferPool::INIT; MAX_PACKET_POOL_SIZE],
}));
});
BUFFER_HOLDER.as_ref().unwrap()
}
}
pub fn alloc() -> Option<(usize, &'static mut Buffer)> {
trace!("Buffer Alloc called\n");
let mut pool = BufferPool::get().lock().unwrap();
for i in 0..MAX_PACKET_POOL_SIZE {
if pool.buffers[i].is_none() {
pool.buffers[i] = Some([0; MAX_RX_BUF_SIZE]);
// Sigh! to by-pass the borrow-checker telling us we are stealing a mutable reference
// from under the lock
// In this case the lock only protects against the setting of Some/None,
// the objects then are independently accessed in a unique way
let buffer = unsafe { &mut *(pool.buffers[i].as_mut().unwrap() as *mut Buffer) };
return Some((i, buffer));
}
}
None
}
pub fn free(index: usize) {
trace!("Buffer Free called\n");
let mut pool = BufferPool::get().lock().unwrap();
if pool.buffers[index].is_some() {
pool.buffers[index] = None;
}
}
}
#[derive(PartialEq)]
enum RxState { enum RxState {
Uninit, Uninit,
PlainDecode, PlainDecode,
@ -94,51 +48,95 @@ enum Direction<'a> {
Rx(ParseBuf<'a>, RxState), Rx(ParseBuf<'a>, RxState),
} }
impl<'a> Direction<'a> {
pub fn load(&mut self, direction: &Direction) -> Result<(), Error> {
if matches!(self, Self::Tx(_)) != matches!(direction, Direction::Tx(_)) {
Err(ErrorCode::Invalid)?;
}
match self {
Self::Tx(wb) => match direction {
Direction::Tx(src_wb) => wb.load(src_wb)?,
Direction::Rx(_, _) => Err(ErrorCode::Invalid)?,
},
Self::Rx(pb, state) => match direction {
Direction::Tx(_) => Err(ErrorCode::Invalid)?,
Direction::Rx(src_pb, src_state) => {
pb.load(src_pb)?;
*state = *src_state;
}
},
}
Ok(())
}
}
pub struct Packet<'a> { pub struct Packet<'a> {
pub plain: PlainHdr, pub plain: PlainHdr,
pub proto: ProtoHdr, pub proto: ProtoHdr,
pub peer: Address, pub peer: Address,
data: Direction<'a>, data: Direction<'a>,
buffer_index: usize,
} }
impl<'a> Packet<'a> { impl<'a> Packet<'a> {
const HDR_RESERVE: usize = plain_hdr::max_plain_hdr_len() + proto_hdr::max_proto_hdr_len(); const HDR_RESERVE: usize = plain_hdr::max_plain_hdr_len() + proto_hdr::max_proto_hdr_len();
pub fn new_rx() -> Result<Self, Error> { pub fn new_rx(buf: &'a mut [u8]) -> Self {
let (buffer_index, buffer) = BufferPool::alloc().ok_or(Error::NoSpace)?; Self {
let buf_len = buffer.len();
Ok(Self {
plain: Default::default(), plain: Default::default(),
proto: Default::default(), proto: Default::default(),
buffer_index,
peer: Address::default(), peer: Address::default(),
data: Direction::Rx(ParseBuf::new(buffer, buf_len), RxState::Uninit), data: Direction::Rx(ParseBuf::new(buf), RxState::Uninit),
}) }
} }
pub fn new_tx() -> Result<Self, Error> { pub fn new_tx(buf: &'a mut [u8]) -> Self {
let (buffer_index, buffer) = BufferPool::alloc().ok_or(Error::NoSpace)?; let mut wb = WriteBuf::new(buf);
let buf_len = buffer.len(); wb.reserve(Packet::HDR_RESERVE).unwrap();
let mut wb = WriteBuf::new(buffer, buf_len); // Reliability on by default
wb.reserve(Packet::HDR_RESERVE)?; let mut proto: ProtoHdr = Default::default();
proto.set_reliable();
let mut p = Self { Self {
plain: Default::default(), plain: Default::default(),
proto: Default::default(), proto,
buffer_index,
peer: Address::default(), peer: Address::default(),
data: Direction::Tx(wb), data: Direction::Tx(wb),
}; }
// Reliability on by default
p.proto.set_reliable();
Ok(p)
} }
pub fn as_borrow_slice(&mut self) -> &mut [u8] { pub fn reset(&mut self) {
if let Direction::Tx(wb) = &mut self.data {
wb.reset();
wb.reserve(Packet::HDR_RESERVE).unwrap();
self.plain = Default::default();
self.proto = Default::default();
self.peer = Address::default();
self.proto.set_reliable();
}
}
pub fn load(&mut self, packet: &Packet) -> Result<(), Error> {
self.plain = packet.plain.clone();
self.proto = packet.proto.clone();
self.peer = packet.peer;
self.data.load(&packet.data)
}
pub fn as_slice(&self) -> &[u8] {
match &self.data {
Direction::Rx(pb, _) => pb.as_slice(),
Direction::Tx(wb) => wb.as_slice(),
}
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
match &mut self.data { match &mut self.data {
Direction::Rx(pb, _) => pb.as_borrow_slice(), Direction::Rx(pb, _) => pb.as_mut_slice(),
Direction::Tx(wb) => wb.as_mut_slice(), Direction::Tx(wb) => wb.as_mut_slice(),
} }
} }
@ -147,7 +145,7 @@ impl<'a> Packet<'a> {
if let Direction::Rx(pbuf, _) = &mut self.data { if let Direction::Rx(pbuf, _) = &mut self.data {
Ok(pbuf) Ok(pbuf)
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
@ -155,7 +153,7 @@ impl<'a> Packet<'a> {
if let Direction::Tx(wbuf) = &mut self.data { if let Direction::Tx(wbuf) = &mut self.data {
Ok(wbuf) Ok(wbuf)
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
@ -167,10 +165,22 @@ impl<'a> Packet<'a> {
self.proto.proto_id = proto_id; self.proto.proto_id = proto_id;
} }
pub fn get_proto_opcode(&self) -> u8 { pub fn get_proto_opcode<T: num::FromPrimitive>(&self) -> Result<T, Error> {
num::FromPrimitive::from_u8(self.proto.proto_opcode).ok_or(ErrorCode::Invalid.into())
}
pub fn get_proto_raw_opcode(&self) -> u8 {
self.proto.proto_opcode self.proto.proto_opcode
} }
pub fn check_proto_opcode(&self, opcode: u8) -> Result<(), Error> {
if self.proto.proto_opcode == opcode {
Ok(())
} else {
Err(ErrorCode::Invalid.into())
}
}
pub fn set_proto_opcode(&mut self, proto_opcode: u8) { pub fn set_proto_opcode(&mut self, proto_opcode: u8) {
self.proto.proto_opcode = proto_opcode; self.proto.proto_opcode = proto_opcode;
} }
@ -196,20 +206,66 @@ impl<'a> Packet<'a> {
.decrypt_and_decode(&self.plain, pb, peer_nodeid, dec_key) .decrypt_and_decode(&self.plain, pb, peer_nodeid, dec_key)
} else { } else {
error!("Invalid state for proto_decode"); error!("Invalid state for proto_decode");
Err(Error::InvalidState) Err(ErrorCode::InvalidState.into())
} }
} }
_ => Err(Error::InvalidState), _ => Err(ErrorCode::InvalidState.into()),
} }
} }
pub fn proto_encode(
&mut self,
peer: Address,
peer_nodeid: Option<u64>,
local_nodeid: u64,
plain_text: bool,
enc_key: Option<&[u8]>,
) -> Result<(), Error> {
self.peer = peer;
// Generate encrypted header
let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()];
let mut write_buf = WriteBuf::new(&mut tmp_buf);
self.proto.encode(&mut write_buf)?;
self.get_writebuf()?.prepend(write_buf.as_slice())?;
// Generate plain-text header
if plain_text {
if let Some(d) = peer_nodeid {
self.plain.set_dest_u64(d);
}
}
let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()];
let mut write_buf = WriteBuf::new(&mut tmp_buf);
self.plain.encode(&mut write_buf)?;
let plain_hdr_bytes = write_buf.as_slice();
trace!("unencrypted packet: {:x?}", self.as_mut_slice());
let ctr = self.plain.ctr;
if let Some(e) = enc_key {
proto_hdr::encrypt_in_place(
ctr,
local_nodeid,
plain_hdr_bytes,
self.get_writebuf()?,
e,
)?;
}
self.get_writebuf()?.prepend(plain_hdr_bytes)?;
trace!("Full encrypted packet: {:x?}", self.as_mut_slice());
Ok(())
}
pub fn is_plain_hdr_decoded(&self) -> Result<bool, Error> { pub fn is_plain_hdr_decoded(&self) -> Result<bool, Error> {
match &self.data { match &self.data {
Direction::Rx(_, state) => match state { Direction::Rx(_, state) => match state {
RxState::Uninit => Ok(false), RxState::Uninit => Ok(false),
_ => Ok(true), _ => Ok(true),
}, },
_ => Err(Error::InvalidState), _ => Err(ErrorCode::InvalidState.into()),
} }
} }
@ -221,19 +277,50 @@ impl<'a> Packet<'a> {
self.plain.decode(pb) self.plain.decode(pb)
} else { } else {
error!("Invalid state for plain_decode"); error!("Invalid state for plain_decode");
Err(Error::InvalidState) Err(ErrorCode::InvalidState.into())
} }
} }
_ => Err(Error::InvalidState), _ => Err(ErrorCode::InvalidState.into()),
}
}
}
impl<'a> Drop for Packet<'a> {
fn drop(&mut self) {
BufferPool::free(self.buffer_index);
trace!("Dropping Packet......");
} }
} }
box_slab!(PacketPool, Packet<'static>, MAX_PACKET_POOL_SIZE); pub fn log(&self, operation: &str) {
match self.get_proto_id() {
PROTO_ID_SECURE_CHANNEL => {
if let Ok(opcode) = self.get_proto_opcode::<crate::secure_channel::common::OpCode>()
{
info!("{} SC:{:?}: ", operation.cyan(), opcode);
} else {
info!(
"{} SC:{}??: ",
operation.cyan(),
self.get_proto_raw_opcode()
);
}
tlv::print_tlv_list(self.as_slice());
}
PROTO_ID_INTERACTION_MODEL => {
if let Ok(opcode) =
self.get_proto_opcode::<crate::interaction_model::core::OpCode>()
{
info!("{} IM:{:?}: ", operation.cyan(), opcode);
} else {
info!(
"{} IM:{}??: ",
operation.cyan(),
self.get_proto_raw_opcode()
);
}
tlv::print_tlv_list(self.as_slice());
}
other => info!(
"{} {}??:{}??: ",
operation.cyan(),
other,
self.get_proto_raw_opcode()
),
}
}
}

View file

@ -0,0 +1,94 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex};
use crate::utils::select::Notification;
use super::network::Address;
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct Chunk {
pub start: usize,
pub end: usize,
pub addr: Address,
}
pub struct PipeData<'a> {
pub buf: &'a mut [u8],
pub chunk: Option<Chunk>,
}
pub struct Pipe<'a> {
pub data: Mutex<NoopRawMutex, PipeData<'a>>,
pub data_supplied_notification: Notification,
pub data_consumed_notification: Notification,
}
impl<'a> Pipe<'a> {
#[inline(always)]
pub fn new(buf: &'a mut [u8]) -> Self {
Self {
data: Mutex::new(PipeData { buf, chunk: None }),
data_supplied_notification: Notification::new(),
data_consumed_notification: Notification::new(),
}
}
pub async fn recv(&self, buf: &mut [u8]) -> (usize, Address) {
loop {
{
let mut data = self.data.lock().await;
if let Some(chunk) = data.chunk {
buf[..chunk.end - chunk.start]
.copy_from_slice(&data.buf[chunk.start..chunk.end]);
data.chunk = None;
self.data_consumed_notification.signal(());
return (chunk.end - chunk.start, chunk.addr);
}
}
self.data_supplied_notification.wait().await
}
}
pub async fn send(&self, addr: Address, buf: &[u8]) {
loop {
{
let mut data = self.data.lock().await;
if data.chunk.is_none() {
data.buf[..buf.len()].copy_from_slice(buf);
data.chunk = Some(Chunk {
start: 0,
end: buf.len(),
addr,
});
self.data_supplied_notification.signal(());
break;
}
}
self.data_consumed_notification.wait().await
}
}
}

View file

@ -21,18 +21,13 @@ use crate::utils::writebuf::WriteBuf;
use bitflags::bitflags; use bitflags::bitflags;
use log::info; use log::info;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq, Eq, Default, Copy, Clone)]
pub enum SessionType { pub enum SessionType {
#[default]
None, None,
Encrypted, Encrypted,
} }
impl Default for SessionType {
fn default() -> SessionType {
SessionType::None
}
}
bitflags! { bitflags! {
#[derive(Default)] #[derive(Default)]
pub struct MsgFlags: u8 { pub struct MsgFlags: u8 {
@ -43,7 +38,7 @@ bitflags! {
} }
// This is the unencrypted message // This is the unencrypted message
#[derive(Debug, Default)] #[derive(Debug, Default, Clone)]
pub struct PlainHdr { pub struct PlainHdr {
pub flags: MsgFlags, pub flags: MsgFlags,
pub sess_type: SessionType, pub sess_type: SessionType,
@ -70,7 +65,7 @@ impl PlainHdr {
impl PlainHdr { impl PlainHdr {
// it will have an additional 'message length' field first // it will have an additional 'message length' field first
pub fn decode(&mut self, msg: &mut ParseBuf) -> Result<(), Error> { pub fn decode(&mut self, msg: &mut ParseBuf) -> Result<(), Error> {
self.flags = MsgFlags::from_bits(msg.le_u8()?).ok_or(Error::Invalid)?; self.flags = MsgFlags::from_bits(msg.le_u8()?).ok_or(ErrorCode::Invalid)?;
self.sess_id = msg.le_u16()?; self.sess_id = msg.le_u16()?;
let _sec_flags = msg.le_u8()?; let _sec_flags = msg.le_u8()?;
self.sess_type = if self.sess_id != 0 { self.sess_type = if self.sess_id != 0 {

View file

@ -0,0 +1,41 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::error::Error;
use super::exchange::ExchangeCtx;
use super::packet::Packet;
/// This is the context in which a receive packet is being processed
pub struct ProtoCtx<'a, 'b> {
/// This is the exchange context, that includes the exchange and the session
pub exch_ctx: ExchangeCtx<'a>,
/// This is the received buffer for this transaction
pub rx: &'a Packet<'b>,
/// This is the transmit buffer for this transaction
pub tx: &'a mut Packet<'b>,
}
impl<'a, 'b> ProtoCtx<'a, 'b> {
pub fn new(exch_ctx: ExchangeCtx<'a>, rx: &'a Packet<'b>, tx: &'a mut Packet<'b>) -> Self {
Self { exch_ctx, rx, tx }
}
pub fn send(&mut self) -> Result<bool, Error> {
self.exch_ctx.exch.send(self.tx, &mut self.exch_ctx.sess)
}
}

Some files were not shown because too many files have changed in this diff Show more