Merge pull request #72 from ivmarkov/sequential-embassy-net

no_std + async support
This commit is contained in:
Kedar Sovani 2023-07-22 15:11:29 +05:30 committed by GitHub
commit 6bbac0b6e9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
130 changed files with 12677 additions and 8882 deletions

View file

@ -15,11 +15,11 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
matrix: matrix:
crypto-backend: ['crypto_openssl', 'crypto_rustcrypto', 'crypto_mbedtls'] crypto-backend: ['rustcrypto', 'mbedtls', 'openssl']
steps: steps:
- uses: actions/checkout@v2 - uses: actions/checkout@v2
- name: Build - name: Build
run: cd matter; cargo build --verbose --no-default-features --features ${{matrix.crypto-backend}} run: cd matter; cargo build --no-default-features --features ${{matrix.crypto-backend}}
- name: Run tests - name: Run tests
run: cd matter; cargo test --verbose --no-default-features --features ${{matrix.crypto-backend}} -- --test-threads=1 run: cd matter; cargo test --no-default-features --features os,${{matrix.crypto-backend}} -- --test-threads=1

1
.gitignore vendored
View file

@ -1,3 +1,4 @@
target target
Cargo.lock Cargo.lock
.vscode .vscode
.embuild

View file

@ -1,4 +1,17 @@
[workspace] [workspace]
members = ["matter", "matter_macro_derive", "boxslab", "tools/tlv_tool"] resolver = "2"
members = ["matter", "matter_macro_derive"]
exclude = ["examples/*"] exclude = ["examples/*", "tools/tlv_tool"]
# For compatibility with ESP IDF
[patch.crates-io]
polling = { git = "https://github.com/esp-rs-compat/polling" }
socket2 = { git = "https://github.com/esp-rs-compat/socket2" }
[profile.release]
opt-level = 3
[profile.dev]
debug = true
opt-level = 3

View file

@ -5,55 +5,44 @@
[![Test Linux (OpenSSL)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-openssl.yml/badge.svg)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-openssl.yml) [![Test Linux (OpenSSL)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-openssl.yml/badge.svg)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-openssl.yml)
[![Test Linux (mbedTLS)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-mbedtls.yml/badge.svg)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-mbedtls.yml) [![Test Linux (mbedTLS)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-mbedtls.yml/badge.svg)](https://github.com/project-chip/matter-rs/actions/workflows/test-linux-mbedtls.yml)
## Important Note
All development work is now ongoing in two other branches ([no_std](https://github.com/project-chip/matter-rs/tree/no_std) and [sequential](https://github.com/project-chip/matter-rs/tree/sequential) - explained below). The plan is one of these two branches to become the new `main`.
We highly encourage users to try out both of these branches (there is a working `onoff_light` example in both) and provide feedback.
### [no_std](https://github.com/project-chip/matter-rs/tree/no_std)
The purpose of this branch - as the name suggests - is to introduce `no_std` compatibility to the `matter-rs` library, so that it is possible to target constrained environments like MCUs which more often than not have no support for the Rust Standard library (threads, network sockets, filesystem and so on).
We have been successful in this endeavour. The library now only requires Rust `core` and runs on e.g. ESP32 baremental Rust targets.
When `matter-rs` is used on targets that do not support the Rust Standard Library, user is expected to provide the following:
- A `rand` function that can fill a `&[u8]` slice with random data
- An `epoch` function (a "current time" utility); note that since this utility is only used for measuring timeouts, it is OK to provide a function that e.g. measures elapsed millis since system boot, rather than something that tries to adhere to the UNIX epoch (1/1/1970)
- An MCU-specific UDP stack that the user would need to connect to the `matter-rs` library
Besides just having `no_std` compatibility, the `no_std` branch does not need an allocator. I.e. all structures internal to the `matter-rs` librarty are statically allocated.
Last but not least, the `no_std` branch by itself does **not** do any IO. In other words, it is "compute only" (as in, "give me a network packet and I'll produce one or more that you have to send; how you receive/send those is up to you"). Ditto for persisting fabrics and ACLs - it is up to the user to listen the matter stack for changes to those and persist.
### [sequential](https://github.com/project-chip/matter-rs/tree/sequential)
The `sequential` branch builds on top of the work implemented in the `no_std` branch by utilizing code implemented as `async` functions and methods. Committing to `async` has multiple benefits:
- (Internal for the library) We were able to turn several explicit state machines into implicit ones (after all, `async` is primarily about generating state machines automatically based on "sequential" user codee that uses the async/await language constructs - hence the name of the branch)
- (External, for the user) The ergonomics of the Exchange API in this branch (in other words, the "transport aspect of the Matter CSA spec) is much better, approaching that of dealing with regular TCP/IP sockets in the Rust Standard Library. This is only possible by utilizing async functions and methods, because - let's not forget - `matter-rs` needs to run on MCUs where native threading and task scheduling capabilities might not even exist, hence "sequentially-looking" request/response interaction can only be expressed asynchronously, or with explicit state machines.
- Certain pending concepts are much easier to implement via async functions and methods:
- Re-sending packets which were not acknowledged by the receiver yet (the MRP protocol as per the Matter spec)
- The "initiator" side of an exchange (think client clusters)
- This branch provides facilities to implement asynchronous read, write and invoke handling for server clusters, which is beneficial in certain scenarios (i.e. brdige devices)
The `async` metaphor however comes with a bit higher memory usage, due to not enough optimizations being implemented yet in the rust language when the async code is transpiled to state machines.
## Build ## Build
Building the library: ### Building the library
``` ```
$ cargo build $ cargo build
``` ```
Building the example: ### Building and running the example (Linux, MacOS X)
``` ```
$ RUST_LOG="matter" cargo run --example onoff_light $ cargo run --example onoff_light
``` ```
With the chip-tool (the current tool for testing Matter) use the Ethernet commissioning mechanism: ### Building the example (Espressif's ESP-IDF)
* Install all build prerequisites described [here](https://github.com/esp-rs/esp-idf-template#prerequisites)
* Build with the following command line:
```
export MCU=esp32; export CARGO_TARGET_XTENSA_ESP32_ESPIDF_LINKER=ldproxy; export RUSTFLAGS="-C default-linker-libraries"; export WIFI_SSID=ssid;export WIFI_PASS=pass; cargo build --example onoff_light --no-default-features --features esp-idf --target xtensa-esp32-espidf -Zbuild-std=std,panic_abort
```
* If you are building for a different Espressif MCU, change the `MCU` variable, the `xtensa-esp32-espidf` target and the name of the `CARGO_TARGET_<esp-idf-target-uppercase>_LINKER` variable to match your MCU and its Rust target. Available Espressif MCUs and targets are:
* esp32 / xtensa-esp32-espidf
* esp32s2 / xtensa-esp32s2-espidf
* esp32s3 / xtensa-esp32s3-espidf
* esp32c3 / riscv32imc-esp-espidf
* esp32c5 / riscv32imc-esp-espidf
* esp32c6 / risxcv32imac-esp-espidf
* Put in `WIFI_SSID` / `WIFI_PASS` the SSID & password for your wireless router
* Flash using the `espflash` utility described in the build prerequsites' link above
### Building the example (ESP32-XX baremetal or RP2040)
Coming soon!
## Test
With the `chip-tool` (the current tool for testing Matter) use the Ethernet commissioning mechanism:
``` ```
$ chip-tool pairing code 12344321 <Pairing-Code> $ chip-tool pairing code 12344321 <Pairing-Code>

View file

@ -1,9 +0,0 @@
[package]
name = "boxslab"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bitmaps={version="3.2.0", features=[]}

View file

@ -1,237 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use std::{
mem::MaybeUninit,
ops::{Deref, DerefMut},
sync::Mutex,
};
// TODO: why is max bitmap size 64 a correct max size? Could we match
// boxslabs instead or store used/not used inside the box slabs themselves?
const MAX_BITMAP_SIZE: usize = 64;
pub struct Bitmap {
inner: bitmaps::Bitmap<MAX_BITMAP_SIZE>,
max_size: usize,
}
impl Bitmap {
pub fn new(max_size: usize) -> Self {
assert!(max_size <= MAX_BITMAP_SIZE);
Bitmap {
inner: bitmaps::Bitmap::new(),
max_size,
}
}
pub fn set(&mut self, index: usize) -> bool {
assert!(index < self.max_size);
self.inner.set(index, true)
}
pub fn reset(&mut self, index: usize) -> bool {
assert!(index < self.max_size);
self.inner.set(index, false)
}
pub fn first_false_index(&self) -> Option<usize> {
match self.inner.first_false_index() {
Some(idx) if idx < self.max_size => Some(idx),
_ => None,
}
}
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn is_full(&self) -> bool {
self.first_false_index().is_none()
}
}
#[macro_export]
macro_rules! box_slab {
($name:ident,$t:ty,$v:expr) => {
use std::mem::MaybeUninit;
use std::sync::Once;
use $crate::{BoxSlab, Slab, SlabPool};
pub struct $name;
impl SlabPool for $name {
type SlabType = $t;
fn get_slab() -> &'static Slab<Self> {
const MAYBE_INIT: MaybeUninit<$t> = MaybeUninit::uninit();
static mut SLAB_POOL: [MaybeUninit<$t>; $v] = [MAYBE_INIT; $v];
static mut SLAB_SPACE: Option<Slab<$name>> = None;
static mut INIT: Once = Once::new();
unsafe {
INIT.call_once(|| {
SLAB_SPACE = Some(Slab::<$name>::init(&mut SLAB_POOL, $v));
});
SLAB_SPACE.as_ref().unwrap()
}
}
}
};
}
pub trait SlabPool {
type SlabType: 'static;
fn get_slab() -> &'static Slab<Self>
where
Self: Sized;
}
pub struct Inner<T: 'static + SlabPool> {
pool: &'static mut [MaybeUninit<T::SlabType>],
map: Bitmap,
}
// TODO: Instead of a mutex, we should replace this with a CAS loop
pub struct Slab<T: 'static + SlabPool>(Mutex<Inner<T>>);
impl<T: SlabPool> Slab<T> {
pub fn init(pool: &'static mut [MaybeUninit<T::SlabType>], size: usize) -> Self {
Self(Mutex::new(Inner {
pool,
map: Bitmap::new(size),
}))
}
pub fn try_new(new_object: T::SlabType) -> Option<BoxSlab<T>> {
let slab = T::get_slab();
let mut inner = slab.0.lock().unwrap();
if let Some(index) = inner.map.first_false_index() {
inner.map.set(index);
inner.pool[index].write(new_object);
let cell_ptr = unsafe { &mut *inner.pool[index].as_mut_ptr() };
Some(BoxSlab {
data: cell_ptr,
index,
})
} else {
None
}
}
pub fn free(&self, index: usize) {
let mut inner = self.0.lock().unwrap();
inner.map.reset(index);
let old_value = std::mem::replace(&mut inner.pool[index], MaybeUninit::uninit());
let _old_value = unsafe { old_value.assume_init() };
// This will drop the old_value
}
}
pub struct BoxSlab<T: 'static + SlabPool> {
// Because the data is a reference within the MaybeUninit, we don't have a mechanism
// to go out to the MaybeUninit from this reference. Hence this index
index: usize,
// TODO: We should figure out a way to get rid of the index too
data: &'static mut T::SlabType,
}
impl<T: 'static + SlabPool> Drop for BoxSlab<T> {
fn drop(&mut self) {
T::get_slab().free(self.index);
}
}
impl<T: SlabPool> Deref for BoxSlab<T> {
type Target = T::SlabType;
fn deref(&self) -> &Self::Target {
self.data
}
}
impl<T: SlabPool> DerefMut for BoxSlab<T> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.data
}
}
#[cfg(test)]
mod tests {
use std::{ops::Deref, sync::Arc};
pub struct Test {
val: Arc<u32>,
}
box_slab!(TestSlab, Test, 3);
#[test]
fn simple_alloc_free() {
{
let a = Slab::<TestSlab>::try_new(Test { val: Arc::new(10) }).unwrap();
assert_eq!(*a.val.deref(), 10);
let inner = TestSlab::get_slab().0.lock().unwrap();
assert!(!inner.map.is_empty());
}
// Validates that the 'Drop' got executed
let inner = TestSlab::get_slab().0.lock().unwrap();
assert!(inner.map.is_empty());
println!("Box Size {}", std::mem::size_of::<Box<Test>>());
println!("BoxSlab Size {}", std::mem::size_of::<BoxSlab<TestSlab>>());
}
#[test]
fn alloc_full_block() {
{
let a = Slab::<TestSlab>::try_new(Test { val: Arc::new(10) }).unwrap();
let b = Slab::<TestSlab>::try_new(Test { val: Arc::new(11) }).unwrap();
let c = Slab::<TestSlab>::try_new(Test { val: Arc::new(12) }).unwrap();
// Test that at overflow, we return None
assert!(Slab::<TestSlab>::try_new(Test { val: Arc::new(13) }).is_none(),);
assert_eq!(*b.val.deref(), 11);
{
let inner = TestSlab::get_slab().0.lock().unwrap();
// Test that the bitmap is marked as full
assert!(inner.map.is_full());
}
// Purposefully drop, to test that new allocation is possible
std::mem::drop(b);
let d = Slab::<TestSlab>::try_new(Test { val: Arc::new(21) }).unwrap();
assert_eq!(*d.val.deref(), 21);
// Ensure older allocations are still valid
assert_eq!(*a.val.deref(), 10);
assert_eq!(*c.val.deref(), 12);
}
// Validates that the 'Drop' got executed - test that the bitmap is empty
let inner = TestSlab::get_slab().0.lock().unwrap();
assert!(inner.map.is_empty());
}
#[test]
fn test_drop_logic() {
let root = Arc::new(10);
{
let _a = Slab::<TestSlab>::try_new(Test { val: root.clone() }).unwrap();
let _b = Slab::<TestSlab>::try_new(Test { val: root.clone() }).unwrap();
let _c = Slab::<TestSlab>::try_new(Test { val: root.clone() }).unwrap();
assert_eq!(Arc::strong_count(&root), 4);
}
// Test that Drop was correctly called on all the members of the pool
assert_eq!(Arc::strong_count(&root), 1);
}
}

View file

@ -16,7 +16,7 @@
*/ */
use matter::data_model::sdm::dev_att::{DataType, DevAttDataFetcher}; use matter::data_model::sdm::dev_att::{DataType, DevAttDataFetcher};
use matter::error::Error; use matter::error::{Error, ErrorCode};
pub struct HardCodedDevAtt {} pub struct HardCodedDevAtt {}
@ -159,7 +159,7 @@ impl DevAttDataFetcher for HardCodedDevAtt {
data.copy_from_slice(src); data.copy_from_slice(src);
Ok(src.len()) Ok(src.len())
} else { } else {
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
} }
} }

View file

@ -15,40 +15,365 @@
* limitations under the License. * limitations under the License.
*/ */
mod dev_att; use core::borrow::Borrow;
use matter::core::{self, CommissioningData}; use core::pin::pin;
use embassy_futures::select::select3;
use log::info;
use matter::core::{CommissioningData, Matter};
use matter::data_model::cluster_basic_information::BasicInfoConfig; use matter::data_model::cluster_basic_information::BasicInfoConfig;
use matter::data_model::device_types::device_type_add_on_off_light; use matter::data_model::cluster_on_off;
use matter::data_model::device_types::DEV_TYPE_ON_OFF_LIGHT;
use matter::data_model::objects::*;
use matter::data_model::root_endpoint;
use matter::data_model::system_model::descriptor;
use matter::error::Error;
use matter::mdns::{MdnsRunBuffers, MdnsService};
use matter::secure_channel::spake2p::VerifierData; use matter::secure_channel::spake2p::VerifierData;
use matter::transport::core::RunBuffers;
use matter::transport::network::{Ipv4Addr, Ipv6Addr, NetworkStack};
use matter::utils::select::EitherUnwrap;
fn main() { mod dev_att;
env_logger::init();
let comm_data = CommissioningData {
// TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456),
discriminator: 250,
};
// vid/pid should match those in the DAC #[cfg(feature = "std")]
let dev_info = BasicInfoConfig { fn main() -> Result<(), Error> {
let thread = std::thread::Builder::new()
.stack_size(150 * 1024)
.spawn(run)
.unwrap();
thread.join().unwrap()
}
// NOTE (no_std): For no_std, name this entry point according to your MCU platform
#[cfg(not(feature = "std"))]
#[no_mangle]
fn app_main() {
run().unwrap();
}
fn run() -> Result<(), Error> {
initialize_logger();
info!(
"Matter memory: mDNS={}, Matter={}, MdnsBuffers={}, RunBuffers={}",
core::mem::size_of::<MdnsService>(),
core::mem::size_of::<Matter>(),
core::mem::size_of::<MdnsRunBuffers>(),
core::mem::size_of::<RunBuffers>(),
);
let dev_det = BasicInfoConfig {
vid: 0xFFF1, vid: 0xFFF1,
pid: 0x8000, pid: 0x8000,
hw_ver: 2, hw_ver: 2,
sw_ver: 1, sw_ver: 1,
sw_ver_str: "1".to_string(), sw_ver_str: "1",
serial_no: "aabbccdd".to_string(), serial_no: "aabbccdd",
device_name: "OnOff Light".to_string(), device_name: "OnOff Light",
}; };
let dev_att = Box::new(dev_att::HardCodedDevAtt::new());
let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); let (ipv4_addr, ipv6_addr, interface) = initialize_network()?;
let dm = matter.get_data_model();
let dev_att = dev_att::HardCodedDevAtt::new();
#[cfg(feature = "std")]
let epoch = matter::utils::epoch::sys_epoch;
#[cfg(feature = "std")]
let rand = matter::utils::rand::sys_rand;
// NOTE (no_std): For no_std, provide your own function here
#[cfg(not(feature = "std"))]
let epoch = matter::utils::epoch::dummy_epoch;
// NOTE (no_std): For no_std, provide your own function here
#[cfg(not(feature = "std"))]
let rand = matter::utils::rand::dummy_rand;
let mdns = MdnsService::new(
0,
"matter-demo",
ipv4_addr.octets(),
Some((ipv6_addr.octets(), interface)),
&dev_det,
matter::MATTER_PORT,
);
info!("mDNS initialized");
let matter = Matter::new(
// vid/pid should match those in the DAC
&dev_det,
&dev_att,
&mdns,
epoch,
rand,
matter::MATTER_PORT,
);
info!("Matter initialized");
#[cfg(all(feature = "std", not(target_os = "espidf")))]
let mut psm = matter::persist::Psm::new(&matter, std::env::temp_dir().join("matter-iot"))?;
let handler = HandlerCompat(handler(&matter));
// When using a custom UDP stack, remove the network stack initialization below
// and call `Matter::run_piped()` instead, by utilizing the TX & RX `Pipe` structs
// to push/pull your UDP packets from/to the Matter stack.
// Ditto for `MdnsService`.
//
// When using the `embassy-net` feature (as opposed to the Rust Standard Library network stack),
// this initialization would be more complex.
let stack = NetworkStack::new();
let mut mdns_buffers = MdnsRunBuffers::new();
let mut mdns_runner = pin!(mdns.run(&stack, &mut mdns_buffers));
let mut buffers = RunBuffers::new();
let mut runner = matter.run(
&stack,
&mut buffers,
CommissioningData {
// TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456, *matter.borrow()),
discriminator: 250,
},
&handler,
);
info!(
"Matter transport runner memory: {}",
core::mem::size_of_val(&runner)
);
let mut runner = pin!(runner);
#[cfg(all(feature = "std", not(target_os = "espidf")))]
let mut psm_runner = pin!(psm.run());
#[cfg(not(all(feature = "std", not(target_os = "espidf"))))]
let mut psm_runner = pin!(core::future::pending());
let mut runner = select3(&mut runner, &mut mdns_runner, &mut psm_runner);
#[cfg(feature = "std")]
async_io::block_on(&mut runner).unwrap()?;
// NOTE (no_std): For no_std, replace with your own more efficient no_std executor,
// because the executor used below is a simple busy-loop poller
#[cfg(not(feature = "std"))]
embassy_futures::block_on(&mut runner).unwrap()?;
Ok(())
}
const NODE: Node<'static> = Node {
id: 0,
endpoints: &[
root_endpoint::endpoint(0),
Endpoint {
id: 1,
device_type: DEV_TYPE_ON_OFF_LIGHT,
clusters: &[descriptor::CLUSTER, cluster_on_off::CLUSTER],
},
],
};
fn handler<'a>(matter: &'a Matter<'a>) -> impl Metadata + NonBlockingHandler + 'a {
(
NODE,
root_endpoint::handler(0, matter)
.chain(
1,
descriptor::ID,
descriptor::DescriptorCluster::new(*matter.borrow()),
)
.chain(
1,
cluster_on_off::ID,
cluster_on_off::OnOffCluster::new(*matter.borrow()),
),
)
}
// NOTE (no_std): For no_std, implement here your own way of initializing the logger
#[cfg(all(not(feature = "std"), not(target_os = "espidf")))]
#[inline(never)]
fn initialize_logger() {}
// NOTE (no_std): For no_std, implement here your own way of initializing the network
#[cfg(all(not(feature = "std"), not(target_os = "espidf")))]
#[inline(never)]
fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> {
Ok((Ipv4Addr::UNSPECIFIED, Ipv6Addr::UNSPECIFIED, 0))
}
#[cfg(all(feature = "std", not(target_os = "espidf")))]
#[inline(never)]
fn initialize_logger() {
env_logger::init_from_env(
env_logger::Env::default().filter_or(env_logger::DEFAULT_FILTER_ENV, "info"),
);
}
#[cfg(all(feature = "std", not(target_os = "espidf")))]
#[inline(never)]
fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> {
use log::error;
use matter::error::ErrorCode;
use nix::{net::if_::InterfaceFlags, sys::socket::SockaddrIn6};
let interfaces = || {
nix::ifaddrs::getifaddrs().unwrap().filter(|ia| {
ia.flags
.contains(InterfaceFlags::IFF_UP | InterfaceFlags::IFF_BROADCAST)
&& !ia
.flags
.intersects(InterfaceFlags::IFF_LOOPBACK | InterfaceFlags::IFF_POINTOPOINT)
})
};
// A quick and dirty way to get a network interface that has a link-local IPv6 address assigned as well as a non-loopback IPv4
// Most likely, this is the interface we need
// (as opposed to all the docker and libvirt interfaces that might be assigned on the machine and which seem by default to be IPv4 only)
let (iname, ip, ipv6) = interfaces()
.filter_map(|ia| {
ia.address
.and_then(|addr| addr.as_sockaddr_in6().map(SockaddrIn6::ip))
.filter(|ip| ip.octets()[..2] == [0xfe, 0x80])
.map(|ipv6| (ia.interface_name, ipv6))
})
.filter_map(|(iname, ipv6)| {
interfaces()
.filter(|ia2| ia2.interface_name == iname)
.find_map(|ia2| {
ia2.address
.and_then(|addr| addr.as_sockaddr_in().map(|addr| addr.ip().into()))
.map(|ip| (iname.clone(), ip, ipv6))
})
})
.next()
.ok_or_else(|| {
error!("Cannot find network interface suitable for mDNS broadcasting");
ErrorCode::Network
})?;
info!(
"Will use network interface {} with {}/{} for mDNS",
iname, ip, ipv6
);
Ok((ip, ipv6, 0 as _))
}
#[cfg(target_os = "espidf")]
#[inline(never)]
fn initialize_logger() {
esp_idf_svc::log::EspLogger::initialize_default();
}
#[cfg(target_os = "espidf")]
#[inline(never)]
fn initialize_network() -> Result<(Ipv4Addr, Ipv6Addr, u32), Error> {
use core::time::Duration;
use embedded_svc::wifi::{AuthMethod, ClientConfiguration, Configuration};
use esp_idf_hal::prelude::Peripherals;
use esp_idf_svc::handle::RawHandle;
use esp_idf_svc::wifi::{BlockingWifi, EspWifi};
use esp_idf_svc::{eventloop::EspSystemEventLoop, nvs::EspDefaultNvsPartition};
use esp_idf_sys::{
self as _, esp, esp_ip6_addr_t, esp_netif_create_ip6_linklocal, esp_netif_get_ip6_linklocal,
}; // If using the `binstart` feature of `esp-idf-sys`, always keep this module imported
const SSID: &'static str = env!("WIFI_SSID");
const PASSWORD: &'static str = env!("WIFI_PASS");
#[allow(clippy::needless_update)]
{ {
let mut node = dm.node.write().unwrap(); // VFS is necessary for poll-based async IO
let endpoint = device_type_add_on_off_light(&mut node).unwrap(); esp_idf_sys::esp!(unsafe {
println!("Added OnOff Light Device type at endpoint id: {}", endpoint); esp_idf_sys::esp_vfs_eventfd_register(&esp_idf_sys::esp_vfs_eventfd_config_t {
println!("Data Model now is: {}", node); max_fds: 5,
..Default::default()
})
})?;
} }
matter.start_daemon().unwrap(); let peripherals = Peripherals::take().unwrap();
let sys_loop = EspSystemEventLoop::take()?;
let nvs = EspDefaultNvsPartition::take()?;
let mut wifi = EspWifi::new(peripherals.modem, sys_loop.clone(), Some(nvs))?;
let mut bwifi = BlockingWifi::wrap(&mut wifi, sys_loop)?;
let wifi_configuration: Configuration = Configuration::Client(ClientConfiguration {
ssid: SSID.into(),
bssid: None,
auth_method: AuthMethod::WPA2Personal,
password: PASSWORD.into(),
channel: None,
});
bwifi.set_configuration(&wifi_configuration)?;
bwifi.start()?;
info!("Wifi started");
bwifi.connect()?;
info!("Wifi connected");
esp!(unsafe {
esp_netif_create_ip6_linklocal(bwifi.wifi_mut().sta_netif_mut().handle() as _)
})?;
bwifi.wait_netif_up()?;
info!("Wifi netif up");
let ip_info = wifi.sta_netif().get_ip_info()?;
let mut ipv6: esp_ip6_addr_t = Default::default();
info!("Waiting for IPv6 address");
while esp!(unsafe { esp_netif_get_ip6_linklocal(wifi.sta_netif().handle() as _, &mut ipv6) })
.is_err()
{
info!("Waiting...");
std::thread::sleep(Duration::from_secs(2));
}
info!("Wifi DHCP info: {:?}, IPv6: {:?}", ip_info, ipv6.addr);
let ipv4_octets = ip_info.ip.octets();
let ipv6_octets = [
ipv6.addr[0].to_le_bytes()[0],
ipv6.addr[0].to_le_bytes()[1],
ipv6.addr[0].to_le_bytes()[2],
ipv6.addr[0].to_le_bytes()[3],
ipv6.addr[1].to_le_bytes()[0],
ipv6.addr[1].to_le_bytes()[1],
ipv6.addr[1].to_le_bytes()[2],
ipv6.addr[1].to_le_bytes()[3],
ipv6.addr[2].to_le_bytes()[0],
ipv6.addr[2].to_le_bytes()[1],
ipv6.addr[2].to_le_bytes()[2],
ipv6.addr[2].to_le_bytes()[3],
ipv6.addr[3].to_le_bytes()[0],
ipv6.addr[3].to_le_bytes()[1],
ipv6.addr[3].to_le_bytes()[2],
ipv6.addr[3].to_le_bytes()[3],
];
let interface = wifi.sta_netif().get_index();
// Not OK of course, but for a demo this is good enough
// Wifi will continue to be available and working in the background
core::mem::forget(wifi);
Ok((ipv4_octets.into(), ipv6_octets.into(), interface))
} }

View file

@ -159,7 +159,7 @@ impl DevAttDataFetcher for HardCodedDevAtt {
data.copy_from_slice(src); data.copy_from_slice(src);
Ok(src.len()) Ok(src.len())
} else { } else {
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
} }
} }

View file

@ -15,4 +15,4 @@
* limitations under the License. * limitations under the License.
*/ */
pub mod dev_att; // TODO pub mod dev_att;

View file

@ -15,55 +15,56 @@
* limitations under the License. * limitations under the License.
*/ */
mod dev_att; // TODO
use matter::core::{self, CommissioningData}; // mod dev_att;
use matter::data_model::cluster_basic_information::BasicInfoConfig; // use matter::core::{self, CommissioningData};
use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster}; // use matter::data_model::cluster_basic_information::BasicInfoConfig;
use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER; // use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster};
use matter::secure_channel::spake2p::VerifierData; // use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER;
// use matter::secure_channel::spake2p::VerifierData;
fn main() { fn main() {
env_logger::init(); // env_logger::init();
let comm_data = CommissioningData { // let comm_data = CommissioningData {
// TODO: Hard-coded for now // // TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456), // verifier: VerifierData::new_with_pw(123456),
discriminator: 250, // discriminator: 250,
}; // };
// vid/pid should match those in the DAC // // vid/pid should match those in the DAC
let dev_info = BasicInfoConfig { // let dev_info = BasicInfoConfig {
vid: 0xFFF1, // vid: 0xFFF1,
pid: 0x8002, // pid: 0x8002,
hw_ver: 2, // hw_ver: 2,
sw_ver: 1, // sw_ver: 1,
sw_ver_str: "1".to_string(), // sw_ver_str: "1".to_string(),
serial_no: "aabbccdd".to_string(), // serial_no: "aabbccdd".to_string(),
device_name: "Smart Speaker".to_string(), // device_name: "Smart Speaker".to_string(),
}; // };
let dev_att = Box::new(dev_att::HardCodedDevAtt::new()); // let dev_att = Box::new(dev_att::HardCodedDevAtt::new());
let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap(); // let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap();
let dm = matter.get_data_model(); // let dm = matter.get_data_model();
{ // {
let mut node = dm.node.write().unwrap(); // let mut node = dm.node.write().unwrap();
let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap(); // let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap();
let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap(); // let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap();
// Add some callbacks // // Add some callbacks
let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback.")); // let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback."));
let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback.")); // let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback."));
let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback.")); // let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback."));
let start_over_callback = // let start_over_callback =
Box::new(|| log::info!("Comamnd [StartOver] handled with callback.")); // Box::new(|| log::info!("Comamnd [StartOver] handled with callback."));
media_playback_cluster.add_callback(Commands::Play, play_callback); // media_playback_cluster.add_callback(Commands::Play, play_callback);
media_playback_cluster.add_callback(Commands::Pause, pause_callback); // media_playback_cluster.add_callback(Commands::Pause, pause_callback);
media_playback_cluster.add_callback(Commands::Stop, stop_callback); // media_playback_cluster.add_callback(Commands::Stop, stop_callback);
media_playback_cluster.add_callback(Commands::StartOver, start_over_callback); // media_playback_cluster.add_callback(Commands::StartOver, start_over_callback);
node.add_cluster(endpoint_audio, media_playback_cluster) // node.add_cluster(endpoint_audio, media_playback_cluster)
.unwrap(); // .unwrap();
println!("Added Speaker type at endpoint id: {}", endpoint_audio) // println!("Added Speaker type at endpoint id: {}", endpoint_audio)
} // }
matter.start_daemon().unwrap(); // matter.start_daemon().unwrap();
} }

View file

@ -0,0 +1,69 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
mod dev_att;
use matter::core::{self, CommissioningData};
use matter::data_model::cluster_basic_information::BasicInfoConfig;
use matter::data_model::cluster_media_playback::{Commands, MediaPlaybackCluster};
use matter::data_model::device_types::DEV_TYPE_ON_SMART_SPEAKER;
use matter::secure_channel::spake2p::VerifierData;
fn main() {
env_logger::init();
let comm_data = CommissioningData {
// TODO: Hard-coded for now
verifier: VerifierData::new_with_pw(123456),
discriminator: 250,
};
// vid/pid should match those in the DAC
let dev_info = BasicInfoConfig {
vid: 0xFFF1,
pid: 0x8002,
hw_ver: 2,
sw_ver: 1,
sw_ver_str: "1".to_string(),
serial_no: "aabbccdd".to_string(),
device_name: "Smart Speaker".to_string(),
};
let dev_att = Box::new(dev_att::HardCodedDevAtt::new());
let mut matter = core::Matter::new(dev_info, dev_att, comm_data).unwrap();
let dm = matter.get_data_model();
{
let mut node = dm.node.write().unwrap();
let endpoint_audio = node.add_endpoint(DEV_TYPE_ON_SMART_SPEAKER).unwrap();
let mut media_playback_cluster = MediaPlaybackCluster::new().unwrap();
// Add some callbacks
let play_callback = Box::new(|| log::info!("Comamnd [Play] handled with callback."));
let pause_callback = Box::new(|| log::info!("Comamnd [Pause] handled with callback."));
let stop_callback = Box::new(|| log::info!("Comamnd [Stop] handled with callback."));
let start_over_callback =
Box::new(|| log::info!("Comamnd [StartOver] handled with callback."));
media_playback_cluster.add_callback(Commands::Play, play_callback);
media_playback_cluster.add_callback(Commands::Pause, pause_callback);
media_playback_cluster.add_callback(Commands::Stop, stop_callback);
media_playback_cluster.add_callback(Commands::StartOver, start_over_callback);
node.add_cluster(endpoint_audio, media_playback_cluster)
.unwrap();
println!("Added Speaker type at endpoint id: {}", endpoint_audio)
}
matter.start_daemon().unwrap();
}

View file

@ -1,52 +1,70 @@
[package] [package]
name = "matter-iot" name = "matter-iot"
version = "0.1.0" version = "0.1.0"
edition = "2018" edition = "2021"
authors = ["Kedar Sovani <kedars@gmail.com>"] authors = ["Kedar Sovani <kedars@gmail.com>"]
description = "Native RUST implementation of the Matter (Smart-Home) ecosystem" description = "Native RUST implementation of the Matter (Smart-Home) ecosystem"
repository = "https://github.com/kedars/matter-rs" repository = "https://github.com/kedars/matter-rs"
readme = "README.md" readme = "README.md"
keywords = ["matter", "smart", "smart-home", "IoT", "ESP32"] keywords = ["matter", "smart", "smart-home", "IoT", "ESP32"]
categories = ["embedded", "network-programming"] categories = ["embedded", "network-programming"]
license = "MIT" license = "Apache-2.0"
[lib] [lib]
name = "matter" name = "matter"
path = "src/lib.rs" path = "src/lib.rs"
[features] [features]
default = ["crypto_mbedtls"] default = ["os", "mbedtls"]
crypto_openssl = ["openssl", "foreign-types", "hmac", "sha2"] os = ["std", "backtrace", "env_logger", "nix", "critical-section/std", "embassy-sync/std", "embassy-time/std"]
crypto_mbedtls = ["mbedtls"] esp-idf = ["std", "rustcrypto", "esp-idf-sys"]
crypto_esp_mbedtls = ["esp-idf-sys"] std = ["alloc", "rand", "qrcode", "async-io", "esp-idf-sys?/std", "embassy-time/generic-queue-16"]
crypto_rustcrypto = ["sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert"] backtrace = []
alloc = []
nightly = []
openssl = ["alloc", "dep:openssl", "foreign-types", "hmac", "sha2"]
mbedtls = ["alloc", "dep:mbedtls"]
rustcrypto = ["alloc", "sha2", "hmac", "pbkdf2", "hkdf", "aes", "ccm", "p256", "elliptic-curve", "crypto-bigint", "x509-cert", "rand_core"]
embassy-net = ["dep:embassy-net", "dep:embassy-net-driver", "smoltcp"]
[dependencies] [dependencies]
boxslab = { path = "../boxslab" }
matter_macro_derive = { path = "../matter_macro_derive" } matter_macro_derive = { path = "../matter_macro_derive" }
bitflags = "1.3" bitflags = { version = "1.3", default-features = false }
byteorder = "1.4.3" byteorder = { version = "1.4.3", default-features = false }
heapless = { version = "0.7.16", features = ["x86-sync-pool"] } heapless = "0.7.16"
generic-array = "0.14.6" num = { version = "0.4", default-features = false }
num = "0.4"
num-derive = "0.3.3" num-derive = "0.3.3"
num-traits = "0.2.15" num-traits = { version = "0.2.15", default-features = false }
strum = { version = "0.24", features = ["derive"], default-features = false }
log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] } log = { version = "0.4.17", features = ["max_level_debug", "release_max_level_debug"] }
env_logger = { version = "0.10.0", default-features = false, features = [] } no-std-net = "0.6"
rand = "0.8.5" subtle = { version = "2.4.1", default-features = false }
esp-idf-sys = { version = "0.32", features = ["binstart"], optional = true } safemem = { version = "0.3.3", default-features = false }
subtle = "2.4.1" owo-colors = "3"
colored = "2.0.0" time = { version = "0.3", default-features = false }
smol = "1.3.0" verhoeff = { version = "1", default-features = false }
owning_ref = "0.4.1" embassy-futures = "0.1"
safemem = "0.3.3" embassy-time = "0.1.1"
chrono = { version = "0.4.23", default-features = false, features = ["clock", "std"] } embassy-sync = "0.2"
async-channel = "1.8" critical-section = "1.1.1"
domain = { version = "0.7.2", default_features = false, features = ["heapless"] }
portable-atomic = "1"
# embassy-net dependencies
embassy-net = { version = "0.1", features = ["igmp", "proto-ipv6", "udp"], optional = true }
embassy-net-driver = { version = "0.1", optional = true }
smoltcp = { version = "0.10", default-features = false, optional = true }
# STD-only dependencies
rand = { version = "0.8.5", optional = true }
qrcode = { version = "0.12", default-features = false, optional = true } # Print QR code
async-io = { version = "=1.12", optional = true } # =1.12 for compatibility with ESP IDF
# crypto # crypto
openssl = { git = "https://github.com/sfackler/rust-openssl", optional = true } openssl = { version = "0.10.55", optional = true }
foreign-types = { version = "0.3.2", optional = true } foreign-types = { version = "0.3.2", optional = true }
mbedtls = { version = "0.9", optional = true }
# rust-crypto
sha2 = { version = "0.10", default-features = false, optional = true } sha2 = { version = "0.10", default-features = false, optional = true }
hmac = { version = "0.12", optional = true } hmac = { version = "0.12", optional = true }
pbkdf2 = { version = "0.12", optional = true } pbkdf2 = { version = "0.12", optional = true }
@ -56,28 +74,33 @@ ccm = { version = "0.5", default-features = false, features = ["alloc"], optiona
p256 = { version = "0.13.0", default-features = false, features = ["arithmetic", "ecdh", "ecdsa"], optional = true } p256 = { version = "0.13.0", default-features = false, features = ["arithmetic", "ecdh", "ecdsa"], optional = true }
elliptic-curve = { version = "0.13.2", optional = true } elliptic-curve = { version = "0.13.2", optional = true }
crypto-bigint = { version = "0.4", default-features = false, optional = true } crypto-bigint = { version = "0.4", default-features = false, optional = true }
# Note: requires std rand_core = { version = "0.6", default-features = false, optional = true }
x509-cert = { version = "0.2.0", default-features = false, features = ["pem", "std"], optional = true } x509-cert = { version = "0.2.0", default-features = false, features = ["pem"], optional = true } # TODO: requires `alloc`
# to compute the check digit
verhoeff = "1"
# print QR code
qrcode = { version = "0.12", default-features = false }
[target.'cfg(target_os = "macos")'.dependencies] [target.'cfg(target_os = "macos")'.dependencies]
astro-dnssd = "0.3" astro-dnssd = { version = "0.3" }
# MDNS support [target.'cfg(not(target_os = "espidf"))'.dependencies]
[target.'cfg(target_os = "linux")'.dependencies] mbedtls = { version = "0.9", optional = true }
lazy_static = "1.4.0" env_logger = { version = "0.10.0", optional = true }
libmdns = { version = "0.7.4" } nix = { version = "0.26", features = ["net"], optional = true }
[target.'cfg(target_os = "espidf")'.dependencies]
esp-idf-sys = { version = "0.33", optional = true, default-features = false, features = ["native"] }
[build-dependencies]
embuild = "0.31.2"
[target.'cfg(target_os = "espidf")'.dev-dependencies]
esp-idf-sys = { version = "0.33", default-features = false, features = ["binstart"] }
esp-idf-hal = { version = "0.41", features = ["embassy-sync", "critical-section"] }
esp-idf-svc = { version = "0.46", features = ["embassy-time-driver"] }
embedded-svc = { version = "0.25" }
[[example]] [[example]]
name = "onoff_light" name = "onoff_light"
path = "../examples/onoff_light/src/main.rs" path = "../examples/onoff_light/src/main.rs"
# [[example]]
[[example]] # name = "speaker"
name = "speaker" # path = "../examples/speaker/src/main.rs"
path = "../examples/speaker/src/main.rs"

11
matter/build.rs Normal file
View file

@ -0,0 +1,11 @@
use std::env::var;
// Necessary because of this issue: https://github.com/rust-lang/cargo/issues/9641
fn main() -> Result<(), Box<dyn std::error::Error>> {
if var("TARGET").unwrap().ends_with("-espidf") {
embuild::build::CfgArgs::output_propagated("ESP_IDF")?;
embuild::build::LinkArgs::output_propagated("ESP_IDF")?;
}
Ok(())
}

View file

@ -15,19 +15,15 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::{cell::RefCell, fmt::Display};
fmt::Display,
sync::{Arc, Mutex, MutexGuard, RwLock},
};
use crate::{ use crate::{
data_model::objects::{Access, ClusterId, EndptId, Privilege}, data_model::objects::{Access, ClusterId, EndptId, Privilege},
error::Error, error::{Error, ErrorCode},
fabric, fabric,
interaction_model::messages::GenericPath, interaction_model::messages::GenericPath,
sys::Psm, tlv::{self, FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV},
tlv::{FromTLV, TLVElement, TLVList, TLVWriter, TagType, ToTLV}, transport::session::{Session, SessionMode, MAX_CAT_IDS_PER_NOC},
transport::session::MAX_CAT_IDS_PER_NOC,
utils::writebuf::WriteBuf, utils::writebuf::WriteBuf,
}; };
use log::error; use log::error;
@ -54,7 +50,7 @@ impl FromTLV<'_> for AuthMode {
{ {
num::FromPrimitive::from_u32(t.u32()?) num::FromPrimitive::from_u32(t.u32()?)
.filter(|a| *a != AuthMode::Invalid) .filter(|a| *a != AuthMode::Invalid)
.ok_or(Error::Invalid) .ok_or_else(|| ErrorCode::Invalid.into())
} }
} }
@ -116,7 +112,7 @@ impl AccessorSubjects {
return Ok(()); return Ok(());
} }
} }
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
/// Match the match_subject with any of the current subjects /// Match the match_subject with any of the current subjects
@ -146,7 +142,7 @@ impl AccessorSubjects {
} }
impl Display for AccessorSubjects { impl Display for AccessorSubjects {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::result::Result<(), std::fmt::Error> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::result::Result<(), core::fmt::Error> {
write!(f, "[")?; write!(f, "[")?;
for i in self.0 { for i in self.0 {
if is_noc_cat(i) { if is_noc_cat(i) {
@ -160,7 +156,7 @@ impl Display for AccessorSubjects {
} }
/// The Accessor Object /// The Accessor Object
pub struct Accessor { pub struct Accessor<'a> {
/// The fabric index of the accessor /// The fabric index of the accessor
pub fab_idx: u8, pub fab_idx: u8,
/// Accessor's subject: could be node-id, NoC CAT, group id /// Accessor's subject: could be node-id, NoC CAT, group id
@ -168,15 +164,37 @@ pub struct Accessor {
/// The Authmode of this session /// The Authmode of this session
auth_mode: AuthMode, auth_mode: AuthMode,
// TODO: Is this the right place for this though, or should we just use a global-acl-handle-get // TODO: Is this the right place for this though, or should we just use a global-acl-handle-get
acl_mgr: Arc<AclMgr>, acl_mgr: &'a RefCell<AclMgr>,
} }
impl Accessor { impl<'a> Accessor<'a> {
pub fn new( pub fn for_session(session: &Session, acl_mgr: &'a RefCell<AclMgr>) -> Self {
match session.get_session_mode() {
SessionMode::Case(c) => {
let mut subject =
AccessorSubjects::new(session.get_peer_node_id().unwrap_or_default());
for i in c.cat_ids {
if i != 0 {
let _ = subject.add_catid(i);
}
}
Accessor::new(c.fab_idx, subject, AuthMode::Case, acl_mgr)
}
SessionMode::Pase => {
Accessor::new(0, AccessorSubjects::new(1), AuthMode::Pase, acl_mgr)
}
SessionMode::PlainText => {
Accessor::new(0, AccessorSubjects::new(1), AuthMode::Invalid, acl_mgr)
}
}
}
pub const fn new(
fab_idx: u8, fab_idx: u8,
subjects: AccessorSubjects, subjects: AccessorSubjects,
auth_mode: AuthMode, auth_mode: AuthMode,
acl_mgr: Arc<AclMgr>, acl_mgr: &'a RefCell<AclMgr>,
) -> Self { ) -> Self {
Self { Self {
fab_idx, fab_idx,
@ -188,9 +206,9 @@ impl Accessor {
} }
#[derive(Debug)] #[derive(Debug)]
pub struct AccessDesc<'a> { pub struct AccessDesc {
/// The object to be acted upon /// The object to be acted upon
path: &'a GenericPath, path: GenericPath,
/// The target permissions /// The target permissions
target_perms: Option<Access>, target_perms: Option<Access>,
// The operation being done // The operation being done
@ -200,8 +218,8 @@ pub struct AccessDesc<'a> {
/// Access Request Object /// Access Request Object
pub struct AccessReq<'a> { pub struct AccessReq<'a> {
accessor: &'a Accessor, accessor: &'a Accessor<'a>,
object: AccessDesc<'a>, object: AccessDesc,
} }
impl<'a> AccessReq<'a> { impl<'a> AccessReq<'a> {
@ -209,7 +227,7 @@ impl<'a> AccessReq<'a> {
/// ///
/// An access request specifies the _accessor_ attempting to access _path_ /// An access request specifies the _accessor_ attempting to access _path_
/// with _operation_ /// with _operation_
pub fn new(accessor: &'a Accessor, path: &'a GenericPath, operation: Access) -> Self { pub fn new(accessor: &'a Accessor, path: GenericPath, operation: Access) -> Self {
AccessReq { AccessReq {
accessor, accessor,
object: AccessDesc { object: AccessDesc {
@ -220,6 +238,10 @@ impl<'a> AccessReq<'a> {
} }
} }
pub fn operation(&self) -> Access {
self.object.operation
}
/// Add target's permissions to the request /// Add target's permissions to the request
/// ///
/// The permissions that are associated with the target (identified by the /// The permissions that are associated with the target (identified by the
@ -234,11 +256,11 @@ impl<'a> AccessReq<'a> {
/// _accessor_ the necessary privileges to access the target as per its /// _accessor_ the necessary privileges to access the target as per its
/// permissions /// permissions
pub fn allow(&self) -> bool { pub fn allow(&self) -> bool {
self.accessor.acl_mgr.allow(self) self.accessor.acl_mgr.borrow().allow(self)
} }
} }
#[derive(FromTLV, ToTLV, Copy, Clone, Debug, PartialEq)] #[derive(FromTLV, ToTLV, Clone, Debug, PartialEq)]
pub struct Target { pub struct Target {
cluster: Option<ClusterId>, cluster: Option<ClusterId>,
endpoint: Option<EndptId>, endpoint: Option<EndptId>,
@ -261,7 +283,7 @@ impl Target {
type Subjects = [Option<u64>; SUBJECTS_PER_ENTRY]; type Subjects = [Option<u64>; SUBJECTS_PER_ENTRY];
type Targets = [Option<Target>; TARGETS_PER_ENTRY]; type Targets = [Option<Target>; TARGETS_PER_ENTRY];
#[derive(ToTLV, FromTLV, Copy, Clone, Debug, PartialEq)] #[derive(ToTLV, FromTLV, Clone, Debug, PartialEq)]
#[tlvargs(start = 1)] #[tlvargs(start = 1)]
pub struct AclEntry { pub struct AclEntry {
privilege: Privilege, privilege: Privilege,
@ -292,7 +314,7 @@ impl AclEntry {
.subjects .subjects
.iter() .iter()
.position(|s| s.is_none()) .position(|s| s.is_none())
.ok_or(Error::NoSpace)?; .ok_or(ErrorCode::NoSpace)?;
self.subjects[index] = Some(subject); self.subjects[index] = Some(subject);
Ok(()) Ok(())
} }
@ -306,7 +328,7 @@ impl AclEntry {
.targets .targets
.iter() .iter()
.position(|s| s.is_none()) .position(|s| s.is_none())
.ok_or(Error::NoSpace)?; .ok_or(ErrorCode::NoSpace)?;
self.targets[index] = Some(target); self.targets[index] = Some(target);
Ok(()) Ok(())
} }
@ -367,35 +389,151 @@ impl AclEntry {
} }
const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS; const MAX_ACL_ENTRIES: usize = ENTRIES_PER_FABRIC * fabric::MAX_SUPPORTED_FABRICS;
type AclEntries = [Option<AclEntry>; MAX_ACL_ENTRIES];
#[derive(ToTLV, FromTLV, Debug)] type AclEntries = heapless::Vec<Option<AclEntry>, MAX_ACL_ENTRIES>;
struct AclMgrInner {
pub struct AclMgr {
entries: AclEntries, entries: AclEntries,
changed: bool,
} }
const ACL_KV_ENTRY: &str = "acl"; impl AclMgr {
const ACL_KV_MAX_SIZE: usize = 300; #[inline(always)]
impl AclMgrInner { pub const fn new() -> Self {
pub fn store(&self, psm: &MutexGuard<Psm>) -> Result<(), Error> { Self {
let mut acl_tlvs = [0u8; ACL_KV_MAX_SIZE]; entries: AclEntries::new(),
let mut wb = WriteBuf::new(&mut acl_tlvs, ACL_KV_MAX_SIZE); changed: false,
let mut tw = TLVWriter::new(&mut wb); }
self.entries.to_tlv(&mut tw, TagType::Anonymous)?;
psm.set_kv_slice(ACL_KV_ENTRY, wb.as_slice())
} }
pub fn load(psm: &MutexGuard<Psm>) -> Result<Self, Error> { pub fn erase_all(&mut self) -> Result<(), Error> {
let mut acl_tlvs = Vec::new(); self.entries.clear();
psm.get_kv_slice(ACL_KV_ENTRY, &mut acl_tlvs)?; self.changed = true;
let root = TLVList::new(&acl_tlvs)
Ok(())
}
pub fn add(&mut self, entry: AclEntry) -> Result<(), Error> {
let cnt = self
.entries
.iter() .iter()
.next() .flatten()
.ok_or(Error::Invalid)?; .filter(|a| a.fab_idx == entry.fab_idx)
.count();
if cnt >= ENTRIES_PER_FABRIC {
Err(ErrorCode::NoSpace)?;
}
Ok(Self { let slot = self.entries.iter().position(|a| a.is_none());
entries: AclEntries::from_tlv(&root)?,
}) if slot.is_some() || self.entries.len() < MAX_ACL_ENTRIES {
if let Some(index) = slot {
self.entries[index] = Some(entry);
} else {
self.entries
.push(Some(entry))
.map_err(|_| ErrorCode::NoSpace)
.unwrap();
}
self.changed = true;
}
Ok(())
}
// Since the entries are fabric-scoped, the index is only for entries with the matching fabric index
pub fn edit(&mut self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> {
let old = self.for_index_in_fabric(index, fab_idx)?;
*old = Some(new);
self.changed = true;
Ok(())
}
pub fn delete(&mut self, index: u8, fab_idx: u8) -> Result<(), Error> {
let old = self.for_index_in_fabric(index, fab_idx)?;
*old = None;
self.changed = true;
Ok(())
}
pub fn delete_for_fabric(&mut self, fab_idx: u8) -> Result<(), Error> {
for entry in &mut self.entries {
if entry
.as_ref()
.map(|e| e.fab_idx == Some(fab_idx))
.unwrap_or(false)
{
*entry = None;
self.changed = true;
}
}
Ok(())
}
pub fn for_each_acl<T>(&self, mut f: T) -> Result<(), Error>
where
T: FnMut(&AclEntry) -> Result<(), Error>,
{
for entry in self.entries.iter().flatten() {
f(entry)?;
}
Ok(())
}
pub fn allow(&self, req: &AccessReq) -> bool {
// PASE Sessions have implicit access grant
if req.accessor.auth_mode == AuthMode::Pase {
return true;
}
for e in self.entries.iter().flatten() {
if e.allow(req) {
return true;
}
}
error!(
"ACL Disallow for subjects {} fab idx {}",
req.accessor.subjects, req.accessor.fab_idx
);
error!("{}", self);
false
}
pub fn load(&mut self, data: &[u8]) -> Result<(), Error> {
let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?;
tlv::from_tlv(&mut self.entries, &root)?;
self.changed = false;
Ok(())
}
pub fn store<'a>(&mut self, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> {
if self.changed {
let mut wb = WriteBuf::new(buf);
let mut tw = TLVWriter::new(&mut wb);
self.entries
.as_slice()
.to_tlv(&mut tw, TagType::Anonymous)?;
self.changed = false;
let len = tw.get_tail();
Ok(Some(&buf[..len]))
} else {
Ok(None)
}
}
pub fn is_changed(&self) -> bool {
self.changed
} }
/// Traverse fabric specific entries to find the index /// Traverse fabric specific entries to find the index
@ -411,180 +549,25 @@ impl AclMgrInner {
for (curr_index, entry) in self for (curr_index, entry) in self
.entries .entries
.iter_mut() .iter_mut()
.filter(|e| e.filter(|e1| e1.fab_idx == Some(fab_idx)).is_some()) .filter(|e| {
e.as_ref()
.filter(|e1| e1.fab_idx == Some(fab_idx))
.is_some()
})
.enumerate() .enumerate()
{ {
if curr_index == index as usize { if curr_index == index as usize {
return Ok(entry); return Ok(entry);
} }
} }
Err(Error::NotFound) Err(ErrorCode::NotFound.into())
} }
} }
pub struct AclMgr { impl core::fmt::Display for AclMgr {
inner: RwLock<AclMgrInner>, fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
// The Option<> is solely because test execution is faster
// Doing this here adds the least overhead during ACL verification
psm: Option<Arc<Mutex<Psm>>>,
}
impl AclMgr {
pub fn new() -> Result<Self, Error> {
AclMgr::new_with(true)
}
pub fn new_with(psm_support: bool) -> Result<Self, Error> {
const INIT: Option<AclEntry> = None;
let mut psm = None;
let inner = if !psm_support {
AclMgrInner {
entries: [INIT; MAX_ACL_ENTRIES],
}
} else {
let psm_handle = Psm::get()?;
let inner = {
let psm_lock = psm_handle.lock().unwrap();
AclMgrInner::load(&psm_lock)
};
psm = Some(psm_handle);
inner.unwrap_or({
// Error loading from PSM
AclMgrInner {
entries: [INIT; MAX_ACL_ENTRIES],
}
})
};
Ok(Self {
inner: RwLock::new(inner),
psm,
})
}
pub fn erase_all(&self) {
let mut inner = self.inner.write().unwrap();
for i in 0..MAX_ACL_ENTRIES {
inner.entries[i] = None;
}
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
let _ = inner.store(&psm).map_err(|e| {
error!("Error in storing ACLs {}", e);
});
}
}
pub fn add(&self, entry: AclEntry) -> Result<(), Error> {
let mut inner = self.inner.write().unwrap();
let cnt = inner
.entries
.iter()
.flatten()
.filter(|a| a.fab_idx == entry.fab_idx)
.count();
if cnt >= ENTRIES_PER_FABRIC {
return Err(Error::NoSpace);
}
let index = inner
.entries
.iter()
.position(|a| a.is_none())
.ok_or(Error::NoSpace)?;
inner.entries[index] = Some(entry);
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
inner.store(&psm)
} else {
Ok(())
}
}
// Since the entries are fabric-scoped, the index is only for entries with the matching fabric index
pub fn edit(&self, index: u8, fab_idx: u8, new: AclEntry) -> Result<(), Error> {
let mut inner = self.inner.write().unwrap();
let old = inner.for_index_in_fabric(index, fab_idx)?;
*old = Some(new);
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
inner.store(&psm)
} else {
Ok(())
}
}
pub fn delete(&self, index: u8, fab_idx: u8) -> Result<(), Error> {
let mut inner = self.inner.write().unwrap();
let old = inner.for_index_in_fabric(index, fab_idx)?;
*old = None;
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
inner.store(&psm)
} else {
Ok(())
}
}
pub fn delete_for_fabric(&self, fab_idx: u8) -> Result<(), Error> {
let mut inner = self.inner.write().unwrap();
for i in 0..MAX_ACL_ENTRIES {
if inner.entries[i]
.filter(|e| e.fab_idx == Some(fab_idx))
.is_some()
{
inner.entries[i] = None;
}
}
if let Some(psm) = self.psm.as_ref() {
let psm = psm.lock().unwrap();
inner.store(&psm)
} else {
Ok(())
}
}
pub fn for_each_acl<T>(&self, mut f: T) -> Result<(), Error>
where
T: FnMut(&AclEntry),
{
let inner = self.inner.read().unwrap();
for entry in inner.entries.iter().flatten() {
f(entry)
}
Ok(())
}
pub fn allow(&self, req: &AccessReq) -> bool {
// PASE Sessions have implicit access grant
if req.accessor.auth_mode == AuthMode::Pase {
return true;
}
let inner = self.inner.read().unwrap();
for e in inner.entries.iter().flatten() {
if e.allow(req) {
return true;
}
}
error!(
"ACL Disallow for subjects {} fab idx {}",
req.accessor.subjects, req.accessor.fab_idx
);
error!("{}", self);
false
}
}
impl std::fmt::Display for AclMgr {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let inner = self.inner.read().unwrap();
write!(f, "ACLS: [")?; write!(f, "ACLS: [")?;
for i in inner.entries.iter().flatten() { for i in self.entries.iter().flatten() {
write!(f, " {{ {:?} }}, ", i)?; write!(f, " {{ {:?} }}, ", i)?;
} }
write!(f, "]") write!(f, "]")
@ -594,22 +577,23 @@ impl std::fmt::Display for AclMgr {
#[cfg(test)] #[cfg(test)]
#[allow(clippy::bool_assert_comparison)] #[allow(clippy::bool_assert_comparison)]
mod tests { mod tests {
use core::cell::RefCell;
use crate::{ use crate::{
acl::{gen_noc_cat, AccessorSubjects}, acl::{gen_noc_cat, AccessorSubjects},
data_model::objects::{Access, Privilege}, data_model::objects::{Access, Privilege},
interaction_model::messages::GenericPath, interaction_model::messages::GenericPath,
}; };
use std::sync::Arc;
use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target}; use super::{AccessReq, Accessor, AclEntry, AclMgr, AuthMode, Target};
#[test] #[test]
fn test_basic_empty_subject_target() { fn test_basic_empty_subject_target() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Default deny // Default deny
@ -617,46 +601,46 @@ mod tests {
// Deny for session mode mismatch // Deny for session mode mismatch
let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Pase); let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Pase);
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Deny for fab idx mismatch // Deny for fab idx mismatch
let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(1, Privilege::VIEW, AuthMode::Case);
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow // Allow
let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_subject() { fn test_subject() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Deny for subject mismatch // Deny for subject mismatch
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject(112232).unwrap(); new.add_subject(112232).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow for subject match - target is wildcard // Allow for subject match - target is wildcard
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_cat() { fn test_cat() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let allow_cat = 0xABCD; let allow_cat = 0xABCD;
let disallow_cat = 0xCAFE; let disallow_cat = 0xCAFE;
@ -666,35 +650,35 @@ mod tests {
let mut subjects = AccessorSubjects::new(112233); let mut subjects = AccessorSubjects::new(112233);
subjects.add_catid(gen_noc_cat(allow_cat, v2)).unwrap(); subjects.add_catid(gen_noc_cat(allow_cat, v2)).unwrap();
let accessor = Accessor::new(2, subjects, AuthMode::Case, am.clone()); let accessor = Accessor::new(2, subjects, AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Deny for CAT id mismatch // Deny for CAT id mismatch
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) new.add_subject_catid(gen_noc_cat(disallow_cat, v2))
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Deny of CAT version mismatch // Deny of CAT version mismatch
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(allow_cat, v3)).unwrap(); new.add_subject_catid(gen_noc_cat(allow_cat, v3)).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow for CAT match // Allow for CAT match
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_cat_version() { fn test_cat_version() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let allow_cat = 0xABCD; let allow_cat = 0xABCD;
let disallow_cat = 0xCAFE; let disallow_cat = 0xCAFE;
@ -704,32 +688,32 @@ mod tests {
let mut subjects = AccessorSubjects::new(112233); let mut subjects = AccessorSubjects::new(112233);
subjects.add_catid(gen_noc_cat(allow_cat, v3)).unwrap(); subjects.add_catid(gen_noc_cat(allow_cat, v3)).unwrap();
let accessor = Accessor::new(2, subjects, AuthMode::Case, am.clone()); let accessor = Accessor::new(2, subjects, AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Deny for CAT id mismatch // Deny for CAT id mismatch
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(disallow_cat, v2)) new.add_subject_catid(gen_noc_cat(disallow_cat, v2))
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow for CAT match and version more than ACL version // Allow for CAT match and version more than ACL version
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap(); new.add_subject_catid(gen_noc_cat(allow_cat, v2)).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_target() { fn test_target() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let mut req = AccessReq::new(&accessor, &path, Access::READ); let mut req = AccessReq::new(&accessor, path, Access::READ);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
// Deny for target mismatch // Deny for target mismatch
@ -740,7 +724,7 @@ mod tests {
device_type: None, device_type: None,
}) })
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
// Allow for cluster match - subject wildcard // Allow for cluster match - subject wildcard
@ -751,11 +735,11 @@ mod tests {
device_type: None, device_type: None,
}) })
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
// Clean Slate // Clean Slate
am.erase_all(); am.borrow_mut().erase_all().unwrap();
// Allow for endpoint match - subject wildcard // Allow for endpoint match - subject wildcard
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
@ -765,11 +749,11 @@ mod tests {
device_type: None, device_type: None,
}) })
.unwrap(); .unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
// Clean Slate // Clean Slate
am.erase_all(); am.borrow_mut().erase_all().unwrap();
// Allow for exact match // Allow for exact match
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
@ -780,16 +764,15 @@ mod tests {
}) })
.unwrap(); .unwrap();
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_privilege() { fn test_privilege() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let accessor = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone());
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
// Create an Exact Match ACL with View privilege // Create an Exact Match ACL with View privilege
@ -801,10 +784,10 @@ mod tests {
}) })
.unwrap(); .unwrap();
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
// Write on an RWVA without admin access - deny // Write on an RWVA without admin access - deny
let mut req = AccessReq::new(&accessor, &path, Access::WRITE); let mut req = AccessReq::new(&accessor, path.clone(), Access::WRITE);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
assert_eq!(req.allow(), false); assert_eq!(req.allow(), false);
@ -817,40 +800,40 @@ mod tests {
}) })
.unwrap(); .unwrap();
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
// Write on an RWVA with admin access - allow // Write on an RWVA with admin access - allow
let mut req = AccessReq::new(&accessor, &path, Access::WRITE); let mut req = AccessReq::new(&accessor, path, Access::WRITE);
req.set_target_perms(Access::RWVA); req.set_target_perms(Access::RWVA);
assert_eq!(req.allow(), true); assert_eq!(req.allow(), true);
} }
#[test] #[test]
fn test_delete_for_fabric() { fn test_delete_for_fabric() {
let am = Arc::new(AclMgr::new_with(false).unwrap()); let am = RefCell::new(AclMgr::new());
am.erase_all(); am.borrow_mut().erase_all().unwrap();
let path = GenericPath::new(Some(1), Some(1234), None); let path = GenericPath::new(Some(1), Some(1234), None);
let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor2 = Accessor::new(2, AccessorSubjects::new(112233), AuthMode::Case, &am);
let mut req2 = AccessReq::new(&accessor2, &path, Access::READ); let mut req2 = AccessReq::new(&accessor2, path.clone(), Access::READ);
req2.set_target_perms(Access::RWVA); req2.set_target_perms(Access::RWVA);
let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, am.clone()); let accessor3 = Accessor::new(3, AccessorSubjects::new(112233), AuthMode::Case, &am);
let mut req3 = AccessReq::new(&accessor3, &path, Access::READ); let mut req3 = AccessReq::new(&accessor3, path, Access::READ);
req3.set_target_perms(Access::RWVA); req3.set_target_perms(Access::RWVA);
// Allow for subject match - target is wildcard - Fabric idx 2 // Allow for subject match - target is wildcard - Fabric idx 2
let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
// Allow for subject match - target is wildcard - Fabric idx 3 // Allow for subject match - target is wildcard - Fabric idx 3
let mut new = AclEntry::new(3, Privilege::VIEW, AuthMode::Case); let mut new = AclEntry::new(3, Privilege::VIEW, AuthMode::Case);
new.add_subject(112233).unwrap(); new.add_subject(112233).unwrap();
am.add(new).unwrap(); am.borrow_mut().add(new).unwrap();
// Req for Fabric idx 2 gets denied, and that for Fabric idx 3 is allowed // Req for Fabric idx 2 gets denied, and that for Fabric idx 3 is allowed
assert_eq!(req2.allow(), true); assert_eq!(req2.allow(), true);
assert_eq!(req3.allow(), true); assert_eq!(req3.allow(), true);
am.delete_for_fabric(2).unwrap(); am.borrow_mut().delete_for_fabric(2).unwrap();
assert_eq!(req2.allow(), false); assert_eq!(req2.allow(), false);
assert_eq!(req3.allow(), true); assert_eq!(req3.allow(), true);
} }

View file

@ -15,10 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
use time::OffsetDateTime;
use super::{CertConsumer, MAX_DEPTH}; use super::{CertConsumer, MAX_DEPTH};
use crate::error::Error; use crate::{
use chrono::{Datelike, TimeZone, Utc}; error::{Error, ErrorCode},
use log::warn; utils::epoch::MATTER_EPOCH_SECS,
};
use core::fmt::Write;
#[derive(Debug)] #[derive(Debug)]
pub struct ASN1Writer<'a> { pub struct ASN1Writer<'a> {
@ -52,7 +56,7 @@ impl<'a> ASN1Writer<'a> {
self.offset += size; self.offset += size;
return Ok(()); return Ok(());
} }
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
pub fn append_tlv<F>(&mut self, tag: u8, len: usize, f: F) -> Result<(), Error> pub fn append_tlv<F>(&mut self, tag: u8, len: usize, f: F) -> Result<(), Error>
@ -68,7 +72,7 @@ impl<'a> ASN1Writer<'a> {
self.offset += len; self.offset += len;
return Ok(()); return Ok(());
} }
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
fn add_compound(&mut self, val: u8) -> Result<(), Error> { fn add_compound(&mut self, val: u8) -> Result<(), Error> {
@ -78,7 +82,7 @@ impl<'a> ASN1Writer<'a> {
self.depth[self.current_depth] = self.offset; self.depth[self.current_depth] = self.offset;
self.current_depth += 1; self.current_depth += 1;
if self.current_depth >= MAX_DEPTH { if self.current_depth >= MAX_DEPTH {
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} else { } else {
Ok(()) Ok(())
} }
@ -111,7 +115,7 @@ impl<'a> ASN1Writer<'a> {
fn end_compound(&mut self) -> Result<(), Error> { fn end_compound(&mut self) -> Result<(), Error> {
if self.current_depth == 0 { if self.current_depth == 0 {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
let seq_len = self.get_compound_len(); let seq_len = self.get_compound_len();
let write_offset = self.get_length_encoding_offset(); let write_offset = self.get_length_encoding_offset();
@ -146,7 +150,7 @@ impl<'a> ASN1Writer<'a> {
// This is done with an 0xA2 followed by 2 bytes of actual len // This is done with an 0xA2 followed by 2 bytes of actual len
3 3
} else { } else {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?
}; };
Ok(len) Ok(len)
} }
@ -261,28 +265,38 @@ impl<'a> CertConsumer for ASN1Writer<'a> {
} }
fn utctime(&mut self, _tag: &str, epoch: u32) -> Result<(), Error> { fn utctime(&mut self, _tag: &str, epoch: u32) -> Result<(), Error> {
let mut matter_epoch = Utc let matter_epoch = MATTER_EPOCH_SECS + epoch as u64;
.with_ymd_and_hms(2000, 1, 1, 0, 0, 0)
.unwrap()
.timestamp();
matter_epoch += epoch as i64; let dt = OffsetDateTime::from_unix_timestamp(matter_epoch as _).unwrap();
let dt = match Utc.timestamp_opt(matter_epoch, 0) { let mut time_str: heapless::String<32> = heapless::String::<32>::new();
chrono::LocalResult::None => return Err(Error::InvalidTime),
chrono::LocalResult::Single(s) => s,
chrono::LocalResult::Ambiguous(_, a) => {
warn!("Ambiguous time for epoch {epoch}; returning latest timestamp: {a}");
a
}
};
if dt.year() >= 2050 { if dt.year() >= 2050 {
// If year is >= 2050, ASN.1 requires it to be Generalised Time // If year is >= 2050, ASN.1 requires it to be Generalised Time
let time_str = format!("{}Z", dt.format("%Y%m%d%H%M%S")); write!(
&mut time_str,
"{:04}{:02}{:02}{:02}{:02}{:02}Z",
dt.year(),
dt.month() as u8,
dt.day(),
dt.hour(),
dt.minute(),
dt.second()
)
.unwrap();
self.write_str(0x18, time_str.as_bytes()) self.write_str(0x18, time_str.as_bytes())
} else { } else {
let time_str = format!("{}Z", dt.format("%y%m%d%H%M%S")); write!(
&mut time_str,
"{:02}{:02}{:02}{:02}{:02}{:02}Z",
dt.year() % 100,
dt.month() as u8,
dt.day(),
dt.hour(),
dt.minute(),
dt.second()
)
.unwrap();
self.write_str(0x17, time_str.as_bytes()) self.write_str(0x17, time_str.as_bytes())
} }
} }

View file

@ -15,12 +15,12 @@
* limitations under the License. * limitations under the License.
*/ */
use std::fmt; use core::fmt::{self, Write};
use crate::{ use crate::{
crypto::{CryptoKeyPair, KeyPair}, crypto::KeyPair,
error::Error, error::{Error, ErrorCode},
tlv::{self, FromTLV, TLVArrayOwned, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{self, FromTLV, OctetStr, TLVArray, TLVElement, TLVWriter, TagType, ToTLV},
utils::writebuf::WriteBuf, utils::writebuf::WriteBuf,
}; };
use log::error; use log::error;
@ -29,6 +29,8 @@ use num_derive::FromPrimitive;
pub use self::asn1_writer::ASN1Writer; pub use self::asn1_writer::ASN1Writer;
use self::printer::CertPrinter; use self::printer::CertPrinter;
pub const MAX_CERT_TLV_LEN: usize = 1024; // TODO
// As per https://datatracker.ietf.org/doc/html/rfc5280 // As per https://datatracker.ietf.org/doc/html/rfc5280
const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01]; const OID_PUB_KEY_ECPUBKEY: [u8; 7] = [0x2A, 0x86, 0x48, 0xCE, 0x3D, 0x02, 0x01];
@ -113,8 +115,10 @@ macro_rules! add_if {
}; };
} }
fn get_print_str(key_usage: u16) -> String { fn get_print_str(key_usage: u16) -> heapless::String<256> {
format!( let mut string = heapless::String::new();
write!(
&mut string,
"{}{}{}{}{}{}{}{}{}", "{}{}{}{}{}{}{}{}{}",
add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "), add_if!(key_usage, KEY_USAGE_DIGITAL_SIGN, "digitalSignature "),
add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "), add_if!(key_usage, KEY_USAGE_NON_REPUDIATION, "nonRepudiation "),
@ -126,6 +130,9 @@ fn get_print_str(key_usage: u16) -> String {
add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "), add_if!(key_usage, KEY_USAGE_ENCIPHER_ONLY, "encipherOnly "),
add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "), add_if!(key_usage, KEY_USAGE_DECIPHER_ONLY, "decipherOnly "),
) )
.unwrap();
string
} }
#[allow(unused_assignments)] #[allow(unused_assignments)]
@ -137,7 +144,7 @@ fn encode_key_usage(key_usage: u16, w: &mut dyn CertConsumer) -> Result<(), Erro
} }
fn encode_extended_key_usage( fn encode_extended_key_usage(
list: &TLVArrayOwned<u8>, list: impl Iterator<Item = u8>,
w: &mut dyn CertConsumer, w: &mut dyn CertConsumer,
) -> Result<(), Error> { ) -> Result<(), Error> {
const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01]; const OID_SERVER_AUTH: [u8; 8] = [0x2B, 0x06, 0x01, 0x05, 0x05, 0x07, 0x03, 0x01];
@ -157,19 +164,18 @@ fn encode_extended_key_usage(
]; ];
w.start_seq("")?; w.start_seq("")?;
for t in list.iter() { for t in list {
let t = *t as usize; let t = t as usize;
if t > 0 && t <= encoding.len() { if t > 0 && t <= encoding.len() {
w.oid(encoding[t].0, encoding[t].1)?; w.oid(encoding[t].0, encoding[t].1)?;
} else { } else {
error!("Skipping encoding key usage out of bounds"); error!("Skipping encoding key usage out of bounds");
} }
} }
w.end_seq()?; w.end_seq()
Ok(())
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1)] #[tlvargs(start = 1)]
struct BasicConstraints { struct BasicConstraints {
is_ca: bool, is_ca: bool,
@ -209,18 +215,18 @@ fn encode_extension_end(w: &mut dyn CertConsumer) -> Result<(), Error> {
w.end_seq() w.end_seq()
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1, datatype = "list")] #[tlvargs(lifetime = "'a", start = 1, datatype = "list")]
struct Extensions { struct Extensions<'a> {
basic_const: Option<BasicConstraints>, basic_const: Option<BasicConstraints>,
key_usage: Option<u16>, key_usage: Option<u16>,
ext_key_usage: Option<TLVArrayOwned<u8>>, ext_key_usage: Option<TLVArray<'a, u8>>,
subj_key_id: Option<Vec<u8>>, subj_key_id: Option<OctetStr<'a>>,
auth_key_id: Option<Vec<u8>>, auth_key_id: Option<OctetStr<'a>>,
future_extensions: Option<Vec<u8>>, future_extensions: Option<OctetStr<'a>>,
} }
impl Extensions { impl<'a> Extensions<'a> {
fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> { fn encode(&self, w: &mut dyn CertConsumer) -> Result<(), Error> {
const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13]; const OID_BASIC_CONSTRAINTS: [u8; 3] = [0x55, 0x1D, 0x13];
const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F]; const OID_KEY_USAGE: [u8; 3] = [0x55, 0x1D, 0x0F];
@ -242,30 +248,29 @@ impl Extensions {
} }
if let Some(t) = &self.ext_key_usage { if let Some(t) = &self.ext_key_usage {
encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?; encode_extension_start("X509v3 Extended Key Usage", true, &OID_EXT_KEY_USAGE, w)?;
encode_extended_key_usage(t, w)?; encode_extended_key_usage(t.iter(), w)?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.subj_key_id { if let Some(t) = &self.subj_key_id {
encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?; encode_extension_start("Subject Key ID", false, &OID_SUBJ_KEY_IDENTIFIER, w)?;
w.ostr("", t.as_slice())?; w.ostr("", t.0)?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.auth_key_id { if let Some(t) = &self.auth_key_id {
encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?; encode_extension_start("Auth Key ID", false, &OID_AUTH_KEY_ID, w)?;
w.start_seq("")?; w.start_seq("")?;
w.ctx("", 0, t.as_slice())?; w.ctx("", 0, t.0)?;
w.end_seq()?; w.end_seq()?;
encode_extension_end(w)?; encode_extension_end(w)?;
} }
if let Some(t) = &self.future_extensions { if let Some(t) = &self.future_extensions {
error!("Future Extensions Not Yet Supported: {:x?}", t.as_slice()) error!("Future Extensions Not Yet Supported: {:x?}", t.0);
} }
w.end_seq()?; w.end_seq()?;
w.end_ctx()?; w.end_ctx()?;
Ok(()) Ok(())
} }
} }
const MAX_DN_ENTRIES: usize = 5;
#[derive(FromPrimitive, Copy, Clone)] #[derive(FromPrimitive, Copy, Clone)]
enum DnTags { enum DnTags {
@ -293,20 +298,23 @@ enum DnTags {
NocCat = 22, NocCat = 22,
} }
enum DistNameValue { #[derive(Debug)]
enum DistNameValue<'a> {
Uint(u64), Uint(u64),
Utf8Str(Vec<u8>), Utf8Str(&'a [u8]),
PrintableStr(Vec<u8>), PrintableStr(&'a [u8]),
} }
#[derive(Default)] const MAX_DN_ENTRIES: usize = 5;
struct DistNames {
#[derive(Default, Debug)]
struct DistNames<'a> {
// The order in which the DNs arrive is important, as the signing // The order in which the DNs arrive is important, as the signing
// requires that the ASN1 notation retains the same order // requires that the ASN1 notation retains the same order
dn: Vec<(u8, DistNameValue)>, dn: heapless::Vec<(u8, DistNameValue<'a>), MAX_DN_ENTRIES>,
} }
impl DistNames { impl<'a> DistNames<'a> {
fn u64(&self, match_id: DnTags) -> Option<u64> { fn u64(&self, match_id: DnTags) -> Option<u64> {
self.dn self.dn
.iter() .iter()
@ -336,24 +344,27 @@ impl DistNames {
const PRINTABLE_STR_THRESHOLD: u8 = 0x80; const PRINTABLE_STR_THRESHOLD: u8 = 0x80;
impl<'a> FromTLV<'a> for DistNames { impl<'a> FromTLV<'a> for DistNames<'a> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> { fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> {
let mut d = Self { let mut d = Self {
dn: Vec::with_capacity(MAX_DN_ENTRIES), dn: heapless::Vec::new(),
}; };
let iter = t.confirm_list()?.enter().ok_or(Error::Invalid)?; let iter = t.confirm_list()?.enter().ok_or(ErrorCode::Invalid)?;
for t in iter { for t in iter {
if let TagType::Context(tag) = t.get_tag() { if let TagType::Context(tag) = t.get_tag() {
if let Ok(value) = t.u64() { if let Ok(value) = t.u64() {
d.dn.push((tag, DistNameValue::Uint(value))); d.dn.push((tag, DistNameValue::Uint(value)))
.map_err(|_| ErrorCode::BufferTooSmall)?;
} else if let Ok(value) = t.slice() { } else if let Ok(value) = t.slice() {
if tag > PRINTABLE_STR_THRESHOLD { if tag > PRINTABLE_STR_THRESHOLD {
d.dn.push(( d.dn.push((
tag - PRINTABLE_STR_THRESHOLD, tag - PRINTABLE_STR_THRESHOLD,
DistNameValue::PrintableStr(value.to_vec()), DistNameValue::PrintableStr(value),
)); ))
.map_err(|_| ErrorCode::BufferTooSmall)?;
} else { } else {
d.dn.push((tag, DistNameValue::Utf8Str(value.to_vec()))); d.dn.push((tag, DistNameValue::Utf8Str(value)))
.map_err(|_| ErrorCode::BufferTooSmall)?;
} }
} }
} }
@ -362,24 +373,23 @@ impl<'a> FromTLV<'a> for DistNames {
} }
} }
impl ToTLV for DistNames { impl<'a> ToTLV for DistNames<'a> {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.start_list(tag)?; tw.start_list(tag)?;
for (name, value) in &self.dn { for (name, value) in &self.dn {
match value { match value {
DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?, DistNameValue::Uint(v) => tw.u64(TagType::Context(*name), *v)?,
DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v.as_slice())?, DistNameValue::Utf8Str(v) => tw.utf8(TagType::Context(*name), v)?,
DistNameValue::PrintableStr(v) => tw.utf8( DistNameValue::PrintableStr(v) => {
TagType::Context(*name + PRINTABLE_STR_THRESHOLD), tw.utf8(TagType::Context(*name + PRINTABLE_STR_THRESHOLD), v)?
v.as_slice(), }
)?,
} }
} }
tw.end_container() tw.end_container()
} }
} }
impl DistNames { impl<'a> DistNames<'a> {
fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> { fn encode(&self, tag: &str, w: &mut dyn CertConsumer) -> Result<(), Error> {
const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03]; const OID_COMMON_NAME: [u8; 3] = [0x55_u8, 0x04, 0x03];
const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04]; const OID_SURNAME: [u8; 3] = [0x55_u8, 0x04, 0x04];
@ -509,52 +519,60 @@ fn encode_dn_value(
w.oid(name, oid)?; w.oid(name, oid)?;
match value { match value {
DistNameValue::Uint(v) => match expected_len { DistNameValue::Uint(v) => match expected_len {
Some(IntToStringLen::Len16) => w.utf8str("", format!("{:016X}", v).as_str())?, Some(IntToStringLen::Len16) => {
Some(IntToStringLen::Len8) => w.utf8str("", format!("{:08X}", v).as_str())?, let mut string = heapless::String::<32>::new();
write!(&mut string, "{:016X}", v).unwrap();
w.utf8str("", &string)?
}
Some(IntToStringLen::Len8) => {
let mut string = heapless::String::<32>::new();
write!(&mut string, "{:08X}", v).unwrap();
w.utf8str("", &string)?
}
_ => { _ => {
error!("Invalid encoding"); error!("Invalid encoding");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?
} }
}, },
DistNameValue::Utf8Str(v) => { DistNameValue::Utf8Str(v) => {
let str = String::from_utf8(v.to_vec())?; w.utf8str("", core::str::from_utf8(v)?)?;
w.utf8str("", &str)?;
} }
DistNameValue::PrintableStr(v) => { DistNameValue::PrintableStr(v) => {
let str = String::from_utf8(v.to_vec())?; w.printstr("", core::str::from_utf8(v)?)?;
w.printstr("", &str)?;
} }
} }
w.end_seq()?; w.end_seq()?;
w.end_set() w.end_set()
} }
#[derive(FromTLV, ToTLV, Default)] #[derive(FromTLV, ToTLV, Default, Debug)]
#[tlvargs(start = 1)] #[tlvargs(lifetime = "'a", start = 1)]
pub struct Cert { pub struct Cert<'a> {
serial_no: Vec<u8>, serial_no: OctetStr<'a>,
sign_algo: u8, sign_algo: u8,
issuer: DistNames, issuer: DistNames<'a>,
not_before: u32, not_before: u32,
not_after: u32, not_after: u32,
subject: DistNames, subject: DistNames<'a>,
pubkey_algo: u8, pubkey_algo: u8,
ec_curve_id: u8, ec_curve_id: u8,
pubkey: Vec<u8>, pubkey: OctetStr<'a>,
extensions: Extensions, extensions: Extensions<'a>,
signature: Vec<u8>, signature: OctetStr<'a>,
} }
// TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding // TODO: Instead of parsing the TLVs everytime, we should just cache this, but the encoding
// rules in terms of sequence may get complicated. Need to look into this // rules in terms of sequence may get complicated. Need to look into this
impl Cert { impl<'a> Cert<'a> {
pub fn new(cert_bin: &[u8]) -> Result<Self, Error> { pub fn new(cert_bin: &'a [u8]) -> Result<Self, Error> {
let root = tlv::get_root_node(cert_bin)?; let root = tlv::get_root_node(cert_bin)?;
Cert::from_tlv(&root) Cert::from_tlv(&root)
} }
pub fn get_node_id(&self) -> Result<u64, Error> { pub fn get_node_id(&self) -> Result<u64, Error> {
self.subject.u64(DnTags::NodeId).ok_or(Error::NoNodeId) self.subject
.u64(DnTags::NodeId)
.ok_or_else(|| Error::from(ErrorCode::NoNodeId))
} }
pub fn get_cat_ids(&self, output: &mut [u32]) { pub fn get_cat_ids(&self, output: &mut [u32]) {
@ -562,21 +580,27 @@ impl Cert {
} }
pub fn get_fabric_id(&self) -> Result<u64, Error> { pub fn get_fabric_id(&self) -> Result<u64, Error> {
self.subject.u64(DnTags::FabricId).ok_or(Error::NoFabricId) self.subject
.u64(DnTags::FabricId)
.ok_or_else(|| Error::from(ErrorCode::NoFabricId))
} }
pub fn get_pubkey(&self) -> &[u8] { pub fn get_pubkey(&self) -> &[u8] {
self.pubkey.as_slice() self.pubkey.0
} }
pub fn get_subject_key_id(&self) -> Result<&[u8], Error> { pub fn get_subject_key_id(&self) -> Result<&[u8], Error> {
self.extensions.subj_key_id.as_deref().ok_or(Error::Invalid) if let Some(id) = self.extensions.subj_key_id.as_ref() {
Ok(id.0)
} else {
Err(ErrorCode::Invalid.into())
}
} }
pub fn is_authority(&self, their: &Cert) -> Result<bool, Error> { pub fn is_authority(&self, their: &Cert) -> Result<bool, Error> {
if let Some(our_auth_key) = &self.extensions.auth_key_id { if let Some(our_auth_key) = &self.extensions.auth_key_id {
let their_subject = their.get_subject_key_id()?; let their_subject = their.get_subject_key_id()?;
if our_auth_key == their_subject { if our_auth_key.0 == their_subject {
Ok(true) Ok(true)
} else { } else {
Ok(false) Ok(false)
@ -587,11 +611,11 @@ impl Cert {
} }
pub fn get_signature(&self) -> &[u8] { pub fn get_signature(&self) -> &[u8] {
self.signature.as_slice() self.signature.0
} }
pub fn as_tlv(&self, buf: &mut [u8]) -> Result<usize, Error> { pub fn as_tlv(&self, buf: &mut [u8]) -> Result<usize, Error> {
let mut wb = WriteBuf::new(buf, buf.len()); let mut wb = WriteBuf::new(buf);
let mut tw = TLVWriter::new(&mut wb); let mut tw = TLVWriter::new(&mut wb);
self.to_tlv(&mut tw, TagType::Anonymous)?; self.to_tlv(&mut tw, TagType::Anonymous)?;
Ok(wb.as_slice().len()) Ok(wb.as_slice().len())
@ -614,10 +638,10 @@ impl Cert {
w.integer("", &[2])?; w.integer("", &[2])?;
w.end_ctx()?; w.end_ctx()?;
w.integer("Serial Num:", self.serial_no.as_slice())?; w.integer("Serial Num:", self.serial_no.0)?;
w.start_seq("Signature Algorithm:")?; w.start_seq("Signature Algorithm:")?;
let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(Error::Invalid)? { let (str, oid) = match get_sign_algo(self.sign_algo).ok_or(ErrorCode::Invalid)? {
SignAlgoValue::ECDSAWithSHA256 => ("ECDSA with SHA256", OID_ECDSA_WITH_SHA256), SignAlgoValue::ECDSAWithSHA256 => ("ECDSA with SHA256", OID_ECDSA_WITH_SHA256),
}; };
w.oid(str, &oid)?; w.oid(str, &oid)?;
@ -634,17 +658,17 @@ impl Cert {
w.start_seq("")?; w.start_seq("")?;
w.start_seq("Public Key Algorithm")?; w.start_seq("Public Key Algorithm")?;
let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(Error::Invalid)? { let (str, pub_key) = match get_pubkey_algo(self.pubkey_algo).ok_or(ErrorCode::Invalid)? {
PubKeyAlgoValue::EcPubKey => ("ECPubKey", OID_PUB_KEY_ECPUBKEY), PubKeyAlgoValue::EcPubKey => ("ECPubKey", OID_PUB_KEY_ECPUBKEY),
}; };
w.oid(str, &pub_key)?; w.oid(str, &pub_key)?;
let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(Error::Invalid)? { let (str, curve_id) = match get_ec_curve_id(self.ec_curve_id).ok_or(ErrorCode::Invalid)? {
EcCurveIdValue::Prime256V1 => ("Prime256v1", OID_EC_TYPE_PRIME256V1), EcCurveIdValue::Prime256V1 => ("Prime256v1", OID_EC_TYPE_PRIME256V1),
}; };
w.oid(str, &curve_id)?; w.oid(str, &curve_id)?;
w.end_seq()?; w.end_seq()?;
w.bitstr("Public-Key:", false, self.pubkey.as_slice())?; w.bitstr("Public-Key:", false, self.pubkey.0)?;
w.end_seq()?; w.end_seq()?;
self.extensions.encode(w)?; self.extensions.encode(w)?;
@ -655,7 +679,7 @@ impl Cert {
} }
} }
impl fmt::Display for Cert { impl<'a> fmt::Display for Cert<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut printer = CertPrinter::new(f); let mut printer = CertPrinter::new(f);
let _ = self let _ = self
@ -667,7 +691,7 @@ impl fmt::Display for Cert {
} }
pub struct CertVerifier<'a> { pub struct CertVerifier<'a> {
cert: &'a Cert, cert: &'a Cert<'a>,
} }
impl<'a> CertVerifier<'a> { impl<'a> CertVerifier<'a> {
@ -677,7 +701,7 @@ impl<'a> CertVerifier<'a> {
pub fn add_cert(self, parent: &'a Cert) -> Result<CertVerifier<'a>, Error> { pub fn add_cert(self, parent: &'a Cert) -> Result<CertVerifier<'a>, Error> {
if !self.cert.is_authority(parent)? { if !self.cert.is_authority(parent)? {
return Err(Error::InvalidAuthKey); Err(ErrorCode::InvalidAuthKey)?;
} }
let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE]; let mut asn1 = [0u8; MAX_ASN1_CERT_SIZE];
let len = self.cert.as_asn1(&mut asn1)?; let len = self.cert.as_asn1(&mut asn1)?;
@ -731,8 +755,9 @@ mod printer;
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use log::info;
use crate::cert::Cert; use crate::cert::Cert;
use crate::error::Error;
use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV}; use crate::tlv::{self, FromTLV, TLVWriter, TagType, ToTLV};
use crate::utils::writebuf::WriteBuf; use crate::utils::writebuf::WriteBuf;
@ -777,29 +802,41 @@ mod tests {
#[test] #[test]
fn test_verify_chain_incomplete() { fn test_verify_chain_incomplete() {
// The chain doesn't lead up to a self-signed certificate // The chain doesn't lead up to a self-signed certificate
use crate::error::ErrorCode;
let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap(); let noc = Cert::new(&test_vectors::NOC1_SUCCESS).unwrap();
let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap();
let a = noc.verify_chain_start(); let a = noc.verify_chain_start();
assert_eq!( assert_eq!(
Err(Error::InvalidAuthKey), Err(ErrorCode::InvalidAuthKey),
a.add_cert(&icac).unwrap().finalise() a.add_cert(&icac).unwrap().finalise().map_err(|e| e.code())
); );
} }
#[test] #[test]
fn test_auth_key_chain_incorrect() { fn test_auth_key_chain_incorrect() {
use crate::error::ErrorCode;
let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap(); let noc = Cert::new(&test_vectors::NOC1_AUTH_KEY_FAIL).unwrap();
let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap();
let a = noc.verify_chain_start(); let a = noc.verify_chain_start();
assert_eq!(Err(Error::InvalidAuthKey), a.add_cert(&icac).map(|_| ())); assert_eq!(
Err(ErrorCode::InvalidAuthKey),
a.add_cert(&icac).map(|_| ()).map_err(|e| e.code())
);
} }
#[test] #[test]
fn test_cert_corrupted() { fn test_cert_corrupted() {
use crate::error::ErrorCode;
let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap(); let noc = Cert::new(&test_vectors::NOC1_CORRUPT_CERT).unwrap();
let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap(); let icac = Cert::new(&test_vectors::ICAC1_SUCCESS).unwrap();
let a = noc.verify_chain_start(); let a = noc.verify_chain_start();
assert_eq!(Err(Error::InvalidSignature), a.add_cert(&icac).map(|_| ())); assert_eq!(
Err(ErrorCode::InvalidSignature),
a.add_cert(&icac).map(|_| ()).map_err(|e| e.code())
);
} }
#[test] #[test]
@ -811,12 +848,11 @@ mod tests {
]; ];
for input in test_input.iter() { for input in test_input.iter() {
println!("Testing next input..."); info!("Testing next input...");
let root = tlv::get_root_node(input).unwrap(); let root = tlv::get_root_node(input).unwrap();
let cert = Cert::from_tlv(&root).unwrap(); let cert = Cert::from_tlv(&root).unwrap();
let mut buf = [0u8; 1024]; let mut buf = [0u8; 1024];
let buf_len = buf.len(); let mut wb = WriteBuf::new(&mut buf);
let mut wb = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut wb); let mut tw = TLVWriter::new(&mut wb);
cert.to_tlv(&mut tw, TagType::Anonymous).unwrap(); cert.to_tlv(&mut tw, TagType::Anonymous).unwrap();
assert_eq!(*input, wb.as_slice()); assert_eq!(*input, wb.as_slice());

View file

@ -15,11 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use time::OffsetDateTime;
use super::{CertConsumer, MAX_DEPTH}; use super::{CertConsumer, MAX_DEPTH};
use crate::error::Error; use crate::{error::Error, utils::epoch::MATTER_EPOCH_SECS};
use chrono::{TimeZone, Utc}; use core::fmt;
use log::warn;
use std::fmt;
pub struct CertPrinter<'a, 'b> { pub struct CertPrinter<'a, 'b> {
level: usize, level: usize,
@ -123,23 +123,11 @@ impl<'a, 'b> CertConsumer for CertPrinter<'a, 'b> {
Ok(()) Ok(())
} }
fn utctime(&mut self, tag: &str, epoch: u32) -> Result<(), Error> { fn utctime(&mut self, tag: &str, epoch: u32) -> Result<(), Error> {
let mut matter_epoch = Utc let matter_epoch = MATTER_EPOCH_SECS + epoch as u64;
.with_ymd_and_hms(2000, 1, 1, 0, 0, 0)
.unwrap()
.timestamp();
matter_epoch += epoch as i64; let dt = OffsetDateTime::from_unix_timestamp(matter_epoch as _).unwrap();
let dt = match Utc.timestamp_opt(matter_epoch, 0) { let _ = writeln!(self.f, "{} {} {:?}", SPACE[self.level], tag, dt);
chrono::LocalResult::None => return Err(Error::InvalidTime),
chrono::LocalResult::Single(s) => s,
chrono::LocalResult::Ambiguous(_, a) => {
warn!("Ambiguous time for epoch {epoch}; returning latest timestamp: {a}");
a
}
};
let _ = writeln!(self.f, "{} {} {}", SPACE[self.level], tag, dt);
Ok(()) Ok(())
} }
} }

View file

@ -17,7 +17,7 @@
//! Base38 encoding and decoding functions. //! Base38 encoding and decoding functions.
use crate::error::Error; use crate::error::{Error, ErrorCode};
const BASE38_CHARS: [char; 38] = [ const BASE38_CHARS: [char; 38] = [
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I',
@ -77,60 +77,68 @@ const DECODE_BASE38: [u8; 46] = [
35, // 'Z', =90 35, // 'Z', =90
]; ];
const BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK: [u8; 3] = [2, 4, 5];
const RADIX: u32 = BASE38_CHARS.len() as u32; const RADIX: u32 = BASE38_CHARS.len() as u32;
/// Encode a byte array into a base38 string. /// Encode a byte array into a base38 string.
/// ///
/// # Arguments /// # Arguments
/// * `bytes` - byte array to encode /// * `bytes` - byte array to encode
/// * `length` - optional length of the byte array to encode. If not specified, the entire byte array is encoded. pub fn encode_string<const N: usize>(bytes: &[u8]) -> Result<heapless::String<N>, Error> {
pub fn encode(bytes: &[u8], length: Option<usize>) -> String { let mut string = heapless::String::new();
let mut offset = 0; for c in encode(bytes) {
let mut result = String::new(); string.push(c).map_err(|_| ErrorCode::NoSpace)?;
}
// if length is specified, use it, otherwise use the length of the byte array Ok(string)
// if length is specified but is greater than the length of the byte array, use the length of the byte array }
let b_len = bytes.len();
let length = length.map(|l| l.min(b_len)).unwrap_or(b_len);
while offset < length { pub fn encode(bytes: &[u8]) -> impl Iterator<Item = char> + '_ {
let remaining = length - offset; (0..bytes.len() / 3)
match remaining.cmp(&2) { .flat_map(move |index| {
std::cmp::Ordering::Greater => { let offset = index * 3;
result.push_str(&encode_base38(
encode_base38(
((bytes[offset + 2] as u32) << 16) ((bytes[offset + 2] as u32) << 16)
| ((bytes[offset + 1] as u32) << 8) | ((bytes[offset + 1] as u32) << 8)
| (bytes[offset] as u32), | (bytes[offset] as u32),
5, 5,
)); )
offset += 3; })
} .chain(
std::cmp::Ordering::Equal => { core::iter::once(bytes.len() % 3).flat_map(move |remainder| {
result.push_str(&encode_base38( let offset = bytes.len() / 3 * 3;
match remainder {
2 => encode_base38(
((bytes[offset + 1] as u32) << 8) | (bytes[offset] as u32), ((bytes[offset + 1] as u32) << 8) | (bytes[offset] as u32),
4, 4,
)); ),
break; 1 => encode_base38(bytes[offset] as u32, 2),
} _ => encode_base38(0, 0),
std::cmp::Ordering::Less => {
result.push_str(&encode_base38(bytes[offset] as u32, 2));
break;
}
} }
}),
)
} }
result fn encode_base38(mut value: u32, repeat: usize) -> impl Iterator<Item = char> {
(0..repeat).map(move |_| {
let remainder = value % RADIX;
let c = BASE38_CHARS[remainder as usize];
value = (value - remainder) / RADIX;
c
})
} }
fn encode_base38(mut value: u32, char_count: u8) -> String { pub fn decode_vec<const N: usize>(base38_str: &str) -> Result<heapless::Vec<u8, N>, Error> {
let mut result = String::new(); let mut vec = heapless::Vec::new();
for _ in 0..char_count {
let remainder = value % 38; for byte in decode(base38_str) {
result.push(BASE38_CHARS[remainder as usize]); vec.push(byte?).map_err(|_| ErrorCode::NoSpace)?;
value = (value - remainder) / 38;
} }
result
Ok(vec)
} }
/// Decode a base38-encoded string into a byte slice /// Decode a base38-encoded string into a byte slice
@ -138,64 +146,71 @@ fn encode_base38(mut value: u32, char_count: u8) -> String {
/// # Arguments /// # Arguments
/// * `base38_str` - base38-encoded string to decode /// * `base38_str` - base38-encoded string to decode
/// ///
/// Fails if the string contains invalid characters /// Fails if the string contains invalid characters or if the supplied buffer is too small to fit the decoded data
pub fn decode(base38_str: &str) -> Result<Vec<u8>, Error> { pub fn decode(base38_str: &str) -> impl Iterator<Item = Result<u8, Error>> + '_ {
let mut result = Vec::new(); let stru = base38_str.as_bytes();
let mut base38_characters_number: usize = base38_str.len();
let mut decoded_base38_characters: usize = 0;
while base38_characters_number > 0 { (0..stru.len() / 5)
let base38_characters_in_chunk: usize; .flat_map(move |index| {
let bytes_in_decoded_chunk: usize; let offset = index * 5;
decode_base38(&stru[offset..offset + 5])
if base38_characters_number >= BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[2] as usize { })
base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[2] as usize; .chain({
bytes_in_decoded_chunk = 3; let offset = stru.len() / 5 * 5;
} else if base38_characters_number == BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[1] as usize { decode_base38(&stru[offset..])
base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[1] as usize; })
bytes_in_decoded_chunk = 2; .take_while(Result::is_ok)
} else if base38_characters_number == BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[0] as usize {
base38_characters_in_chunk = BASE38_CHARACTERS_NEEDED_IN_NBYTES_CHUNK[0] as usize;
bytes_in_decoded_chunk = 1;
} else {
return Err(Error::InvalidData);
} }
fn decode_base38(chars: &[u8]) -> impl Iterator<Item = Result<u8, Error>> {
let mut value = 0u32; let mut value = 0u32;
let mut cerr = None;
for i in (1..=base38_characters_in_chunk).rev() { let repeat = match chars.len() {
let mut base38_chars = base38_str.chars(); 5 => 3,
let v = decode_char(base38_chars.nth(decoded_base38_characters + i - 1).unwrap())?; 4 => 2,
2 => 1,
0 => 0,
_ => -1,
};
value = value * RADIX + v as u32; if repeat >= 0 {
for c in chars.iter().rev() {
match decode_char(*c) {
Ok(v) => value = value * RADIX + v as u32,
Err(err) => {
cerr = Some(err.code());
break;
}
}
}
} else {
cerr = Some(ErrorCode::InvalidData)
} }
decoded_base38_characters += base38_characters_in_chunk; (0..repeat)
base38_characters_number -= base38_characters_in_chunk; .map(move |_| {
if let Some(err) = cerr {
Err(err.into())
} else {
let byte = (value & 0xff) as u8;
for _i in 0..bytes_in_decoded_chunk {
result.push(value as u8);
value >>= 8; value >>= 8;
Ok(byte)
}
})
.take_while(Result::is_ok)
} }
if value > 0 { fn decode_char(c: u8) -> Result<u8, Error> {
// encoded value is too big to represent a correct chunk of size 1, 2 or 3 bytes
return Err(Error::InvalidArgument);
}
}
Ok(result)
}
fn decode_char(c: char) -> Result<u8, Error> {
let c = c as u8;
if !(45..=90).contains(&c) { if !(45..=90).contains(&c) {
return Err(Error::InvalidData); Err(ErrorCode::InvalidData)?;
} }
let c = DECODE_BASE38[c as usize - 45]; let c = DECODE_BASE38[c as usize - 45];
if c == UNUSED { if c == UNUSED {
return Err(Error::InvalidData); Err(ErrorCode::InvalidData)?;
} }
Ok(c) Ok(c)
@ -211,15 +226,17 @@ mod tests {
#[test] #[test]
fn can_base38_encode() { fn can_base38_encode() {
assert_eq!(encode(&DECODED, None), ENCODED); assert_eq!(
assert_eq!(encode(&DECODED, Some(11)), ENCODED); encode_string::<{ ENCODED.len() }>(&DECODED).unwrap(),
ENCODED
// length is greater than the length of the byte array );
assert_eq!(encode(&DECODED, Some(12)), ENCODED);
} }
#[test] #[test]
fn can_base38_decode() { fn can_base38_decode() {
assert_eq!(decode(ENCODED).expect("can not decode base38"), DECODED); assert_eq!(
decode_vec::<{ DECODED.len() }>(ENCODED).expect("Cannot decode base38"),
DECODED
);
} }
} }

View file

@ -15,21 +15,28 @@
* limitations under the License. * limitations under the License.
*/ */
use core::{borrow::Borrow, cell::RefCell};
use crate::{ use crate::{
acl::AclMgr, acl::AclMgr,
data_model::{ data_model::{
cluster_basic_information::BasicInfoConfig, core::DataModel, cluster_basic_information::BasicInfoConfig,
sdm::dev_att::DevAttDataFetcher, sdm::{dev_att::DevAttDataFetcher, failsafe::FailSafe},
}, },
error::*, error::*,
fabric::FabricMgr, fabric::FabricMgr,
interaction_model::InteractionModel,
mdns::Mdns, mdns::Mdns,
pairing::{print_pairing_code_and_qr, DiscoveryCapabilities}, pairing::{print_pairing_code_and_qr, DiscoveryCapabilities},
secure_channel::{core::SecureChannel, pake::PaseMgr, spake2p::VerifierData}, secure_channel::{pake::PaseMgr, spake2p::VerifierData},
transport, transport::{
exchange::{ExchangeCtx, MAX_EXCHANGES},
session::SessionMgr,
},
utils::{epoch::Epoch, rand::Rand, select::Notification},
}; };
use std::sync::Arc;
/* The Matter Port */
pub const MATTER_PORT: u16 = 5540;
/// Device Commissioning Data /// Device Commissioning Data
pub struct CommissioningData { pub struct CommissioningData {
@ -40,71 +47,190 @@ pub struct CommissioningData {
} }
/// The primary Matter Object /// The primary Matter Object
pub struct Matter { pub struct Matter<'a> {
transport_mgr: transport::mgr::Mgr, fabric_mgr: RefCell<FabricMgr>,
data_model: DataModel, pub acl_mgr: RefCell<AclMgr>, // Public for tests
fabric_mgr: Arc<FabricMgr>, pase_mgr: RefCell<PaseMgr>,
failsafe: RefCell<FailSafe>,
persist_notification: Notification,
pub(crate) send_notification: Notification,
mdns: &'a dyn Mdns,
pub(crate) epoch: Epoch,
pub(crate) rand: Rand,
dev_det: &'a BasicInfoConfig<'a>,
dev_att: &'a dyn DevAttDataFetcher,
pub(crate) port: u16,
pub(crate) exchanges: RefCell<heapless::Vec<ExchangeCtx, MAX_EXCHANGES>>,
pub session_mgr: RefCell<SessionMgr>, // Public for tests
}
impl<'a> Matter<'a> {
#[cfg(feature = "std")]
#[inline(always)]
pub fn new_default(
dev_det: &'a BasicInfoConfig<'a>,
dev_att: &'a dyn DevAttDataFetcher,
mdns: &'a dyn Mdns,
port: u16,
) -> Self {
use crate::utils::epoch::sys_epoch;
use crate::utils::rand::sys_rand;
Self::new(dev_det, dev_att, mdns, sys_epoch, sys_rand, port)
} }
impl Matter {
/// Creates a new Matter object /// Creates a new Matter object
/// ///
/// # Parameters /// # Parameters
/// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device /// * dev_att: An object that implements the trait [DevAttDataFetcher]. Any Matter device
/// requires a set of device attestation certificates and keys. It is the responsibility of /// requires a set of device attestation certificates and keys. It is the responsibility of
/// this object to return the device attestation details when queried upon. /// this object to return the device attestation details when queried upon.
#[inline(always)]
pub fn new( pub fn new(
dev_det: BasicInfoConfig, dev_det: &'a BasicInfoConfig<'a>,
dev_att: Box<dyn DevAttDataFetcher>, dev_att: &'a dyn DevAttDataFetcher,
mdns: &'a dyn Mdns,
epoch: Epoch,
rand: Rand,
port: u16,
) -> Self {
Self {
fabric_mgr: RefCell::new(FabricMgr::new()),
acl_mgr: RefCell::new(AclMgr::new()),
pase_mgr: RefCell::new(PaseMgr::new(epoch, rand)),
failsafe: RefCell::new(FailSafe::new()),
persist_notification: Notification::new(),
send_notification: Notification::new(),
mdns,
epoch,
rand,
dev_det,
dev_att,
port,
exchanges: RefCell::new(heapless::Vec::new()),
session_mgr: RefCell::new(SessionMgr::new(epoch, rand)),
}
}
pub fn dev_det(&self) -> &BasicInfoConfig<'_> {
self.dev_det
}
pub fn dev_att(&self) -> &dyn DevAttDataFetcher {
self.dev_att
}
pub fn port(&self) -> u16 {
self.port
}
pub fn load_fabrics(&self, data: &[u8]) -> Result<(), Error> {
self.fabric_mgr.borrow_mut().load(data, self.mdns)
}
pub fn load_acls(&self, data: &[u8]) -> Result<(), Error> {
self.acl_mgr.borrow_mut().load(data)
}
pub fn store_fabrics<'b>(&self, buf: &'b mut [u8]) -> Result<Option<&'b [u8]>, Error> {
self.fabric_mgr.borrow_mut().store(buf)
}
pub fn store_acls<'b>(&self, buf: &'b mut [u8]) -> Result<Option<&'b [u8]>, Error> {
self.acl_mgr.borrow_mut().store(buf)
}
pub fn is_changed(&self) -> bool {
self.acl_mgr.borrow().is_changed() || self.fabric_mgr.borrow().is_changed()
}
pub fn start_comissioning(
&self,
dev_comm: CommissioningData, dev_comm: CommissioningData,
) -> Result<Box<Matter>, Error> { buf: &mut [u8],
let mdns = Mdns::get()?; ) -> Result<bool, Error> {
mdns.set_values(dev_det.vid, dev_det.pid, &dev_det.device_name); if !self.pase_mgr.borrow().is_pase_session_enabled() && self.fabric_mgr.borrow().is_empty()
{
print_pairing_code_and_qr(
self.dev_det,
&dev_comm,
DiscoveryCapabilities::default(),
buf,
)?;
let fabric_mgr = Arc::new(FabricMgr::new()?); self.pase_mgr.borrow_mut().enable_pase_session(
let open_comm_window = fabric_mgr.is_empty(); dev_comm.verifier,
if open_comm_window { dev_comm.discriminator,
print_pairing_code_and_qr(&dev_det, &dev_comm, DiscoveryCapabilities::default()); self.mdns,
} )?;
let acl_mgr = Arc::new(AclMgr::new()?); Ok(true)
let mut pase = PaseMgr::new(); } else {
let data_model = Ok(false)
DataModel::new(dev_det, dev_att, fabric_mgr.clone(), acl_mgr, pase.clone())?; }
let mut matter = Box::new(Matter { }
transport_mgr: transport::mgr::Mgr::new()?,
data_model, pub fn notify_changed(&self) {
fabric_mgr, if self.is_changed() {
}); self.persist_notification.signal(());
let interaction_model = }
Box::new(InteractionModel::new(Box::new(matter.data_model.clone()))); }
matter.transport_mgr.register_protocol(interaction_model)?;
pub async fn wait_changed(&self) {
if open_comm_window { self.persist_notification.wait().await
pase.enable_pase_session(dev_comm.verifier, dev_comm.discriminator)?; }
} }
let secure_channel = Box::new(SecureChannel::new(pase, matter.fabric_mgr.clone())); impl<'a> Borrow<RefCell<FabricMgr>> for Matter<'a> {
matter.transport_mgr.register_protocol(secure_channel)?; fn borrow(&self) -> &RefCell<FabricMgr> {
Ok(matter) &self.fabric_mgr
} }
}
/// Returns an Arc to [DataModel]
/// impl<'a> Borrow<RefCell<AclMgr>> for Matter<'a> {
/// The Data Model is where you express what is the type of your device. Typically fn borrow(&self) -> &RefCell<AclMgr> {
/// once you gets this reference, you acquire the write lock and add your device &self.acl_mgr
/// types, clusters, attributes, commands to the data model. }
pub fn get_data_model(&self) -> DataModel { }
self.data_model.clone()
} impl<'a> Borrow<RefCell<PaseMgr>> for Matter<'a> {
fn borrow(&self) -> &RefCell<PaseMgr> {
/// Starts the Matter daemon &self.pase_mgr
/// }
/// This call does NOT return }
///
/// This call starts the Matter daemon that starts communication with other Matter impl<'a> Borrow<RefCell<FailSafe>> for Matter<'a> {
/// devices on the network. fn borrow(&self) -> &RefCell<FailSafe> {
pub fn start_daemon(&mut self) -> Result<(), Error> { &self.failsafe
self.transport_mgr.start() }
}
impl<'a> Borrow<BasicInfoConfig<'a>> for Matter<'a> {
fn borrow(&self) -> &BasicInfoConfig<'a> {
self.dev_det
}
}
impl<'a> Borrow<dyn DevAttDataFetcher + 'a> for Matter<'a> {
fn borrow(&self) -> &(dyn DevAttDataFetcher + 'a) {
self.dev_att
}
}
impl<'a> Borrow<dyn Mdns + 'a> for Matter<'a> {
fn borrow(&self) -> &(dyn Mdns + 'a) {
self.mdns
}
}
impl<'a> Borrow<Epoch> for Matter<'a> {
fn borrow(&self) -> &Epoch {
&self.epoch
}
}
impl<'a> Borrow<Rand> for Matter<'a> {
fn borrow(&self) -> &Rand {
&self.rand
} }
} }

View file

@ -17,41 +17,120 @@
use log::error; use log::error;
use crate::error::Error; use crate::{
error::{Error, ErrorCode},
utils::rand::Rand,
};
use super::CryptoKeyPair; pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> {
error!("This API should never get called");
Ok(())
}
pub struct KeyPairDummy {} #[derive(Clone, Debug)]
pub struct Sha256 {}
impl KeyPairDummy { impl Sha256 {
pub fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
Ok(Self {}) Ok(Self {})
} }
pub fn update(&mut self, _data: &[u8]) -> Result<(), Error> {
Ok(())
} }
impl CryptoKeyPair for KeyPairDummy { pub fn finish(self, _digest: &mut [u8]) -> Result<(), Error> {
fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { Ok(())
}
}
pub struct HmacSha256 {}
impl HmacSha256 {
pub fn new(_key: &[u8]) -> Result<Self, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Ok(Self {})
} }
fn get_public_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
pub fn update(&mut self, _data: &[u8]) -> Result<(), Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Ok(())
} }
fn get_private_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
pub fn finish(self, _out: &mut [u8]) -> Result<(), Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Ok(())
} }
fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result<usize, Error> { }
#[derive(Debug)]
pub struct KeyPair;
impl KeyPair {
pub fn new(_rand: Rand) -> Result<Self, Error> {
Ok(Self)
}
pub fn new_from_components(_pub_key: &[u8], _priv_key: &[u8]) -> Result<Self, Error> {
Ok(Self {})
}
pub fn new_from_public(_pub_key: &[u8]) -> Result<Self, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid)
Ok(Self {})
} }
fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result<usize, Error> {
pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> {
pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
Ok(0)
}
pub fn get_private_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
Ok(0)
}
pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
}
pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called");
Err(ErrorCode::Invalid.into())
}
pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> {
error!("This API should never get called");
Err(ErrorCode::Invalid.into())
} }
} }
pub fn pbkdf2_hmac(_pass: &[u8], _iter: usize, _salt: &[u8], _key: &mut [u8]) -> Result<(), Error> {
error!("This API should never get called");
Ok(())
}
pub fn encrypt_in_place(
_key: &[u8],
_nonce: &[u8],
_ad: &[u8],
_data: &mut [u8],
_data_len: usize,
) -> Result<usize, Error> {
Ok(0)
}
pub fn decrypt_in_place(
_key: &[u8],
_nonce: &[u8],
_ad: &[u8],
_data: &mut [u8],
) -> Result<usize, Error> {
Ok(0)
}

View file

@ -17,9 +17,8 @@
use log::error; use log::error;
use crate::error::Error; use crate::error::{Error, ErrorCode};
use crate::utils::rand::Rand;
use super::CryptoKeyPair;
pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> { pub fn hkdf_sha256(_salt: &[u8], _ikm: &[u8], _info: &[u8], _key: &mut [u8]) -> Result<(), Error> {
error!("This API should never get called"); error!("This API should never get called");
@ -62,18 +61,17 @@ impl HmacSha256 {
} }
} }
#[derive(Debug)]
pub struct KeyPair {} pub struct KeyPair {}
impl KeyPair { impl KeyPair {
pub fn new() -> Result<Self, Error> { pub fn new(_rand: Rand) -> Result<Self, Error> {
error!("This API should never get called"); error!("This API should never get called");
Ok(Self {}) Ok(Self {})
} }
pub fn new_from_components(_pub_key: &[u8], priv_key: &[u8]) -> Result<Self, Error> { pub fn new_from_components(_pub_key: &[u8], priv_key: &[u8]) -> Result<Self, Error> {
error!("This API should never get called");
Ok(Self {}) Ok(Self {})
} }
@ -82,28 +80,33 @@ impl KeyPair {
Ok(Self {}) Ok(Self {})
} }
pub fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
error!("This API should never get called");
Err(ErrorCode::Invalid.into())
} }
impl CryptoKeyPair for KeyPair { pub fn get_public_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
fn get_csr<'a>(&self, _out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { Ok(0)
error!("This API should never get called");
Err(Error::Invalid)
} }
fn get_public_key(&self, _pub_key: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called"); pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
Err(Error::Invalid) Ok(0)
} }
fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result<usize, Error> {
pub fn derive_secret(self, _peer_pub_key: &[u8], _secret: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result<usize, Error> {
pub fn sign_msg(&self, _msg: &[u8], _signature: &mut [u8]) -> Result<usize, Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> {
pub fn verify_msg(&self, _msg: &[u8], _signature: &[u8]) -> Result<(), Error> {
error!("This API should never get called"); error!("This API should never get called");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }

View file

@ -15,9 +15,13 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::Arc; extern crate alloc;
use log::error; use core::fmt::{self, Debug};
use alloc::sync::Arc;
use log::{error, info};
use mbedtls::{ use mbedtls::{
bignum::Mpi, bignum::Mpi,
cipher::{Authenticated, Cipher}, cipher::{Authenticated, Cipher},
@ -28,12 +32,12 @@ use mbedtls::{
x509, x509,
}; };
use super::CryptoKeyPair;
use crate::{ use crate::{
// TODO: We should move ASN1Writer out of Cert, // TODO: We should move ASN1Writer out of Cert,
// so Crypto doesn't have to depend on Cert // so Crypto doesn't have to depend on Cert
cert::{ASN1Writer, CertConsumer}, cert::{ASN1Writer, CertConsumer},
error::Error, error::{Error, ErrorCode},
utils::rand::Rand,
}; };
pub struct HmacSha256 { pub struct HmacSha256 {
@ -48,11 +52,13 @@ impl HmacSha256 {
} }
pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { pub fn update(&mut self, data: &[u8]) -> Result<(), Error> {
self.inner.update(data).map_err(|_| Error::TLSStack) self.inner
.update(data)
.map_err(|_| ErrorCode::TLSStack.into())
} }
pub fn finish(self, out: &mut [u8]) -> Result<(), Error> { pub fn finish(self, out: &mut [u8]) -> Result<(), Error> {
self.inner.finish(out).map_err(|_| Error::TLSStack)?; self.inner.finish(out).map_err(|_| ErrorCode::TLSStack)?;
Ok(()) Ok(())
} }
} }
@ -62,7 +68,7 @@ pub struct KeyPair {
} }
impl KeyPair { impl KeyPair {
pub fn new() -> Result<Self, Error> { pub fn new(_rand: Rand) -> Result<Self, Error> {
let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?; let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?;
Ok(Self { Ok(Self {
key: Pk::generate_ec(&mut ctr_drbg, EcGroupId::SecP256R1)?, key: Pk::generate_ec(&mut ctr_drbg, EcGroupId::SecP256R1)?,
@ -85,10 +91,8 @@ impl KeyPair {
key: Pk::public_from_ec_components(group, pub_key)?, key: Pk::public_from_ec_components(group, pub_key)?,
}) })
} }
}
impl CryptoKeyPair for KeyPair { pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
let tmp_priv = self.key.ec_private()?; let tmp_priv = self.key.ec_private()?;
let mut tmp_key = let mut tmp_key =
Pk::private_from_ec_components(EcGroup::new(EcGroupId::SecP256R1)?, tmp_priv)?; Pk::private_from_ec_components(EcGroup::new(EcGroupId::SecP256R1)?, tmp_priv)?;
@ -103,16 +107,16 @@ impl CryptoKeyPair for KeyPair {
Ok(Some(a)) => Ok(a), Ok(Some(a)) => Ok(a),
Ok(None) => { Ok(None) => {
error!("Error in writing CSR: None received"); error!("Error in writing CSR: None received");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
Err(e) => { Err(e) => {
error!("Error in writing CSR {}", e); error!("Error in writing CSR {}", e);
Err(Error::TLSStack) Err(ErrorCode::TLSStack.into())
} }
} }
} }
fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> { pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> {
let public_key = self.key.ec_public()?; let public_key = self.key.ec_public()?;
let group = EcGroup::new(EcGroupId::SecP256R1)?; let group = EcGroup::new(EcGroupId::SecP256R1)?;
let vec = public_key.to_binary(&group, false)?; let vec = public_key.to_binary(&group, false)?;
@ -122,7 +126,7 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> { pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
let priv_key_mpi = self.key.ec_private()?; let priv_key_mpi = self.key.ec_private()?;
let vec = priv_key_mpi.to_binary()?; let vec = priv_key_mpi.to_binary()?;
@ -131,7 +135,7 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> { pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> {
// mbedtls requires a 'mut' key. Instead of making a change in our Trait, // mbedtls requires a 'mut' key. Instead of making a change in our Trait,
// we just clone the key this way // we just clone the key this way
@ -149,7 +153,7 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> { pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> {
// mbedtls requires a 'mut' key. Instead of making a change in our Trait, // mbedtls requires a 'mut' key. Instead of making a change in our Trait,
// we just clone the key this way // we just clone the key this way
let tmp_key = self.key.ec_private()?; let tmp_key = self.key.ec_private()?;
@ -162,7 +166,7 @@ impl CryptoKeyPair for KeyPair {
let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?; let mut ctr_drbg = CtrDrbg::new(Arc::new(OsEntropy::new()), None)?;
if signature.len() < super::EC_SIGNATURE_LEN_BYTES { if signature.len() < super::EC_SIGNATURE_LEN_BYTES {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
safemem::write_bytes(signature, 0); safemem::write_bytes(signature, 0);
@ -175,7 +179,7 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> {
// mbedtls requires a 'mut' key. Instead of making a change in our Trait, // mbedtls requires a 'mut' key. Instead of making a change in our Trait,
// we just clone the key this way // we just clone the key this way
let tmp_key = self.key.ec_public()?; let tmp_key = self.key.ec_public()?;
@ -192,14 +196,20 @@ impl CryptoKeyPair for KeyPair {
let mbedtls_sign = &mbedtls_sign[..len]; let mbedtls_sign = &mbedtls_sign[..len];
if let Err(e) = tmp_key.verify(hash::Type::Sha256, &msg_hash, mbedtls_sign) { if let Err(e) = tmp_key.verify(hash::Type::Sha256, &msg_hash, mbedtls_sign) {
println!("The error is {}", e); info!("The error is {}", e);
Err(Error::InvalidSignature) Err(ErrorCode::InvalidSignature.into())
} else { } else {
Ok(()) Ok(())
} }
} }
} }
impl core::fmt::Debug for KeyPair {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("KeyPair").finish()
}
}
fn convert_r_s_to_asn1_sign(signature: &[u8], mbedtls_sign: &mut [u8]) -> Result<usize, Error> { fn convert_r_s_to_asn1_sign(signature: &[u8], mbedtls_sign: &mut [u8]) -> Result<usize, Error> {
let r = &signature[0..32]; let r = &signature[0..32];
let s = &signature[32..64]; let s = &signature[32..64];
@ -224,7 +234,7 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result<usize, Error> {
// Type 0x2 is Integer (first integer is r) // Type 0x2 is Integer (first integer is r)
if signature[offset] != 2 { if signature[offset] != 2 {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
offset += 1; offset += 1;
@ -249,7 +259,7 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result<usize, Error> {
// Type 0x2 is Integer (this integer is s) // Type 0x2 is Integer (this integer is s)
if signature[offset] != 2 { if signature[offset] != 2 {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
offset += 1; offset += 1;
@ -268,17 +278,17 @@ fn convert_asn1_sign_to_r_s(signature: &mut [u8]) -> Result<usize, Error> {
Ok(64) Ok(64)
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> { pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> {
mbedtls::hash::pbkdf2_hmac(Type::Sha256, pass, salt, iter as u32, key) mbedtls::hash::pbkdf2_hmac(Type::Sha256, pass, salt, iter as u32, key)
.map_err(|_e| Error::TLSStack) .map_err(|_e| ErrorCode::TLSStack.into())
} }
pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> { pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> {
Hkdf::hkdf(Type::Sha256, salt, ikm, info, key).map_err(|_e| Error::TLSStack) Hkdf::hkdf(Type::Sha256, salt, ikm, info, key).map_err(|_e| ErrorCode::TLSStack.into())
} }
pub fn encrypt_in_place( pub fn encrypt_in_place(
@ -299,7 +309,7 @@ pub fn encrypt_in_place(
cipher cipher
.encrypt_auth_inplace(ad, data, tag) .encrypt_auth_inplace(ad, data, tag)
.map(|(len, _)| len) .map(|(len, _)| len)
.map_err(|_e| Error::TLSStack) .map_err(|_e| ErrorCode::TLSStack.into())
} }
pub fn decrypt_in_place( pub fn decrypt_in_place(
@ -321,7 +331,7 @@ pub fn decrypt_in_place(
.map(|(len, _)| len) .map(|(len, _)| len)
.map_err(|e| { .map_err(|e| {
error!("Error during decryption: {:?}", e); error!("Error during decryption: {:?}", e);
Error::TLSStack ErrorCode::TLSStack.into()
}) })
} }
@ -338,12 +348,18 @@ impl Sha256 {
} }
pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { pub fn update(&mut self, data: &[u8]) -> Result<(), Error> {
self.ctx.update(data).map_err(|_| Error::TLSStack)?; self.ctx.update(data).map_err(|_| ErrorCode::TLSStack)?;
Ok(()) Ok(())
} }
pub fn finish(self, digest: &mut [u8]) -> Result<(), Error> { pub fn finish(self, digest: &mut [u8]) -> Result<(), Error> {
self.ctx.finish(digest).map_err(|_| Error::TLSStack)?; self.ctx.finish(digest).map_err(|_| ErrorCode::TLSStack)?;
Ok(()) Ok(())
} }
} }
impl Debug for Sha256 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Sha256")
}
}

View file

@ -15,9 +15,12 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error; use core::fmt::{self, Debug};
use super::CryptoKeyPair; use crate::error::{Error, ErrorCode};
use crate::utils::rand::Rand;
use alloc::vec;
use foreign_types::ForeignTypeRef; use foreign_types::ForeignTypeRef;
use log::error; use log::error;
use openssl::asn1::Asn1Type; use openssl::asn1::Asn1Type;
@ -40,6 +43,9 @@ use openssl::x509::{X509NameBuilder, X509ReqBuilder, X509};
// problem while using OpenSSL's Signer // problem while using OpenSSL's Signer
// TODO: Use proper OpenSSL method for this // TODO: Use proper OpenSSL method for this
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
extern crate alloc;
pub struct HmacSha256 { pub struct HmacSha256 {
ctx: Hmac<sha2::Sha256>, ctx: Hmac<sha2::Sha256>,
} }
@ -47,7 +53,8 @@ pub struct HmacSha256 {
impl HmacSha256 { impl HmacSha256 {
pub fn new(key: &[u8]) -> Result<Self, Error> { pub fn new(key: &[u8]) -> Result<Self, Error> {
Ok(Self { Ok(Self {
ctx: Hmac::<sha2::Sha256>::new_from_slice(key).map_err(|_x| Error::InvalidKeyLength)?, ctx: Hmac::<sha2::Sha256>::new_from_slice(key)
.map_err(|_x| ErrorCode::InvalidKeyLength)?,
}) })
} }
@ -62,16 +69,18 @@ impl HmacSha256 {
} }
} }
#[derive(Debug)]
pub enum KeyType { pub enum KeyType {
Public(EcKey<pkey::Public>), Public(EcKey<pkey::Public>),
Private(EcKey<pkey::Private>), Private(EcKey<pkey::Private>),
} }
#[derive(Debug)]
pub struct KeyPair { pub struct KeyPair {
key: KeyType, key: KeyType,
} }
impl KeyPair { impl KeyPair {
pub fn new() -> Result<Self, Error> { pub fn new(_rand: Rand) -> Result<Self, Error> {
let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
let key = EcKey::generate(&group)?; let key = EcKey::generate(&group)?;
Ok(Self { Ok(Self {
@ -108,14 +117,12 @@ impl KeyPair {
fn private_key(&self) -> Result<&EcKey<Private>, Error> { fn private_key(&self) -> Result<&EcKey<Private>, Error> {
match &self.key { match &self.key {
KeyType::Public(_) => Err(Error::Invalid), KeyType::Public(_) => Err(ErrorCode::Invalid.into()),
KeyType::Private(k) => Ok(&k), KeyType::Private(k) => Ok(&k),
} }
} }
}
impl CryptoKeyPair for KeyPair { pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> {
fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> {
let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
let mut bn_ctx = BigNumContext::new()?; let mut bn_ctx = BigNumContext::new()?;
let s = self.public_key_point().to_bytes( let s = self.public_key_point().to_bytes(
@ -128,14 +135,14 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> { pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
let s = self.private_key()?.private_key().to_vec(); let s = self.private_key()?.private_key().to_vec();
let len = s.len(); let len = s.len();
priv_key[..len].copy_from_slice(s.as_slice()); priv_key[..len].copy_from_slice(s.as_slice());
Ok(len) Ok(len)
} }
fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> { pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> {
let self_pkey = PKey::from_ec_key(self.private_key()?.clone())?; let self_pkey = PKey::from_ec_key(self.private_key()?.clone())?;
let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
@ -149,7 +156,7 @@ impl CryptoKeyPair for KeyPair {
Ok(deriver.derive(secret)?) Ok(deriver.derive(secret)?)
} }
fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
let mut builder = X509ReqBuilder::new()?; let mut builder = X509ReqBuilder::new()?;
builder.set_version(0)?; builder.set_version(0)?;
@ -170,18 +177,18 @@ impl CryptoKeyPair for KeyPair {
a.copy_from_slice(csr); a.copy_from_slice(csr);
Ok(a) Ok(a)
} else { } else {
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} }
} }
fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> { pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> {
// First get the SHA256 of the message // First get the SHA256 of the message
let mut h = Hasher::new(MessageDigest::sha256())?; let mut h = Hasher::new(MessageDigest::sha256())?;
h.update(msg)?; h.update(msg)?;
let msg = h.finish()?; let msg = h.finish()?;
if signature.len() < super::EC_SIGNATURE_LEN_BYTES { if signature.len() < super::EC_SIGNATURE_LEN_BYTES {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
safemem::write_bytes(signature, 0); safemem::write_bytes(signature, 0);
@ -193,7 +200,7 @@ impl CryptoKeyPair for KeyPair {
Ok(64) Ok(64)
} }
fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> {
// First get the SHA256 of the message // First get the SHA256 of the message
let mut h = Hasher::new(MessageDigest::sha256())?; let mut h = Hasher::new(MessageDigest::sha256())?;
h.update(msg)?; h.update(msg)?;
@ -208,11 +215,11 @@ impl CryptoKeyPair for KeyPair {
KeyType::Public(key) => key, KeyType::Public(key) => key,
_ => { _ => {
error!("Not yet supported"); error!("Not yet supported");
return Err(Error::Invalid); return Err(ErrorCode::Invalid.into());
} }
}; };
if !sig.verify(&msg, k)? { if !sig.verify(&msg, k)? {
Err(Error::InvalidSignature) Err(ErrorCode::InvalidSignature.into())
} else { } else {
Ok(()) Ok(())
} }
@ -223,7 +230,7 @@ const P256_KEY_LEN: usize = 256 / 8;
pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Error> { pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Error> {
if out_key.len() != P256_KEY_LEN { if out_key.len() != P256_KEY_LEN {
error!("Insufficient length"); error!("Insufficient length");
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} else { } else {
let key = X509::from_der(der)?.public_key()?.public_key_to_der()?; let key = X509::from_der(der)?.public_key()?.public_key_to_der()?;
let len = key.len(); let len = key.len();
@ -235,7 +242,7 @@ pub fn pubkey_from_der<'a>(der: &'a [u8], out_key: &mut [u8]) -> Result<(), Erro
pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> { pub fn pbkdf2_hmac(pass: &[u8], iter: usize, salt: &[u8], key: &mut [u8]) -> Result<(), Error> {
openssl::pkcs5::pbkdf2_hmac(pass, salt, iter, MessageDigest::sha256(), key) openssl::pkcs5::pbkdf2_hmac(pass, salt, iter, MessageDigest::sha256(), key)
.map_err(|_e| Error::TLSStack) .map_err(|_e| ErrorCode::TLSStack.into())
} }
pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> { pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Result<(), Error> {
@ -295,7 +302,7 @@ pub fn lowlevel_encrypt_aead(
aad: &[u8], aad: &[u8],
data: &[u8], data: &[u8],
tag: &mut [u8], tag: &mut [u8],
) -> Result<Vec<u8>, ErrorStack> { ) -> Result<alloc::vec::Vec<u8>, ErrorStack> {
let t = symm::Cipher::aes_128_ccm(); let t = symm::Cipher::aes_128_ccm();
let mut ctx = CipherCtx::new()?; let mut ctx = CipherCtx::new()?;
CipherCtxRef::encrypt_init( CipherCtxRef::encrypt_init(
@ -331,7 +338,7 @@ pub fn lowlevel_decrypt_aead(
aad: &[u8], aad: &[u8],
data: &[u8], data: &[u8],
tag: &[u8], tag: &[u8],
) -> Result<Vec<u8>, ErrorStack> { ) -> Result<alloc::vec::Vec<u8>, ErrorStack> {
let t = symm::Cipher::aes_128_ccm(); let t = symm::Cipher::aes_128_ccm();
let mut ctx = CipherCtx::new()?; let mut ctx = CipherCtx::new()?;
CipherCtxRef::decrypt_init( CipherCtxRef::decrypt_init(
@ -375,7 +382,9 @@ impl Sha256 {
} }
pub fn update(&mut self, data: &[u8]) -> Result<(), Error> { pub fn update(&mut self, data: &[u8]) -> Result<(), Error> {
self.hasher.update(data).map_err(|_| Error::TLSStack) self.hasher
.update(data)
.map_err(|_| ErrorCode::TLSStack.into())
} }
pub fn finish(mut self, data: &mut [u8]) -> Result<(), Error> { pub fn finish(mut self, data: &mut [u8]) -> Result<(), Error> {
@ -384,3 +393,9 @@ impl Sha256 {
Ok(()) Ok(())
} }
} }
impl Debug for Sha256 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(f, "Sha256")
}
}

View file

@ -15,9 +15,10 @@
* limitations under the License. * limitations under the License.
*/ */
use std::convert::{TryFrom, TryInto}; use core::convert::{TryFrom, TryInto};
use aes::Aes128; use aes::Aes128;
use alloc::vec;
use ccm::{ use ccm::{
aead::generic_array::GenericArray, aead::generic_array::GenericArray,
consts::{U13, U16}, consts::{U13, U16},
@ -33,20 +34,24 @@ use p256::{
use sha2::Digest; use sha2::Digest;
use x509_cert::{ use x509_cert::{
attr::AttributeType, attr::AttributeType,
der::{asn1::BitString, Any, Encode}, der::{asn1::BitString, Any, Encode, Writer},
name::RdnSequence, name::RdnSequence,
request::CertReq, request::CertReq,
spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned}, spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned},
}; };
use crate::error::Error; use crate::{
error::{Error, ErrorCode},
use super::CryptoKeyPair; secure_channel::crypto_rustcrypto::RandRngCore,
utils::rand::Rand,
};
type HmacSha256I = hmac::Hmac<sha2::Sha256>; type HmacSha256I = hmac::Hmac<sha2::Sha256>;
type AesCcm = Ccm<Aes128, U16, U13>; type AesCcm = Ccm<Aes128, U16, U13>;
#[derive(Clone)] extern crate alloc;
#[derive(Debug, Clone)]
pub struct Sha256 { pub struct Sha256 {
hasher: sha2::Sha256, hasher: sha2::Sha256,
} }
@ -79,7 +84,7 @@ impl HmacSha256 {
Ok(Self { Ok(Self {
inner: HmacSha256I::new_from_slice(key).map_err(|e| { inner: HmacSha256I::new_from_slice(key).map_err(|e| {
error!("Error creating HmacSha256 {:?}", e); error!("Error creating HmacSha256 {:?}", e);
Error::TLSStack ErrorCode::TLSStack
})?, })?,
}) })
} }
@ -96,18 +101,20 @@ impl HmacSha256 {
} }
} }
#[derive(Debug)]
pub enum KeyType { pub enum KeyType {
Private(SecretKey), Private(SecretKey),
Public(PublicKey), Public(PublicKey),
} }
#[derive(Debug)]
pub struct KeyPair { pub struct KeyPair {
key: KeyType, key: KeyType,
} }
impl KeyPair { impl KeyPair {
pub fn new() -> Result<Self, Error> { pub fn new(rand: Rand) -> Result<Self, Error> {
let mut rng = rand::thread_rng(); let mut rng = RandRngCore(rand);
let secret_key = SecretKey::random(&mut rng); let secret_key = SecretKey::random(&mut rng);
Ok(Self { Ok(Self {
@ -143,13 +150,11 @@ impl KeyPair {
fn private_key(&self) -> Result<&SecretKey, Error> { fn private_key(&self) -> Result<&SecretKey, Error> {
match &self.key { match &self.key {
KeyType::Private(key) => Ok(key), KeyType::Private(key) => Ok(key),
KeyType::Public(_) => Err(Error::Crypto), KeyType::Public(_) => Err(ErrorCode::Crypto.into()),
}
} }
} }
impl CryptoKeyPair for KeyPair { pub fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error> {
match &self.key { match &self.key {
KeyType::Private(key) => { KeyType::Private(key) => {
let bytes = key.to_bytes(); let bytes = key.to_bytes();
@ -158,10 +163,10 @@ impl CryptoKeyPair for KeyPair {
priv_key[..slice.len()].copy_from_slice(slice); priv_key[..slice.len()].copy_from_slice(slice);
Ok(len) Ok(len)
} }
KeyType::Public(_) => Err(Error::Crypto), KeyType::Public(_) => Err(ErrorCode::Crypto.into()),
} }
} }
fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> { pub fn get_csr<'a>(&self, out_csr: &'a mut [u8]) -> Result<&'a [u8], Error> {
use p256::ecdsa::signature::Signer; use p256::ecdsa::signature::Signer;
let subject = RdnSequence(vec![x509_cert::name::RelativeDistinguishedName( let subject = RdnSequence(vec![x509_cert::name::RelativeDistinguishedName(
@ -200,7 +205,7 @@ impl CryptoKeyPair for KeyPair {
attributes: Default::default(), attributes: Default::default(),
}; };
let mut message = vec![]; let mut message = vec![];
info.encode(&mut message).unwrap(); info.encode(&mut VecWriter(&mut message)).unwrap();
// Can't use self.sign_msg as the signature has to be in DER format // Can't use self.sign_msg as the signature has to be in DER format
let private_key = self.private_key()?; let private_key = self.private_key()?;
@ -224,14 +229,14 @@ impl CryptoKeyPair for KeyPair {
Ok(a) Ok(a)
} }
fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> { pub fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error> {
let point = self.public_key_point().to_encoded_point(false); let point = self.public_key_point().to_encoded_point(false);
let bytes = point.as_bytes(); let bytes = point.as_bytes();
let len = bytes.len(); let len = bytes.len();
pub_key[..len].copy_from_slice(bytes); pub_key[..len].copy_from_slice(bytes);
Ok(len) Ok(len)
} }
fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> { pub fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error> {
let encoded_point = EncodedPoint::from_bytes(peer_pub_key).unwrap(); let encoded_point = EncodedPoint::from_bytes(peer_pub_key).unwrap();
let peer_pubkey = PublicKey::from_encoded_point(&encoded_point).unwrap(); let peer_pubkey = PublicKey::from_encoded_point(&encoded_point).unwrap();
let private_key = self.private_key()?; let private_key = self.private_key()?;
@ -247,11 +252,11 @@ impl CryptoKeyPair for KeyPair {
Ok(len) Ok(len)
} }
fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> { pub fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error> {
use p256::ecdsa::signature::Signer; use p256::ecdsa::signature::Signer;
if signature.len() < super::EC_SIGNATURE_LEN_BYTES { if signature.len() < super::EC_SIGNATURE_LEN_BYTES {
return Err(Error::NoSpace); return Err(ErrorCode::NoSpace.into());
} }
match &self.key { match &self.key {
@ -266,7 +271,7 @@ impl CryptoKeyPair for KeyPair {
KeyType::Public(_) => todo!(), KeyType::Public(_) => todo!(),
} }
} }
fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> { pub fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error> {
use p256::ecdsa::signature::Verifier; use p256::ecdsa::signature::Verifier;
let verifying_key = VerifyingKey::from_affine(self.public_key_point()).unwrap(); let verifying_key = VerifyingKey::from_affine(self.public_key_point()).unwrap();
@ -274,7 +279,7 @@ impl CryptoKeyPair for KeyPair {
verifying_key verifying_key
.verify(msg, &signature) .verify(msg, &signature)
.map_err(|_| Error::InvalidSignature)?; .map_err(|_| ErrorCode::InvalidSignature)?;
Ok(()) Ok(())
} }
@ -291,7 +296,7 @@ pub fn hkdf_sha256(salt: &[u8], ikm: &[u8], info: &[u8], key: &mut [u8]) -> Resu
.expand(info, key) .expand(info, key)
.map_err(|e| { .map_err(|e| {
error!("Error with hkdf_sha256 {:?}", e); error!("Error with hkdf_sha256 {:?}", e);
Error::TLSStack ErrorCode::TLSStack.into()
}) })
} }
@ -370,3 +375,13 @@ impl<'a> ccm::aead::Buffer for SliceBuffer<'a> {
self.len = len; self.len = len;
} }
} }
struct VecWriter<'a>(&'a mut alloc::vec::Vec<u8>);
impl<'a> Writer for VecWriter<'a> {
fn write(&mut self, slice: &[u8]) -> x509_cert::der::Result<()> {
self.0.extend_from_slice(slice);
Ok(())
}
}

View file

@ -14,8 +14,10 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
use crate::{
use crate::error::Error; error::{Error, ErrorCode},
tlv::{FromTLV, TLVWriter, TagType, ToTLV},
};
pub const SYMM_KEY_LEN_BITS: usize = 128; pub const SYMM_KEY_LEN_BITS: usize = 128;
pub const SYMM_KEY_LEN_BYTES: usize = SYMM_KEY_LEN_BITS / 8; pub const SYMM_KEY_LEN_BYTES: usize = SYMM_KEY_LEN_BITS / 8;
@ -35,43 +37,70 @@ pub const ECDH_SHARED_SECRET_LEN_BYTES: usize = 32;
pub const EC_SIGNATURE_LEN_BYTES: usize = 64; pub const EC_SIGNATURE_LEN_BYTES: usize = 64;
// APIs particular to a KeyPair so a KeyPair object can be defined #[cfg(all(feature = "mbedtls", target_os = "espidf"))]
pub trait CryptoKeyPair {
fn get_csr<'a>(&self, csr: &'a mut [u8]) -> Result<&'a [u8], Error>;
fn get_public_key(&self, pub_key: &mut [u8]) -> Result<usize, Error>;
fn get_private_key(&self, priv_key: &mut [u8]) -> Result<usize, Error>;
fn derive_secret(self, peer_pub_key: &[u8], secret: &mut [u8]) -> Result<usize, Error>;
fn sign_msg(&self, msg: &[u8], signature: &mut [u8]) -> Result<usize, Error>;
fn verify_msg(&self, msg: &[u8], signature: &[u8]) -> Result<(), Error>;
}
#[cfg(feature = "crypto_esp_mbedtls")]
mod crypto_esp_mbedtls; mod crypto_esp_mbedtls;
#[cfg(feature = "crypto_esp_mbedtls")] #[cfg(all(feature = "mbedtls", target_os = "espidf"))]
pub use self::crypto_esp_mbedtls::*; pub use self::crypto_esp_mbedtls::*;
#[cfg(feature = "crypto_mbedtls")] #[cfg(all(feature = "mbedtls", not(target_os = "espidf")))]
mod crypto_mbedtls; mod crypto_mbedtls;
#[cfg(feature = "crypto_mbedtls")] #[cfg(all(feature = "mbedtls", not(target_os = "espidf")))]
pub use self::crypto_mbedtls::*; pub use self::crypto_mbedtls::*;
#[cfg(feature = "crypto_openssl")] #[cfg(feature = "openssl")]
mod crypto_openssl; mod crypto_openssl;
#[cfg(feature = "crypto_openssl")] #[cfg(feature = "openssl")]
pub use self::crypto_openssl::*; pub use self::crypto_openssl::*;
#[cfg(feature = "crypto_rustcrypto")] #[cfg(feature = "rustcrypto")]
mod crypto_rustcrypto; mod crypto_rustcrypto;
#[cfg(feature = "crypto_rustcrypto")] #[cfg(feature = "rustcrypto")]
pub use self::crypto_rustcrypto::*; pub use self::crypto_rustcrypto::*;
#[cfg(not(any(feature = "openssl", feature = "mbedtls", feature = "rustcrypto")))]
pub mod crypto_dummy; pub mod crypto_dummy;
#[cfg(not(any(feature = "openssl", feature = "mbedtls", feature = "rustcrypto")))]
pub use self::crypto_dummy::*;
impl<'a> FromTLV<'a> for KeyPair {
fn from_tlv(t: &crate::tlv::TLVElement<'a>) -> Result<Self, Error>
where
Self: Sized,
{
t.confirm_array()?.enter();
if let Some(mut array) = t.enter() {
let pub_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?;
let priv_key = array.next().ok_or(ErrorCode::Invalid)?.slice()?;
KeyPair::new_from_components(pub_key, priv_key)
} else {
Err(ErrorCode::Invalid.into())
}
}
}
impl ToTLV for KeyPair {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
let mut buf = [0; 1024]; // TODO
tw.start_array(tag)?;
let size = self.get_public_key(&mut buf)?;
tw.str16(TagType::Anonymous, &buf[..size])?;
let size = self.get_private_key(&mut buf)?;
tw.str16(TagType::Anonymous, &buf[..size])?;
tw.end_container()
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use crate::error::Error; use crate::error::ErrorCode;
use super::{CryptoKeyPair, KeyPair}; use super::KeyPair;
#[test] #[test]
fn test_verify_msg_success() { fn test_verify_msg_success() {
@ -83,8 +112,9 @@ mod tests {
fn test_verify_msg_fail() { fn test_verify_msg_fail() {
let key = KeyPair::new_from_public(&test_vectors::PUB_KEY1).unwrap(); let key = KeyPair::new_from_public(&test_vectors::PUB_KEY1).unwrap();
assert_eq!( assert_eq!(
key.verify_msg(&test_vectors::MSG1_FAIL, &test_vectors::SIGNATURE1), key.verify_msg(&test_vectors::MSG1_FAIL, &test_vectors::SIGNATURE1)
Err(Error::InvalidSignature) .map_err(|e| e.code()),
Err(ErrorCode::InvalidSignature)
); );
} }

View file

@ -15,14 +15,29 @@
* limitations under the License. * limitations under the License.
*/ */
use core::convert::TryInto;
use super::objects::*; use super::objects::*;
use crate::error::*; use crate::{attribute_enum, error::Error, utils::rand::Rand};
use num_derive::FromPrimitive; use strum::FromRepr;
pub const ID: u32 = 0x0028; pub const ID: u32 = 0x0028;
#[derive(FromPrimitive)] #[derive(Clone, Copy, Debug, FromRepr)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
DMRevision(AttrType<u8>) = 0,
VendorId(AttrType<u16>) = 2,
ProductId(AttrType<u16>) = 4,
HwVer(AttrType<u16>) = 7,
SwVer(AttrType<u32>) = 9,
SwVerString(AttrUtfType) = 0xa,
SerialNo(AttrUtfType) = 0x0f,
}
attribute_enum!(Attributes);
pub enum AttributesDiscriminants {
DMRevision = 0, DMRevision = 0,
VendorId = 2, VendorId = 2,
ProductId = 4, ProductId = 4,
@ -33,82 +48,106 @@ pub enum Attributes {
} }
#[derive(Default)] #[derive(Default)]
pub struct BasicInfoConfig { pub struct BasicInfoConfig<'a> {
pub vid: u16, pub vid: u16,
pub pid: u16, pub pid: u16,
pub hw_ver: u16, pub hw_ver: u16,
pub sw_ver: u32, pub sw_ver: u32,
pub sw_ver_str: String, pub sw_ver_str: &'a str,
pub serial_no: String, pub serial_no: &'a str,
/// Device name; up to 32 characters /// Device name; up to 32 characters
pub device_name: String, pub device_name: &'a str,
} }
pub struct BasicInfoCluster { pub const CLUSTER: Cluster<'static> = Cluster {
base: Cluster, id: ID as _,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new(
AttributesDiscriminants::DMRevision as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::VendorId as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::ProductId as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::HwVer as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::SwVer as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::SwVerString as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::SerialNo as u16,
Access::RV,
Quality::FIXED,
),
],
commands: &[],
};
pub struct BasicInfoCluster<'a> {
data_ver: Dataver,
cfg: &'a BasicInfoConfig<'a>,
} }
impl BasicInfoCluster { impl<'a> BasicInfoCluster<'a> {
pub fn new(cfg: BasicInfoConfig) -> Result<Box<Self>, Error> { pub fn new(cfg: &'a BasicInfoConfig<'a>, rand: Rand) -> Self {
let mut cluster = Box::new(BasicInfoCluster { Self {
base: Cluster::new(ID)?, data_ver: Dataver::new(rand),
}); cfg,
let attrs = [
Attribute::new(
Attributes::DMRevision as u16,
AttrValue::Uint8(1),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::VendorId as u16,
AttrValue::Uint16(cfg.vid),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::ProductId as u16,
AttrValue::Uint16(cfg.pid),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::HwVer as u16,
AttrValue::Uint16(cfg.hw_ver),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::SwVer as u16,
AttrValue::Uint32(cfg.sw_ver),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::SwVerString as u16,
AttrValue::Utf8(cfg.sw_ver_str),
Access::RV,
Quality::FIXED,
),
Attribute::new(
Attributes::SerialNo as u16,
AttrValue::Utf8(cfg.serial_no),
Access::RV,
Quality::FIXED,
),
];
cluster.base.add_attributes(&attrs[..])?;
Ok(cluster)
} }
} }
impl ClusterType for BasicInfoCluster { pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
fn base(&self) -> &Cluster { if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
&self.base if attr.is_system() {
} CLUSTER.read(attr.attr_id, writer)
fn base_mut(&mut self) -> &mut Cluster { } else {
&mut self.base match attr.attr_id.try_into()? {
Attributes::DMRevision(codec) => codec.encode(writer, 1),
Attributes::VendorId(codec) => codec.encode(writer, self.cfg.vid),
Attributes::ProductId(codec) => codec.encode(writer, self.cfg.pid),
Attributes::HwVer(codec) => codec.encode(writer, self.cfg.hw_ver),
Attributes::SwVer(codec) => codec.encode(writer, self.cfg.sw_ver),
Attributes::SwVerString(codec) => codec.encode(writer, self.cfg.sw_ver_str),
Attributes::SerialNo(codec) => codec.encode(writer, self.cfg.serial_no),
}
}
} else {
Ok(())
}
}
}
impl<'a> Handler for BasicInfoCluster<'a> {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
BasicInfoCluster::read(self, attr, encoder)
}
}
impl<'a> NonBlockingHandler for BasicInfoCluster<'a> {}
impl<'a> ChangeNotifier<()> for BasicInfoCluster<'a> {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
} }
} }

View file

@ -15,114 +15,154 @@
* limitations under the License. * limitations under the License.
*/ */
use core::{cell::Cell, convert::TryInto};
use super::objects::*; use super::objects::*;
use crate::{ use crate::{
cmd_enter, attribute_enum, cmd_enter, command_enum, error::Error, tlv::TLVElement,
error::*, transport::exchange::Exchange, utils::rand::Rand,
interaction_model::{command::CommandReq, core::IMStatusCode},
}; };
use log::info; use log::info;
use num_derive::FromPrimitive; use strum::{EnumDiscriminants, FromRepr};
pub const ID: u32 = 0x0006; pub const ID: u32 = 0x0006;
#[derive(FromRepr, EnumDiscriminants)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
OnOff = 0x0, OnOff(AttrType<bool>) = 0x0,
} }
#[derive(FromPrimitive)] attribute_enum!(Attributes);
#[derive(FromRepr, EnumDiscriminants)]
#[repr(u32)]
pub enum Commands { pub enum Commands {
Off = 0x0, Off = 0x0,
On = 0x01, On = 0x01,
Toggle = 0x02, Toggle = 0x02,
} }
fn attr_on_off_new() -> Attribute { command_enum!(Commands);
// OnOff, Value: false
pub const CLUSTER: Cluster<'static> = Cluster {
id: ID as _,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new( Attribute::new(
Attributes::OnOff as u16, AttributesDiscriminants::OnOff as u16,
AttrValue::Bool(false),
Access::RV, Access::RV,
Quality::PERSISTENT, Quality::PERSISTENT,
) ),
} ],
commands: &[
CommandsDiscriminants::Off as _,
CommandsDiscriminants::On as _,
CommandsDiscriminants::Toggle as _,
],
};
pub struct OnOffCluster { pub struct OnOffCluster {
base: Cluster, data_ver: Dataver,
on: Cell<bool>,
} }
impl OnOffCluster { impl OnOffCluster {
pub fn new() -> Result<Box<Self>, Error> { pub fn new(rand: Rand) -> Self {
let mut cluster = Box::new(OnOffCluster { Self {
base: Cluster::new(ID)?, data_ver: Dataver::new(rand),
}); on: Cell::new(false),
cluster.base.add_attribute(attr_on_off_new())?;
Ok(cluster)
} }
} }
impl ClusterType for OnOffCluster { pub fn set(&self, on: bool) {
fn base(&self) -> &Cluster { if self.on.get() != on {
&self.base self.on.set(on);
self.data_ver.changed();
} }
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
} }
fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
let cmd = cmd_req if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
.cmd if attr.is_system() {
.path CLUSTER.read(attr.attr_id, writer)
.leaf } else {
.map(num::FromPrimitive::from_u32) match attr.attr_id.try_into()? {
.ok_or(IMStatusCode::UnsupportedCommand)? Attributes::OnOff(codec) => codec.encode(writer, self.on.get()),
.ok_or(IMStatusCode::UnsupportedCommand)?; }
match cmd { }
} else {
Ok(())
}
}
pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
let data = data.with_dataver(self.data_ver.get())?;
match attr.attr_id.try_into()? {
Attributes::OnOff(codec) => self.set(codec.decode(data)?),
}
self.data_ver.changed();
Ok(())
}
pub fn invoke(
&self,
_exchange: &Exchange,
cmd: &CmdDetails,
_data: &TLVElement,
_encoder: CmdDataEncoder,
) -> Result<(), Error> {
match cmd.cmd_id.try_into()? {
Commands::Off => { Commands::Off => {
cmd_enter!("Off"); cmd_enter!("Off");
let value = self self.set(false);
.base
.read_attribute_raw(Attributes::OnOff as u16)
.unwrap();
if AttrValue::Bool(true) == *value {
self.base
.write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(false))
.map_err(|_| IMStatusCode::Failure)?;
}
cmd_req.trans.complete();
Err(IMStatusCode::Success)
} }
Commands::On => { Commands::On => {
cmd_enter!("On"); cmd_enter!("On");
let value = self self.set(true);
.base
.read_attribute_raw(Attributes::OnOff as u16)
.unwrap();
if AttrValue::Bool(false) == *value {
self.base
.write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(true))
.map_err(|_| IMStatusCode::Failure)?;
}
cmd_req.trans.complete();
Err(IMStatusCode::Success)
} }
Commands::Toggle => { Commands::Toggle => {
cmd_enter!("Toggle"); cmd_enter!("Toggle");
let value = match self self.set(!self.on.get());
.base
.read_attribute_raw(Attributes::OnOff as u16)
.unwrap()
{
&AttrValue::Bool(v) => v,
_ => false,
};
self.base
.write_attribute_raw(Attributes::OnOff as u16, AttrValue::Bool(!value))
.map_err(|_| IMStatusCode::Failure)?;
cmd_req.trans.complete();
Err(IMStatusCode::Success)
} }
} }
self.data_ver.changed();
Ok(())
}
}
impl Handler for OnOffCluster {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
OnOffCluster::read(self, attr, encoder)
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
OnOffCluster::write(self, attr, data)
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
OnOffCluster::invoke(self, exchange, cmd, data, encoder)
}
}
// TODO: Might be removed once the `on` member is externalized
impl NonBlockingHandler for OnOffCluster {}
impl ChangeNotifier<()> for OnOffCluster {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
} }
} }

View file

@ -16,29 +16,59 @@
*/ */
use crate::{ use crate::{
data_model::objects::{Cluster, ClusterType}, data_model::objects::{Cluster, Handler},
error::Error, error::{Error, ErrorCode},
utils::rand::Rand,
};
use super::objects::{
AttrDataEncoder, AttrDetails, ChangeNotifier, Dataver, NonBlockingHandler, ATTRIBUTE_LIST,
FEATURE_MAP,
}; };
const CLUSTER_NETWORK_COMMISSIONING_ID: u32 = 0x0031; const CLUSTER_NETWORK_COMMISSIONING_ID: u32 = 0x0031;
pub struct TemplateCluster { pub const CLUSTER: Cluster<'static> = Cluster {
base: Cluster, id: CLUSTER_NETWORK_COMMISSIONING_ID as _,
} feature_map: 0,
attributes: &[FEATURE_MAP, ATTRIBUTE_LIST],
commands: &[],
};
impl ClusterType for TemplateCluster { pub struct TemplateCluster {
fn base(&self) -> &Cluster { data_ver: Dataver,
&self.base
}
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
}
} }
impl TemplateCluster { impl TemplateCluster {
pub fn new() -> Result<Box<Self>, Error> { pub fn new(rand: Rand) -> Self {
Ok(Box::new(Self { Self {
base: Cluster::new(CLUSTER_NETWORK_COMMISSIONING_ID)?, data_ver: Dataver::new(rand),
})) }
}
pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
Err(ErrorCode::AttributeNotFound.into())
}
} else {
Ok(())
}
}
}
impl Handler for TemplateCluster {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
TemplateCluster::read(self, attr, encoder)
}
}
impl NonBlockingHandler for TemplateCluster {}
impl ChangeNotifier<()> for TemplateCluster {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
} }
} }

View file

@ -0,0 +1,141 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use portable_atomic::{AtomicU32, Ordering};
use super::objects::*;
use crate::{
alloc,
error::*,
interaction_model::core::Interaction,
transport::{exchange::Exchange, packet::Packet},
};
// TODO: For now...
static SUBS_ID: AtomicU32 = AtomicU32::new(1);
pub struct DataModel<T>(T);
impl<T> DataModel<T> {
pub fn new(handler: T) -> Self {
Self(handler)
}
pub async fn handle<'r, 'p>(
&self,
exchange: &'r mut Exchange<'_>,
rx: &'r mut Packet<'p>,
tx: &'r mut Packet<'p>,
rx_status: &'r mut Packet<'p>,
) -> Result<(), Error>
where
T: DataModelHandler,
{
let timeout = Interaction::timeout(exchange, rx, tx).await?;
let mut interaction = alloc!(Interaction::new(
exchange,
rx,
tx,
rx_status,
|| SUBS_ID.fetch_add(1, Ordering::SeqCst),
timeout,
)?);
#[cfg(feature = "alloc")]
let interaction = &mut *interaction;
#[cfg(not(feature = "alloc"))]
let interaction = &mut interaction;
#[cfg(feature = "nightly")]
let metadata = self.0.lock().await;
#[cfg(not(feature = "nightly"))]
let metadata = self.0.lock();
if interaction.start().await? {
match interaction {
Interaction::Read {
req,
ref mut driver,
} => {
let accessor = driver.accessor()?;
'outer: for item in metadata.node().read(req, None, &accessor) {
while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?)
.await?
{
if !driver.send_chunk(req).await? {
break 'outer;
}
}
}
driver.complete(req).await?;
}
Interaction::Write {
req,
ref mut driver,
} => {
let accessor = driver.accessor()?;
for item in metadata.node().write(req, &accessor) {
AttrDataEncoder::handle_write(&item, &self.0, &mut driver.writer()?)
.await?;
}
driver.complete(req).await?;
}
Interaction::Invoke {
req,
ref mut driver,
} => {
let accessor = driver.accessor()?;
for item in metadata.node().invoke(req, &accessor) {
let (mut tw, exchange) = driver.writer_exchange()?;
CmdDataEncoder::handle(&item, &self.0, &mut tw, exchange).await?;
}
driver.complete(req).await?;
}
Interaction::Subscribe {
req,
ref mut driver,
} => {
let accessor = driver.accessor()?;
'outer: for item in metadata.node().subscribing_read(req, None, &accessor) {
while !AttrDataEncoder::handle_read(&item, &self.0, &mut driver.writer()?)
.await?
{
if !driver.send_chunk(req).await? {
break 'outer;
}
}
}
driver.complete(req).await?;
}
}
}
Ok(())
}
}

View file

@ -1,394 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use self::subscribe::SubsCtx;
use super::{
cluster_basic_information::BasicInfoConfig,
device_types::device_type_add_root_node,
objects::{self, *},
sdm::dev_att::DevAttDataFetcher,
system_model::descriptor::DescriptorCluster,
};
use crate::{
acl::{AccessReq, Accessor, AccessorSubjects, AclMgr, AuthMode},
error::*,
fabric::FabricMgr,
interaction_model::{
command::CommandReq,
core::{IMStatusCode, OpCode},
messages::{
ib::{self, AttrData, DataVersionFilter},
msg::{self, InvReq, ReadReq, WriteReq},
GenericPath,
},
InteractionConsumer, Transaction,
},
secure_channel::pake::PaseMgr,
tlv::{self, FromTLV, TLVArray, TLVWriter, TagType, ToTLV},
transport::{
proto_demux::ResponseRequired,
session::{Session, SessionMode},
},
};
use log::{error, info};
use std::sync::{Arc, RwLock};
#[derive(Clone)]
pub struct DataModel {
pub node: Arc<RwLock<Box<Node>>>,
acl_mgr: Arc<AclMgr>,
}
impl DataModel {
pub fn new(
dev_details: BasicInfoConfig,
dev_att: Box<dyn DevAttDataFetcher>,
fabric_mgr: Arc<FabricMgr>,
acl_mgr: Arc<AclMgr>,
pase_mgr: PaseMgr,
) -> Result<Self, Error> {
let dm = DataModel {
node: Arc::new(RwLock::new(Node::new()?)),
acl_mgr: acl_mgr.clone(),
};
{
let mut node = dm.node.write()?;
node.set_changes_cb(Box::new(dm.clone()));
device_type_add_root_node(
&mut node,
dev_details,
dev_att,
fabric_mgr,
acl_mgr,
pase_mgr,
)?;
}
Ok(dm)
}
// Encode a write attribute from a path that may or may not be wildcard
fn handle_write_attr_path(
node: &mut Node,
accessor: &Accessor,
attr_data: &AttrData,
tw: &mut TLVWriter,
) {
let gen_path = attr_data.path.to_gp();
let mut encoder = AttrWriteEncoder::new(tw, TagType::Anonymous);
encoder.set_path(gen_path);
// The unsupported pieces of the wildcard path
if attr_data.path.cluster.is_none() {
encoder.encode_status(IMStatusCode::UnsupportedCluster, 0);
return;
}
if attr_data.path.attr.is_none() {
encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0);
return;
}
// Get the data
let write_data = match &attr_data.data {
EncodeValue::Closure(_) | EncodeValue::Value(_) => {
error!("Not supported");
return;
}
EncodeValue::Tlv(t) => t,
};
if gen_path.is_wildcard() {
// This is a wildcard path, skip error
// This is required because there could be access control errors too that need
// to be taken care of.
encoder.skip_error();
}
let mut attr = AttrDetails {
// will be udpated in the loop below
attr_id: 0,
list_index: attr_data.path.list_index,
fab_filter: false,
fab_idx: accessor.fab_idx,
};
let result = node.for_each_cluster_mut(&gen_path, |path, c| {
if attr_data.data_ver.is_some() && Some(c.base().get_dataver()) != attr_data.data_ver {
encoder.encode_status(IMStatusCode::DataVersionMismatch, 0);
return Ok(());
}
attr.attr_id = path.leaf.unwrap_or_default() as u16;
encoder.set_path(*path);
let mut access_req = AccessReq::new(accessor, path, Access::WRITE);
let r = match Cluster::write_attribute(c, &mut access_req, write_data, &attr) {
Ok(_) => IMStatusCode::Success,
Err(e) => e,
};
encoder.encode_status(r, 0);
Ok(())
});
if let Err(e) = result {
// We hit this only if this is a non-wildcard path and some parts of the path are missing
encoder.encode_status(e, 0);
}
}
// Handle command from a path that may or may not be wildcard
fn handle_command_path(node: &mut Node, cmd_req: &mut CommandReq) {
let wildcard = cmd_req.cmd.path.is_wildcard();
let path = cmd_req.cmd.path;
let result = node.for_each_cluster_mut(&path, |path, c| {
cmd_req.cmd.path = *path;
let result = c.handle_command(cmd_req);
if let Err(e) = result {
// It is likely that we might have to do an 'Access' aware traversal
// if there are other conditions in the wildcard scenario that shouldn't be
// encoded as CmdStatus
if !(wildcard && e == IMStatusCode::UnsupportedCommand) {
let invoke_resp = ib::InvResp::status_new(cmd_req.cmd, e, 0);
let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous);
}
}
Ok(())
});
if !wildcard {
if let Err(e) = result {
// We hit this only if this is a non-wildcard path
let invoke_resp = ib::InvResp::status_new(cmd_req.cmd, e, 0);
let _ = invoke_resp.to_tlv(cmd_req.resp, TagType::Anonymous);
}
}
}
fn sess_to_accessor(&self, sess: &Session) -> Accessor {
match sess.get_session_mode() {
SessionMode::Case(c) => {
let mut subject =
AccessorSubjects::new(sess.get_peer_node_id().unwrap_or_default());
for i in c.cat_ids {
if i != 0 {
let _ = subject.add_catid(i);
}
}
Accessor::new(c.fab_idx, subject, AuthMode::Case, self.acl_mgr.clone())
}
SessionMode::Pase => Accessor::new(
0,
AccessorSubjects::new(1),
AuthMode::Pase,
self.acl_mgr.clone(),
),
SessionMode::PlainText => Accessor::new(
0,
AccessorSubjects::new(1),
AuthMode::Invalid,
self.acl_mgr.clone(),
),
}
}
/// Returns true if the path matches the cluster path and the data version is a match
fn data_filter_matches(
filters: &Option<&TLVArray<DataVersionFilter>>,
path: &GenericPath,
data_ver: u32,
) -> bool {
if let Some(filters) = *filters {
for filter in filters.iter() {
// TODO: No handling of 'node' comparision yet
if Some(filter.path.endpoint) == path.endpoint
&& Some(filter.path.cluster) == path.cluster
&& filter.data_ver == data_ver
{
return true;
}
}
}
false
}
}
pub mod read;
pub mod subscribe;
/// Type of Resume Request
enum ResumeReq {
Subscribe(subscribe::SubsCtx),
Read(read::ResumeReadReq),
}
impl objects::ChangeConsumer for DataModel {
fn endpoint_added(&self, id: EndptId, endpoint: &mut Endpoint) -> Result<(), Error> {
endpoint.add_cluster(DescriptorCluster::new(id, self.clone())?)?;
Ok(())
}
}
impl InteractionConsumer for DataModel {
fn consume_write_attr(
&self,
write_req: &WriteReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error> {
let accessor = self.sess_to_accessor(trans.session);
tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?;
let mut node = self.node.write().unwrap();
for attr_data in write_req.write_requests.iter() {
DataModel::handle_write_attr_path(&mut node, &accessor, &attr_data, tw);
}
tw.end_container()?;
Ok(())
}
fn consume_read_attr(
&self,
rx_buf: &[u8],
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error> {
let mut resume_from = None;
let root = tlv::get_root_node(rx_buf)?;
let req = ReadReq::from_tlv(&root)?;
self.handle_read_req(&req, trans, tw, &mut resume_from)?;
if resume_from.is_some() {
// This is a multi-hop read transaction, remember this read request
let resume = read::ResumeReadReq::new(rx_buf, &resume_from)?;
if !trans.exch.is_data_none() {
error!("Exchange data already set, and multi-hop read");
return Err(Error::InvalidState);
}
trans.exch.set_data_boxed(Box::new(ResumeReq::Read(resume)));
}
Ok(())
}
fn consume_invoke_cmd(
&self,
inv_req_msg: &InvReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error> {
let mut node = self.node.write().unwrap();
if let Some(inv_requests) = &inv_req_msg.inv_requests {
// Array of InvokeResponse IBs
tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?;
for i in inv_requests.iter() {
let data = if let Some(data) = i.data.unwrap_tlv() {
data
} else {
continue;
};
info!("Invoke Commmand Handler executing: {:?}", i.path);
let mut cmd_req = CommandReq {
cmd: i.path,
data,
trans,
resp: tw,
};
DataModel::handle_command_path(&mut node, &mut cmd_req);
}
tw.end_container()?;
}
Ok(())
}
fn consume_status_report(
&self,
req: &msg::StatusResp,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error> {
if let Some(mut resume) = trans.exch.take_data_boxed::<ResumeReq>() {
let result = match *resume {
ResumeReq::Read(ref mut read) => self.handle_resume_read(read, trans, tw)?,
ResumeReq::Subscribe(ref mut ctx) => ctx.handle_status_report(trans, tw, self)?,
};
trans.exch.set_data_boxed(resume);
Ok(result)
} else {
// Nothing to do for now
trans.complete();
info!("Received status report with status {:?}", req.status);
Ok((OpCode::Reserved, ResponseRequired::No))
}
}
fn consume_subscribe(
&self,
rx_buf: &[u8],
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error> {
if !trans.exch.is_data_none() {
error!("Exchange data already set!");
return Err(Error::InvalidState);
}
let ctx = SubsCtx::new(rx_buf, trans, tw, self)?;
trans
.exch
.set_data_boxed(Box::new(ResumeReq::Subscribe(ctx)));
Ok((OpCode::ReportData, ResponseRequired::Yes))
}
}
/// Encoder for generating a response to a write request
pub struct AttrWriteEncoder<'a, 'b, 'c> {
tw: &'a mut TLVWriter<'b, 'c>,
tag: TagType,
path: GenericPath,
skip_error: bool,
}
impl<'a, 'b, 'c> AttrWriteEncoder<'a, 'b, 'c> {
pub fn new(tw: &'a mut TLVWriter<'b, 'c>, tag: TagType) -> Self {
Self {
tw,
tag,
path: Default::default(),
skip_error: false,
}
}
pub fn skip_error(&mut self) {
self.skip_error = true;
}
pub fn set_path(&mut self, path: GenericPath) {
self.path = path;
}
}
impl<'a, 'b, 'c> Encoder for AttrWriteEncoder<'a, 'b, 'c> {
fn encode(&mut self, _value: EncodeValue) {
// Only status encodes for AttrWriteResponse
}
fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16) {
if self.skip_error && status != IMStatusCode::Success {
// Don't encode errors
return;
}
let resp = ib::AttrStatus::new(&self.path, status, cluster_status);
let _ = resp.to_tlv(self.tw, self.tag);
}
}

View file

@ -1,319 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::{
acl::{AccessReq, Accessor},
data_model::{core::DataModel, objects::*},
error::*,
interaction_model::{
core::{IMStatusCode, OpCode},
messages::{
ib::{self, DataVersionFilter},
msg::{self, ReadReq, ReportDataTag::MoreChunkedMsgs, ReportDataTag::SupressResponse},
GenericPath,
},
Transaction,
},
tlv::{self, FromTLV, TLVArray, TLVWriter, TagType, ToTLV},
transport::{packet::Packet, proto_demux::ResponseRequired},
utils::writebuf::WriteBuf,
wb_shrink, wb_unshrink,
};
use log::error;
/// Encoder for generating a response to a read request
pub struct AttrReadEncoder<'a, 'b, 'c> {
tw: &'a mut TLVWriter<'b, 'c>,
data_ver: u32,
path: GenericPath,
skip_error: bool,
data_ver_filters: Option<&'a TLVArray<'a, DataVersionFilter>>,
is_buffer_full: bool,
}
impl<'a, 'b, 'c> AttrReadEncoder<'a, 'b, 'c> {
pub fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self {
Self {
tw,
data_ver: 0,
skip_error: false,
path: Default::default(),
data_ver_filters: None,
is_buffer_full: false,
}
}
pub fn skip_error(&mut self, skip: bool) {
self.skip_error = skip;
}
pub fn set_data_ver(&mut self, data_ver: u32) {
self.data_ver = data_ver;
}
pub fn set_data_ver_filters(&mut self, filters: &'a TLVArray<'a, DataVersionFilter>) {
self.data_ver_filters = Some(filters);
}
pub fn set_path(&mut self, path: GenericPath) {
self.path = path;
}
pub fn is_buffer_full(&self) -> bool {
self.is_buffer_full
}
}
impl<'a, 'b, 'c> Encoder for AttrReadEncoder<'a, 'b, 'c> {
fn encode(&mut self, value: EncodeValue) {
let resp = ib::AttrResp::Data(ib::AttrData::new(
Some(self.data_ver),
ib::AttrPath::new(&self.path),
value,
));
let anchor = self.tw.get_tail();
if resp.to_tlv(self.tw, TagType::Anonymous).is_err() {
self.is_buffer_full = true;
self.tw.rewind_to(anchor);
}
}
fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16) {
if !self.skip_error {
let resp =
ib::AttrResp::Status(ib::AttrStatus::new(&self.path, status, cluster_status));
let _ = resp.to_tlv(self.tw, TagType::Anonymous);
}
}
}
/// State to maintain when a Read Request needs to be resumed
/// resumed - the next chunk of the read needs to be returned
#[derive(Default)]
pub struct ResumeReadReq {
/// The Read Request Attribute Path that caused chunking, and this is the path
/// that needs to be resumed.
pub pending_req: Option<Packet<'static>>,
/// The Attribute that couldn't be encoded because our buffer got full. The next chunk
/// will start encoding from this attribute onwards.
/// Note that given wildcard reads, one PendingPath in the member above can generated
/// multiple encode paths. Hence this has to be maintained separately.
pub resume_from: Option<GenericPath>,
}
impl ResumeReadReq {
pub fn new(rx_buf: &[u8], resume_from: &Option<GenericPath>) -> Result<Self, Error> {
let mut packet = Packet::new_rx()?;
let dst = packet.as_borrow_slice();
let src_len = rx_buf.len();
dst[..src_len].copy_from_slice(rx_buf);
packet.get_parsebuf()?.set_len(src_len);
Ok(ResumeReadReq {
pending_req: Some(packet),
resume_from: *resume_from,
})
}
}
impl DataModel {
pub fn read_attribute_raw(
&self,
endpoint: EndptId,
cluster: ClusterId,
attr: AttrId,
) -> Result<AttrValue, IMStatusCode> {
let node = self.node.read().unwrap();
let cluster = node.get_cluster(endpoint, cluster)?;
cluster.base().read_attribute_raw(attr).map(|a| a.clone())
}
/// Encode a read attribute from a path that may or may not be wildcard
///
/// If the buffer gets full while generating the read response, we will return
/// an Err(path), where the path is the path that we should resume from, for the next chunk.
/// This facilitates chunk management
fn handle_read_attr_path(
node: &Node,
accessor: &Accessor,
attr_encoder: &mut AttrReadEncoder,
attr_details: &mut AttrDetails,
resume_from: &mut Option<GenericPath>,
) -> Result<(), Error> {
let mut status = Ok(());
let path = attr_encoder.path;
// Skip error reporting for wildcard paths, don't for concrete paths
attr_encoder.skip_error(path.is_wildcard());
let result = node.for_each_attribute(&path, |path, c| {
// Ignore processing if data filter matches.
// For a wildcard attribute, this may end happening unnecessarily for all attributes, although
// a single skip for the cluster is sufficient. That requires us to replace this for_each with a
// for_each_cluster
let cluster_data_ver = c.base().get_dataver();
if Self::data_filter_matches(&attr_encoder.data_ver_filters, path, cluster_data_ver) {
return Ok(());
}
// The resume_from indicates that this is the next chunk of a previous Read Request. In such cases, we
// need to skip until we hit this path.
if let Some(r) = resume_from {
// If resume_from is valid, and we haven't hit the resume_from yet, skip encoding
if r != path {
return Ok(());
} else {
// Else, wipe out the resume_from so subsequent paths can be encoded
*resume_from = None;
}
}
attr_details.attr_id = path.leaf.unwrap_or_default() as u16;
// Overwrite the previous path with the concrete path
attr_encoder.set_path(*path);
// Set the cluster's data version
attr_encoder.set_data_ver(cluster_data_ver);
let mut access_req = AccessReq::new(accessor, path, Access::READ);
Cluster::read_attribute(c, &mut access_req, attr_encoder, attr_details);
if attr_encoder.is_buffer_full() {
// Buffer is full, next time resume from this attribute
*resume_from = Some(*path);
status = Err(Error::NoSpace);
}
Ok(())
});
if let Err(e) = result {
// We hit this only if this is a non-wildcard path
attr_encoder.encode_status(e, 0);
}
status
}
/// Process an array of Attribute Read Requests
///
/// When the API returns the chunked read is on, if *resume_from is Some(x) otherwise
/// the read is complete
pub(super) fn handle_read_attr_array(
&self,
read_req: &ReadReq,
trans: &mut Transaction,
old_tw: &mut TLVWriter,
resume_from: &mut Option<GenericPath>,
) -> Result<(), Error> {
let old_wb = old_tw.get_buf();
// Note, this function may be called from multiple places: a) an actual read
// request, a b) resumed read request, c) subscribe request or d) resumed subscribe
// request. Hopefully 18 is sufficient to address all those scenarios.
//
// This is the amount of space we reserve for other things to be attached towards
// the end
const RESERVE_SIZE: usize = 24;
let mut new_wb = wb_shrink!(old_wb, RESERVE_SIZE);
let mut tw = TLVWriter::new(&mut new_wb);
let mut attr_encoder = AttrReadEncoder::new(&mut tw);
if let Some(filters) = &read_req.dataver_filters {
attr_encoder.set_data_ver_filters(filters);
}
if let Some(attr_requests) = &read_req.attr_requests {
let accessor = self.sess_to_accessor(trans.session);
let mut attr_details = AttrDetails::new(accessor.fab_idx, read_req.fabric_filtered);
let node = self.node.read().unwrap();
attr_encoder
.tw
.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?;
let mut result = Ok(());
for attr_path in attr_requests.iter() {
attr_encoder.set_path(attr_path.to_gp());
// Extract the attr_path fields into various structures
attr_details.list_index = attr_path.list_index;
result = DataModel::handle_read_attr_path(
&node,
&accessor,
&mut attr_encoder,
&mut attr_details,
resume_from,
);
if result.is_err() {
break;
}
}
// Now that all the read reports are captured, let's use the old_tw that is
// the full writebuf, and hopefully as all the necessary space to store this
wb_unshrink!(old_wb, new_wb);
old_tw.end_container()?; // Finish the AttrReports
if result.is_err() {
// If there was an error, indicate chunking. The resume_read_req would have been
// already populated in the loop above.
old_tw.bool(TagType::Context(MoreChunkedMsgs as u8), true)?;
} else {
// A None resume_from indicates no chunking
*resume_from = None;
}
}
Ok(())
}
/// Handle a read request
///
/// This could be called from an actual read request or a resumed read request. Subscription
/// requests do not come to this function.
/// When the API returns the chunked read is on, if *resume_from is Some(x) otherwise
/// the read is complete
pub fn handle_read_req(
&self,
read_req: &ReadReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
resume_from: &mut Option<GenericPath>,
) -> Result<(OpCode, ResponseRequired), Error> {
tw.start_struct(TagType::Anonymous)?;
self.handle_read_attr_array(read_req, trans, tw, resume_from)?;
if resume_from.is_none() {
tw.bool(TagType::Context(SupressResponse as u8), true)?;
// Mark transaction complete, if not chunked
trans.complete();
}
tw.end_container()?;
Ok((OpCode::ReportData, ResponseRequired::Yes))
}
/// Handle a resumed read request
pub fn handle_resume_read(
&self,
resume_read_req: &mut ResumeReadReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error> {
if let Some(packet) = resume_read_req.pending_req.as_mut() {
let rx_buf = packet.get_parsebuf()?.as_borrow_slice();
let root = tlv::get_root_node(rx_buf)?;
let req = ReadReq::from_tlv(&root)?;
self.handle_read_req(&req, trans, tw, &mut resume_read_req.resume_from)
} else {
// No pending req, is that even possible?
error!("This shouldn't have happened");
Ok((OpCode::Reserved, ResponseRequired::No))
}
}
}

View file

@ -1,142 +0,0 @@
/*
*
* Copyright (c) 2023 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use std::sync::atomic::{AtomicU32, Ordering};
use crate::{
error::Error,
interaction_model::{
core::OpCode,
messages::{
msg::{self, SubscribeReq, SubscribeResp},
GenericPath,
},
},
tlv::{self, get_root_node_struct, FromTLV, TLVWriter, TagType, ToTLV},
transport::proto_demux::ResponseRequired,
};
use super::{read::ResumeReadReq, DataModel, Transaction};
static SUBS_ID: AtomicU32 = AtomicU32::new(1);
#[derive(PartialEq)]
enum SubsState {
Confirming,
Confirmed,
}
pub struct SubsCtx {
state: SubsState,
id: u32,
resume_read_req: Option<ResumeReadReq>,
}
impl SubsCtx {
pub fn new(
rx_buf: &[u8],
trans: &mut Transaction,
tw: &mut TLVWriter,
dm: &DataModel,
) -> Result<Self, Error> {
let root = get_root_node_struct(rx_buf)?;
let req = SubscribeReq::from_tlv(&root)?;
let mut ctx = SubsCtx {
state: SubsState::Confirming,
// TODO
id: SUBS_ID.fetch_add(1, Ordering::SeqCst),
resume_read_req: None,
};
let mut resume_from = None;
ctx.do_read(&req, trans, tw, dm, &mut resume_from)?;
if resume_from.is_some() {
// This is a multi-hop read transaction, remember this read request
ctx.resume_read_req = Some(ResumeReadReq::new(rx_buf, &resume_from)?);
}
Ok(ctx)
}
pub fn handle_status_report(
&mut self,
trans: &mut Transaction,
tw: &mut TLVWriter,
dm: &DataModel,
) -> Result<(OpCode, ResponseRequired), Error> {
if self.state != SubsState::Confirming {
// Not relevant for us
trans.complete();
return Err(Error::Invalid);
}
// Is there a previous resume read pending
if self.resume_read_req.is_some() {
let mut resume_read_req = self.resume_read_req.take().unwrap();
if let Some(packet) = resume_read_req.pending_req.as_mut() {
let rx_buf = packet.get_parsebuf()?.as_borrow_slice();
let root = tlv::get_root_node(rx_buf)?;
let req = SubscribeReq::from_tlv(&root)?;
self.do_read(&req, trans, tw, dm, &mut resume_read_req.resume_from)?;
if resume_read_req.resume_from.is_some() {
// More chunks are pending, setup resume_read_req again
self.resume_read_req = Some(resume_read_req);
}
return Ok((OpCode::ReportData, ResponseRequired::Yes));
}
}
// We are here implies that the read is now complete
self.confirm_subscription(trans, tw)
}
fn confirm_subscription(
&mut self,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error> {
self.state = SubsState::Confirmed;
// TODO
let resp = SubscribeResp::new(self.id, 40);
resp.to_tlv(tw, TagType::Anonymous)?;
trans.complete();
Ok((OpCode::SubscriptResponse, ResponseRequired::Yes))
}
fn do_read(
&mut self,
req: &SubscribeReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
dm: &DataModel,
resume_from: &mut Option<GenericPath>,
) -> Result<(), Error> {
let read_req = req.to_read_req();
tw.start_struct(TagType::Anonymous)?;
tw.u32(
TagType::Context(msg::ReportDataTag::SubscriptionId as u8),
self.id,
)?;
dm.handle_read_attr_array(&read_req, trans, tw, resume_from)?;
tw.end_container()?;
Ok(())
}
}

View file

@ -15,60 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
use super::cluster_basic_information::BasicInfoCluster; use super::objects::DeviceType;
use super::cluster_basic_information::BasicInfoConfig;
use super::cluster_on_off::OnOffCluster;
use super::objects::*;
use super::sdm::admin_commissioning::AdminCommCluster;
use super::sdm::dev_att::DevAttDataFetcher;
use super::sdm::general_commissioning::GenCommCluster;
use super::sdm::noc::NocCluster;
use super::sdm::nw_commissioning::NwCommCluster;
use super::system_model::access_control::AccessControlCluster;
use crate::acl::AclMgr;
use crate::error::*;
use crate::fabric::FabricMgr;
use crate::secure_channel::pake::PaseMgr;
use std::sync::Arc;
use std::sync::RwLockWriteGuard;
pub const DEV_TYPE_ROOT_NODE: DeviceType = DeviceType { pub const DEV_TYPE_ROOT_NODE: DeviceType = DeviceType {
dtype: 0x0016, dtype: 0x0016,
drev: 1, drev: 1,
}; };
type WriteNode<'a> = RwLockWriteGuard<'a, Box<Node>>; pub const DEV_TYPE_ON_OFF_LIGHT: DeviceType = DeviceType {
pub fn device_type_add_root_node(
node: &mut WriteNode,
dev_info: BasicInfoConfig,
dev_att: Box<dyn DevAttDataFetcher>,
fabric_mgr: Arc<FabricMgr>,
acl_mgr: Arc<AclMgr>,
pase_mgr: PaseMgr,
) -> Result<EndptId, Error> {
// Add the root endpoint
let endpoint = node.add_endpoint(DEV_TYPE_ROOT_NODE)?;
if endpoint != 0 {
// Somehow endpoint 0 was already added, this shouldn't be the case
return Err(Error::Invalid);
};
// Add the mandatory clusters
node.add_cluster(0, BasicInfoCluster::new(dev_info)?)?;
let general_commissioning = GenCommCluster::new()?;
let failsafe = general_commissioning.failsafe();
node.add_cluster(0, general_commissioning)?;
node.add_cluster(0, NwCommCluster::new()?)?;
node.add_cluster(0, AdminCommCluster::new(pase_mgr)?)?;
node.add_cluster(
0,
NocCluster::new(dev_att, fabric_mgr, acl_mgr.clone(), failsafe)?,
)?;
node.add_cluster(0, AccessControlCluster::new(acl_mgr)?)?;
Ok(endpoint)
}
const DEV_TYPE_ON_OFF_LIGHT: DeviceType = DeviceType {
dtype: 0x0100, dtype: 0x0100,
drev: 2, drev: 2,
}; };
@ -77,9 +31,3 @@ pub const DEV_TYPE_ON_SMART_SPEAKER: DeviceType = DeviceType {
dtype: 0x0022, dtype: 0x0022,
drev: 2, drev: 2,
}; };
pub fn device_type_add_on_off_light(node: &mut WriteNode) -> Result<EndptId, Error> {
let endpoint = node.add_endpoint(DEV_TYPE_ON_OFF_LIGHT)?;
node.add_cluster(endpoint, OnOffCluster::new()?)?;
Ok(endpoint)
}

View file

@ -20,8 +20,9 @@ pub mod device_types;
pub mod objects; pub mod objects;
pub mod cluster_basic_information; pub mod cluster_basic_information;
pub mod cluster_media_playback; // TODO pub mod cluster_media_playback;
pub mod cluster_on_off; pub mod cluster_on_off;
pub mod cluster_template; pub mod cluster_template;
pub mod root_endpoint;
pub mod sdm; pub mod sdm;
pub mod system_model; pub mod system_model;

View file

@ -15,15 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use super::{AttrId, GlobalElements, Privilege}; use crate::data_model::objects::GlobalElements;
use crate::{
error::*, use super::{AttrId, Privilege};
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{TLVElement, TLVWriter, TagType, ToTLV},
};
use bitflags::bitflags; use bitflags::bitflags;
use log::error; use core::fmt::{self, Debug};
use std::fmt::{self, Debug, Formatter};
bitflags! { bitflags! {
#[derive(Default)] #[derive(Default)]
@ -83,110 +79,24 @@ bitflags! {
} }
} }
/* This file needs some major revamp.
* - instead of allocating all over the heap, we should use some kind of slab/block allocator
* - instead of arrays, can use linked-lists to conserve space and avoid the internal fragmentation
*/
#[derive(PartialEq, PartialOrd, Clone)]
pub enum AttrValue {
Int64(i64),
Uint8(u8),
Uint16(u16),
Uint32(u32),
Uint64(u64),
Bool(bool),
Utf8(String),
Custom,
}
impl Debug for AttrValue {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> {
match &self {
AttrValue::Int64(v) => write!(f, "{:?}", *v),
AttrValue::Uint8(v) => write!(f, "{:?}", *v),
AttrValue::Uint16(v) => write!(f, "{:?}", *v),
AttrValue::Uint32(v) => write!(f, "{:?}", *v),
AttrValue::Uint64(v) => write!(f, "{:?}", *v),
AttrValue::Bool(v) => write!(f, "{:?}", *v),
AttrValue::Utf8(v) => write!(f, "{:?}", *v),
AttrValue::Custom => write!(f, "custom-attribute"),
}?;
Ok(())
}
}
impl ToTLV for AttrValue {
fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
// What is the time complexity of such long match statements?
match self {
AttrValue::Bool(v) => tw.bool(tag_type, *v),
AttrValue::Uint8(v) => tw.u8(tag_type, *v),
AttrValue::Uint16(v) => tw.u16(tag_type, *v),
AttrValue::Uint32(v) => tw.u32(tag_type, *v),
AttrValue::Uint64(v) => tw.u64(tag_type, *v),
AttrValue::Utf8(v) => tw.utf8(tag_type, v.as_bytes()),
_ => {
error!("Attribute type not yet supported");
Err(Error::AttributeNotFound)
}
}
}
}
impl AttrValue {
pub fn update_from_tlv(&mut self, tr: &TLVElement) -> Result<(), Error> {
match self {
AttrValue::Bool(v) => *v = tr.bool()?,
AttrValue::Uint8(v) => *v = tr.u8()?,
AttrValue::Uint16(v) => *v = tr.u16()?,
AttrValue::Uint32(v) => *v = tr.u32()?,
AttrValue::Uint64(v) => *v = tr.u64()?,
_ => {
error!("Attribute type not yet supported");
return Err(Error::AttributeNotFound);
}
}
Ok(())
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct Attribute { pub struct Attribute {
pub(super) id: AttrId, pub id: AttrId,
pub(super) value: AttrValue, pub quality: Quality,
pub(super) quality: Quality, pub access: Access,
pub(super) access: Access,
}
impl Default for Attribute {
fn default() -> Attribute {
Attribute {
id: 0,
value: AttrValue::Bool(true),
quality: Default::default(),
access: Default::default(),
}
}
} }
impl Attribute { impl Attribute {
pub fn new(id: AttrId, value: AttrValue, access: Access, quality: Quality) -> Self { pub const fn new(id: AttrId, access: Access, quality: Quality) -> Self {
Attribute { Self {
id, id,
value,
access, access,
quality, quality,
} }
} }
pub fn set_value(&mut self, value: AttrValue) -> Result<(), Error> { pub fn is_system(&self) -> bool {
if !self.quality.contains(Quality::FIXED) { Self::is_system_attr(self.id)
self.value = value;
Ok(())
} else {
Err(Error::Invalid)
}
} }
pub fn is_system_attr(attr_id: AttrId) -> bool { pub fn is_system_attr(attr_id: AttrId) -> bool {
@ -194,9 +104,9 @@ impl Attribute {
} }
} }
impl std::fmt::Display for Attribute { impl core::fmt::Display for Attribute {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {:?}", self.id, self.value) write!(f, "{}", self.id)
} }
} }

View file

@ -15,25 +15,31 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::{
acl::AccessReq,
data_model::objects::{Access, AttrValue, Attribute, EncodeValue, Quality},
error::*,
interaction_model::{command::CommandReq, core::IMStatusCode},
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{Nullable, TLVElement, TLVWriter, TagType},
};
use log::error; use log::error;
use num_derive::FromPrimitive; use strum::FromRepr;
use rand::Rng;
use std::fmt::{self, Debug};
use super::{AttrId, ClusterId, Encoder}; use crate::{
acl::{AccessReq, Accessor},
attribute_enum,
data_model::objects::*,
error::{Error, ErrorCode},
interaction_model::{
core::IMStatusCode,
messages::{
ib::{AttrPath, AttrStatus, CmdPath, CmdStatus},
GenericPath,
},
},
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{Nullable, TLVWriter, TagType},
};
use core::{
convert::TryInto,
fmt::{self, Debug},
};
pub const ATTRS_PER_CLUSTER: usize = 10; #[derive(Clone, Copy, Debug, Eq, PartialEq, FromRepr)]
pub const CMDS_PER_CLUSTER: usize = 8; #[repr(u16)]
#[derive(FromPrimitive, Debug)]
pub enum GlobalElements { pub enum GlobalElements {
_ClusterRevision = 0xFFFD, _ClusterRevision = 0xFFFD,
FeatureMap = 0xFFFC, FeatureMap = 0xFFFC,
@ -44,297 +50,292 @@ pub enum GlobalElements {
FabricIndex = 0xFE, FabricIndex = 0xFE,
} }
attribute_enum!(GlobalElements);
pub const FEATURE_MAP: Attribute =
Attribute::new(GlobalElements::FeatureMap as _, Access::RV, Quality::NONE);
pub const ATTRIBUTE_LIST: Attribute = Attribute::new(
GlobalElements::AttributeList as _,
Access::RV,
Quality::NONE,
);
// TODO: What if we instead of creating this, we just pass the AttrData/AttrPath to the read/write // TODO: What if we instead of creating this, we just pass the AttrData/AttrPath to the read/write
// methods? // methods?
/// The Attribute Details structure records the details about the attribute under consideration. /// The Attribute Details structure records the details about the attribute under consideration.
/// Typically this structure is progressively built as we proceed through the request processing. #[derive(Debug)]
pub struct AttrDetails { pub struct AttrDetails<'a> {
/// Fabric Filtering Activated pub node: &'a Node<'a>,
pub fab_filter: bool, /// The actual endpoint ID
/// The current Fabric Index pub endpoint_id: EndptId,
pub fab_idx: u8, /// The actual cluster ID
/// List Index, if any pub cluster_id: ClusterId,
pub list_index: Option<Nullable<u16>>,
/// The actual attribute ID /// The actual attribute ID
pub attr_id: AttrId, pub attr_id: AttrId,
/// List Index, if any
pub list_index: Option<Nullable<u16>>,
/// The current Fabric Index
pub fab_idx: u8,
/// Fabric Filtering Activated
pub fab_filter: bool,
pub dataver: Option<u32>,
pub wildcard: bool,
} }
impl AttrDetails { impl<'a> AttrDetails<'a> {
pub fn new(fab_idx: u8, fab_filter: bool) -> Self { pub fn is_system(&self) -> bool {
Self { Attribute::is_system_attr(self.attr_id)
fab_filter,
fab_idx,
list_index: None,
attr_id: 0,
}
}
} }
pub trait ClusterType { pub fn path(&self) -> AttrPath {
// TODO: 5 methods is going to be quite expensive for vtables of all the clusters AttrPath {
fn base(&self) -> &Cluster; endpoint: Some(self.endpoint_id),
fn base_mut(&mut self) -> &mut Cluster; cluster: Some(self.cluster_id),
fn read_custom_attribute(&self, _encoder: &mut dyn Encoder, _attr: &AttrDetails) {} attr: Some(self.attr_id),
list_index: self.list_index,
fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { ..Default::default()
let cmd = cmd_req.cmd.path.leaf.map(|a| a as u16);
println!("Received command: {:?}", cmd);
Err(IMStatusCode::UnsupportedCommand)
}
/// Write an attribute
///
/// Note that if this method is defined, you must handle the write for all the attributes. Even those
/// that are not 'custom'. This is different from how you handle the read_custom_attribute() method.
/// The reason for this being, you may want to handle an attribute write request even though it is a
/// standard attribute like u16, u32 etc.
///
/// If you wish to update the standard attribute in the data model database, you must call the
/// write_attribute_from_tlv() method from the base cluster, as is shown here in the default case
fn write_attribute(
&mut self,
attr: &AttrDetails,
data: &TLVElement,
) -> Result<(), IMStatusCode> {
self.base_mut().write_attribute_from_tlv(attr.attr_id, data)
} }
} }
pub struct Cluster { pub fn status(&self, status: IMStatusCode) -> Result<Option<AttrStatus>, Error> {
pub(super) id: ClusterId, if self.should_report(status) {
attributes: Vec<Attribute>, Ok(Some(AttrStatus::new(
data_ver: u32, &GenericPath {
endpoint: Some(self.endpoint_id),
cluster: Some(self.cluster_id),
leaf: Some(self.attr_id as _),
},
status,
0,
)))
} else {
Ok(None)
}
} }
impl Cluster { fn should_report(&self, status: IMStatusCode) -> bool {
pub fn new(id: ClusterId) -> Result<Cluster, Error> { !self.wildcard
let mut c = Cluster { || !matches!(
id, status,
attributes: Vec::with_capacity(ATTRS_PER_CLUSTER), IMStatusCode::UnsupportedEndpoint
data_ver: rand::thread_rng().gen_range(0..0xFFFFFFFF), | IMStatusCode::UnsupportedCluster
}; | IMStatusCode::UnsupportedAttribute
c.add_default_attributes()?; | IMStatusCode::UnsupportedCommand
Ok(c) | IMStatusCode::UnsupportedAccess
| IMStatusCode::UnsupportedRead
| IMStatusCode::UnsupportedWrite
| IMStatusCode::DataVersionMismatch
)
}
} }
pub fn id(&self) -> ClusterId { #[derive(Debug)]
self.id pub struct CmdDetails<'a> {
pub node: &'a Node<'a>,
pub endpoint_id: EndptId,
pub cluster_id: ClusterId,
pub cmd_id: CmdId,
pub wildcard: bool,
} }
pub fn get_dataver(&self) -> u32 { impl<'a> CmdDetails<'a> {
self.data_ver pub fn path(&self) -> CmdPath {
CmdPath::new(
Some(self.endpoint_id),
Some(self.cluster_id),
Some(self.cmd_id),
)
} }
pub fn set_feature_map(&mut self, map: u32) -> Result<(), Error> { pub fn success(&self, tracker: &CmdDataTracker) -> Option<CmdStatus> {
self.write_attribute_raw(GlobalElements::FeatureMap as u16, AttrValue::Uint32(map)) if tracker.needs_status() {
.map_err(|_| Error::Invalid)?; self.status(IMStatusCode::Success)
Ok(()) } else {
None
}
} }
fn add_default_attributes(&mut self) -> Result<(), Error> { pub fn status(&self, status: IMStatusCode) -> Option<CmdStatus> {
// Default feature map is 0 if self.should_report(status) {
self.add_attribute(Attribute::new( Some(CmdStatus::new(
GlobalElements::FeatureMap as u16, CmdPath::new(
AttrValue::Uint32(0), Some(self.endpoint_id),
Access::RV, Some(self.cluster_id),
Quality::NONE, Some(self.cmd_id),
))?; ),
status,
self.add_attribute(Attribute::new( 0,
GlobalElements::AttributeList as u16,
AttrValue::Custom,
Access::RV,
Quality::NONE,
)) ))
}
pub fn add_attributes(&mut self, attrs: &[Attribute]) -> Result<(), Error> {
if self.attributes.len() + attrs.len() <= self.attributes.capacity() {
self.attributes.extend_from_slice(attrs);
Ok(())
} else { } else {
Err(Error::NoSpace) None
} }
} }
pub fn add_attribute(&mut self, attr: Attribute) -> Result<(), Error> { fn should_report(&self, status: IMStatusCode) -> bool {
if self.attributes.len() < self.attributes.capacity() { !self.wildcard
self.attributes.push(attr); || !matches!(
Ok(()) status,
} else { IMStatusCode::UnsupportedEndpoint
Err(Error::NoSpace) | IMStatusCode::UnsupportedCluster
| IMStatusCode::UnsupportedAttribute
| IMStatusCode::UnsupportedCommand
| IMStatusCode::UnsupportedAccess
| IMStatusCode::UnsupportedRead
| IMStatusCode::UnsupportedWrite
)
} }
} }
fn get_attribute_index(&self, attr_id: AttrId) -> Option<usize> { #[derive(Debug, Clone)]
self.attributes.iter().position(|c| c.id == attr_id) pub struct Cluster<'a> {
pub id: ClusterId,
pub feature_map: u32,
pub attributes: &'a [Attribute],
pub commands: &'a [CmdId],
} }
fn get_attribute(&self, attr_id: AttrId) -> Result<&Attribute, Error> { impl<'a> Cluster<'a> {
let index = self pub const fn new(
.get_attribute_index(attr_id) id: ClusterId,
.ok_or(Error::AttributeNotFound)?; feature_map: u32,
Ok(&self.attributes[index]) attributes: &'a [Attribute],
commands: &'a [CmdId],
) -> Self {
Self {
id,
feature_map,
attributes,
commands,
}
} }
fn get_attribute_mut(&mut self, attr_id: AttrId) -> Result<&mut Attribute, Error> { pub fn match_attributes(
let index = self
.get_attribute_index(attr_id)
.ok_or(Error::AttributeNotFound)?;
Ok(&mut self.attributes[index])
}
// Returns a slice of attribute, with either a single attribute or all (wildcard)
pub fn get_wildcard_attribute(
&self, &self,
attribute: Option<AttrId>, attr: Option<AttrId>,
) -> Result<(&[Attribute], bool), IMStatusCode> { ) -> impl Iterator<Item = &'_ Attribute> + '_ {
if let Some(a) = attribute { self.attributes
if let Some(i) = self.get_attribute_index(a) { .iter()
Ok((&self.attributes[i..i + 1], false)) .filter(move |attribute| attr.map(|attr| attr == attribute.id).unwrap_or(true))
}
pub fn match_commands(&self, cmd: Option<CmdId>) -> impl Iterator<Item = CmdId> + '_ {
self.commands
.iter()
.filter(move |id| cmd.map(|cmd| **id == cmd).unwrap_or(true))
.copied()
}
pub fn check_attribute(
&self,
accessor: &Accessor,
ep: EndptId,
attr: AttrId,
write: bool,
) -> Result<(), IMStatusCode> {
let attribute = self
.attributes
.iter()
.find(|attribute| attribute.id == attr)
.ok_or(IMStatusCode::UnsupportedAttribute)?;
Self::check_attr_access(
accessor,
GenericPath::new(Some(ep), Some(self.id), Some(attr as _)),
write,
attribute.access,
)
}
pub fn check_command(
&self,
accessor: &Accessor,
ep: EndptId,
cmd: CmdId,
) -> Result<(), IMStatusCode> {
self.commands
.iter()
.find(|id| **id == cmd)
.ok_or(IMStatusCode::UnsupportedCommand)?;
Self::check_cmd_access(
accessor,
GenericPath::new(Some(ep), Some(self.id), Some(cmd)),
)
}
pub(crate) fn check_attr_access(
accessor: &Accessor,
path: GenericPath,
write: bool,
target_perms: Access,
) -> Result<(), IMStatusCode> {
let mut access_req = AccessReq::new(
accessor,
path,
if write { Access::WRITE } else { Access::READ },
);
if !target_perms.contains(access_req.operation()) {
Err(if matches!(access_req.operation(), Access::WRITE) {
IMStatusCode::UnsupportedWrite
} else { } else {
Err(IMStatusCode::UnsupportedAttribute) IMStatusCode::UnsupportedRead
})?;
} }
access_req.set_target_perms(target_perms);
if access_req.allow() {
Ok(())
} else { } else {
Ok((&self.attributes[..], true)) Err(IMStatusCode::UnsupportedAccess)
} }
} }
pub fn read_attribute( pub(crate) fn check_cmd_access(
c: &dyn ClusterType, accessor: &Accessor,
access_req: &mut AccessReq, path: GenericPath,
encoder: &mut dyn Encoder, ) -> Result<(), IMStatusCode> {
attr: &AttrDetails, let mut access_req = AccessReq::new(accessor, path, Access::WRITE);
) {
let mut error = IMStatusCode::Success; access_req.set_target_perms(
let base = c.base(); Access::WRITE
let a = if let Ok(a) = base.get_attribute(attr.attr_id) { .union(Access::NEED_OPERATE)
a .union(Access::NEED_MANAGE)
.union(Access::NEED_ADMIN),
); // TODO
if access_req.allow() {
Ok(())
} else { } else {
encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0); Err(IMStatusCode::UnsupportedAccess)
return;
};
if !a.access.contains(Access::READ) {
error = IMStatusCode::UnsupportedRead;
}
access_req.set_target_perms(a.access);
if !access_req.allow() {
error = IMStatusCode::UnsupportedAccess;
}
if error != IMStatusCode::Success {
encoder.encode_status(error, 0);
} else if Attribute::is_system_attr(attr.attr_id) {
c.base().read_system_attribute(encoder, a)
} else if a.value != AttrValue::Custom {
encoder.encode(EncodeValue::Value(&a.value))
} else {
c.read_custom_attribute(encoder, attr)
} }
} }
fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) { pub fn read(&self, attr: AttrId, mut writer: AttrDataWriter) -> Result<(), Error> {
let _ = tw.start_array(tag); match attr.try_into()? {
for a in &self.attributes {
let _ = tw.u16(TagType::Anonymous, a.id);
}
let _ = tw.end_container();
}
fn read_system_attribute(&self, encoder: &mut dyn Encoder, attr: &Attribute) {
let global_attr: Option<GlobalElements> = num::FromPrimitive::from_u16(attr.id);
if let Some(global_attr) = global_attr {
match global_attr {
GlobalElements::AttributeList => { GlobalElements::AttributeList => {
encoder.encode(EncodeValue::Closure(&|tag, tw| { self.encode_attribute_ids(AttrDataWriter::TAG, &mut writer)?;
self.encode_attribute_ids(tag, tw) writer.complete()
}));
return;
} }
GlobalElements::FeatureMap => { GlobalElements::FeatureMap => writer.set(self.feature_map),
encoder.encode(EncodeValue::Value(&attr.value)); other => {
return; error!("This attribute is not yet handled {:?}", other);
Err(ErrorCode::AttributeNotFound.into())
} }
_ => {
error!("This attribute not yet handled {:?}", global_attr);
}
}
}
encoder.encode_status(IMStatusCode::UnsupportedAttribute, 0)
}
pub fn read_attribute_raw(&self, attr_id: AttrId) -> Result<&AttrValue, IMStatusCode> {
let a = self
.get_attribute(attr_id)
.map_err(|_| IMStatusCode::UnsupportedAttribute)?;
Ok(&a.value)
}
pub fn write_attribute(
c: &mut dyn ClusterType,
access_req: &mut AccessReq,
data: &TLVElement,
attr: &AttrDetails,
) -> Result<(), IMStatusCode> {
let base = c.base_mut();
let a = if let Ok(a) = base.get_attribute_mut(attr.attr_id) {
a
} else {
return Err(IMStatusCode::UnsupportedAttribute);
};
if !a.access.contains(Access::WRITE) {
return Err(IMStatusCode::UnsupportedWrite);
}
access_req.set_target_perms(a.access);
if !access_req.allow() {
return Err(IMStatusCode::UnsupportedAccess);
}
c.write_attribute(attr, data)
}
pub fn write_attribute_from_tlv(
&mut self,
attr_id: AttrId,
data: &TLVElement,
) -> Result<(), IMStatusCode> {
let a = self.get_attribute_mut(attr_id)?;
if a.value != AttrValue::Custom {
let mut value = a.value.clone();
value
.update_from_tlv(data)
.map_err(|_| IMStatusCode::Failure)?;
a.set_value(value)
.map(|_| {
self.cluster_changed();
})
.map_err(|_| IMStatusCode::UnsupportedWrite)
} else {
Err(IMStatusCode::UnsupportedAttribute)
} }
} }
pub fn write_attribute_raw(&mut self, attr_id: AttrId, value: AttrValue) -> Result<(), Error> { fn encode_attribute_ids(&self, tag: TagType, tw: &mut TLVWriter) -> Result<(), Error> {
let a = self.get_attribute_mut(attr_id)?; tw.start_array(tag)?;
a.set_value(value).map(|_| { for a in self.attributes {
self.cluster_changed(); tw.u16(TagType::Anonymous, a.id)?;
})
} }
/// This method must be called for any changes to the data model tw.end_container()
/// Currently this only increments the data version, but we can reuse the same
/// for raising events too
pub fn cluster_changed(&mut self) {
self.data_ver = self.data_ver.wrapping_add(1);
} }
} }
impl std::fmt::Display for Cluster { impl<'a> core::fmt::Display for Cluster<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "id:{}, ", self.id)?; write!(f, "id:{}, ", self.id)?;
write!(f, "attrs[")?; write!(f, "attrs[")?;

View file

@ -0,0 +1,57 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use core::cell::Cell;
use crate::utils::rand::Rand;
pub struct Dataver {
ver: Cell<u32>,
changed: Cell<bool>,
}
impl Dataver {
pub fn new(rand: Rand) -> Self {
let mut buf = [0; 4];
rand(&mut buf);
Self {
ver: Cell::new(u32::from_be_bytes(buf)),
changed: Cell::new(false),
}
}
pub fn get(&self) -> u32 {
self.ver.get()
}
pub fn changed(&self) -> u32 {
self.ver.set(self.ver.get().overflowing_add(1).0);
self.changed.set(true);
self.get()
}
pub fn consume_change<T>(&self, change: T) -> Option<T> {
if self.changed.get() {
self.changed.set(false);
Some(change)
} else {
None
}
}
}

View file

@ -15,21 +15,31 @@
* limitations under the License. * limitations under the License.
*/ */
use std::fmt::{Debug, Formatter}; use core::fmt::{Debug, Formatter};
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
use crate::interaction_model::core::IMStatusCode;
use crate::interaction_model::messages::ib::{
AttrPath, AttrResp, AttrStatus, CmdDataTag, CmdPath, CmdStatus, InvResp, InvRespTag,
};
use crate::tlv::UtfStr;
use crate::transport::exchange::Exchange;
use crate::{ use crate::{
error::Error, error::{Error, ErrorCode},
interaction_model::core::IMStatusCode, interaction_model::messages::ib::{AttrDataTag, AttrRespTag},
tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV},
}; };
use log::error; use log::error;
use super::{AttrDetails, CmdDetails, DataModelHandler};
// TODO: Should this return an IMStatusCode Error? But if yes, the higher layer // TODO: Should this return an IMStatusCode Error? But if yes, the higher layer
// may have already started encoding the 'success' headers, we might not to manage // may have already started encoding the 'success' headers, we might not want to manage
// the tw.rewind() in that case, if we add this support // the tw.rewind() in that case, if we add this support
pub type EncodeValueGen<'a> = &'a dyn Fn(TagType, &mut TLVWriter); pub type EncodeValueGen<'a> = &'a dyn Fn(TagType, &mut TLVWriter);
#[derive(Copy, Clone)] #[derive(Clone)]
/// A structure for encoding various types of values /// A structure for encoding various types of values
pub enum EncodeValue<'a> { pub enum EncodeValue<'a> {
/// This indicates a value that is dynamically generated. This variant /// This indicates a value that is dynamically generated. This variant
@ -56,13 +66,13 @@ impl<'a> EncodeValue<'a> {
impl<'a> PartialEq for EncodeValue<'a> { impl<'a> PartialEq for EncodeValue<'a> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
match *self { match self {
EncodeValue::Closure(_) => { EncodeValue::Closure(_) => {
error!("PartialEq not yet supported"); error!("PartialEq not yet supported");
false false
} }
EncodeValue::Tlv(a) => { EncodeValue::Tlv(a) => {
if let EncodeValue::Tlv(b) = *other { if let EncodeValue::Tlv(b) = other {
a == b a == b
} else { } else {
false false
@ -78,8 +88,8 @@ impl<'a> PartialEq for EncodeValue<'a> {
} }
impl<'a> Debug for EncodeValue<'a> { impl<'a> Debug for EncodeValue<'a> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), std::fmt::Error> { fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), core::fmt::Error> {
match *self { match self {
EncodeValue::Closure(_) => write!(f, "Contains closure"), EncodeValue::Closure(_) => write!(f, "Contains closure"),
EncodeValue::Tlv(t) => write!(f, "{:?}", t), EncodeValue::Tlv(t) => write!(f, "{:?}", t),
EncodeValue::Value(_) => write!(f, "Contains EncodeValue"), EncodeValue::Value(_) => write!(f, "Contains EncodeValue"),
@ -103,21 +113,435 @@ impl<'a> ToTLV for EncodeValue<'a> {
impl<'a> FromTLV<'a> for EncodeValue<'a> { impl<'a> FromTLV<'a> for EncodeValue<'a> {
fn from_tlv(data: &TLVElement<'a>) -> Result<Self, Error> { fn from_tlv(data: &TLVElement<'a>) -> Result<Self, Error> {
Ok(EncodeValue::Tlv(*data)) Ok(EncodeValue::Tlv(data.clone()))
} }
} }
/// An object that can encode EncodeValue into the necessary hierarchical structure pub struct AttrDataEncoder<'a, 'b, 'c> {
/// as expected by the Interaction Model dataver_filter: Option<u32>,
pub trait Encoder { path: AttrPath,
/// Encode a given value tw: &'a mut TLVWriter<'b, 'c>,
fn encode(&mut self, value: EncodeValue);
/// Encode a status report
fn encode_status(&mut self, status: IMStatusCode, cluster_status: u16);
} }
#[derive(ToTLV, Copy, Clone)] impl<'a, 'b, 'c> AttrDataEncoder<'a, 'b, 'c> {
pub struct DeviceType { pub async fn handle_read<T: DataModelHandler>(
pub dtype: u16, item: &Result<AttrDetails<'_>, AttrStatus>,
pub drev: u16, handler: &T,
tw: &mut TLVWriter<'_, '_>,
) -> Result<bool, Error> {
let status = match item {
Ok(attr) => {
let encoder = AttrDataEncoder::new(attr, tw);
let result = {
#[cfg(not(feature = "nightly"))]
{
handler.read(attr, encoder)
}
#[cfg(feature = "nightly")]
{
handler.read(&attr, encoder).await
}
};
match result {
Ok(()) => None,
Err(e) => {
if e.code() == ErrorCode::NoSpace {
return Ok(false);
} else {
attr.status(e.into())?
}
}
}
}
Err(status) => Some(status.clone()),
};
if let Some(status) = status {
AttrResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(true)
}
pub async fn handle_write<T: DataModelHandler>(
item: &Result<(AttrDetails<'_>, TLVElement<'_>), AttrStatus>,
handler: &T,
tw: &mut TLVWriter<'_, '_>,
) -> Result<(), Error> {
let status = match item {
Ok((attr, data)) => {
let result = {
#[cfg(not(feature = "nightly"))]
{
handler.write(attr, AttrData::new(attr.dataver, data))
}
#[cfg(feature = "nightly")]
{
handler
.write(&attr, AttrData::new(attr.dataver, &data))
.await
}
};
match result {
Ok(()) => attr.status(IMStatusCode::Success)?,
Err(error) => attr.status(error.into())?,
}
}
Err(status) => Some(status.clone()),
};
if let Some(status) = status {
status.to_tlv(tw, TagType::Anonymous)?;
}
Ok(())
}
pub fn new(attr: &AttrDetails, tw: &'a mut TLVWriter<'b, 'c>) -> Self {
Self {
dataver_filter: attr.dataver,
path: attr.path(),
tw,
}
}
pub fn with_dataver(self, dataver: u32) -> Result<Option<AttrDataWriter<'a, 'b, 'c>>, Error> {
if self
.dataver_filter
.map(|dataver_filter| dataver_filter != dataver)
.unwrap_or(true)
{
let mut writer = AttrDataWriter::new(self.tw);
writer.start_struct(TagType::Anonymous)?;
writer.start_struct(TagType::Context(AttrRespTag::Data as _))?;
writer.u32(TagType::Context(AttrDataTag::DataVer as _), dataver)?;
self.path
.to_tlv(&mut writer, TagType::Context(AttrDataTag::Path as _))?;
Ok(Some(writer))
} else {
Ok(None)
}
}
}
pub struct AttrDataWriter<'a, 'b, 'c> {
tw: &'a mut TLVWriter<'b, 'c>,
anchor: usize,
completed: bool,
}
impl<'a, 'b, 'c> AttrDataWriter<'a, 'b, 'c> {
pub const TAG: TagType = TagType::Context(AttrDataTag::Data as _);
fn new(tw: &'a mut TLVWriter<'b, 'c>) -> Self {
let anchor = tw.get_tail();
Self {
tw,
anchor,
completed: false,
}
}
pub fn set<T: ToTLV>(self, value: T) -> Result<(), Error> {
value.to_tlv(self.tw, Self::TAG)?;
self.complete()
}
pub fn complete(mut self) -> Result<(), Error> {
self.tw.end_container()?;
self.tw.end_container()?;
self.completed = true;
Ok(())
}
fn reset(&mut self) {
self.tw.rewind_to(self.anchor);
}
}
impl<'a, 'b, 'c> Drop for AttrDataWriter<'a, 'b, 'c> {
fn drop(&mut self) {
if !self.completed {
self.reset();
}
}
}
impl<'a, 'b, 'c> Deref for AttrDataWriter<'a, 'b, 'c> {
type Target = TLVWriter<'b, 'c>;
fn deref(&self) -> &Self::Target {
self.tw
}
}
impl<'a, 'b, 'c> DerefMut for AttrDataWriter<'a, 'b, 'c> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.tw
}
}
pub struct AttrData<'a> {
for_dataver: Option<u32>,
data: &'a TLVElement<'a>,
}
impl<'a> AttrData<'a> {
pub fn new(for_dataver: Option<u32>, data: &'a TLVElement<'a>) -> Self {
Self { for_dataver, data }
}
pub fn with_dataver(self, dataver: u32) -> Result<&'a TLVElement<'a>, Error> {
if let Some(req_dataver) = self.for_dataver {
if req_dataver != dataver {
Err(ErrorCode::DataVersionMismatch)?;
}
}
Ok(self.data)
}
}
#[derive(Default)]
pub struct CmdDataTracker {
skip_status: bool,
}
impl CmdDataTracker {
pub const fn new() -> Self {
Self { skip_status: false }
}
pub(crate) fn complete(&mut self) {
self.skip_status = true;
}
pub fn needs_status(&self) -> bool {
!self.skip_status
}
}
pub struct CmdDataEncoder<'a, 'b, 'c> {
tracker: &'a mut CmdDataTracker,
path: CmdPath,
tw: &'a mut TLVWriter<'b, 'c>,
}
impl<'a, 'b, 'c> CmdDataEncoder<'a, 'b, 'c> {
pub async fn handle<T: DataModelHandler>(
item: &Result<(CmdDetails<'_>, TLVElement<'_>), CmdStatus>,
handler: &T,
tw: &mut TLVWriter<'_, '_>,
exchange: &Exchange<'_>,
) -> Result<(), Error> {
let status = match item {
Ok((cmd, data)) => {
let mut tracker = CmdDataTracker::new();
let encoder = CmdDataEncoder::new(cmd, &mut tracker, tw);
let result = {
#[cfg(not(feature = "nightly"))]
{
handler.invoke(exchange, cmd, data, encoder)
}
#[cfg(feature = "nightly")]
{
handler.invoke(exchange, &cmd, &data, encoder).await
}
};
match result {
Ok(()) => cmd.success(&tracker),
Err(error) => {
error!("Error invoking command: {}", error);
cmd.status(error.into())
}
}
}
Err(status) => {
error!("Error invoking command: {:?}", status);
Some(status.clone())
}
};
if let Some(status) = status {
InvResp::Status(status).to_tlv(tw, TagType::Anonymous)?;
}
Ok(())
}
pub fn new(
cmd: &CmdDetails,
tracker: &'a mut CmdDataTracker,
tw: &'a mut TLVWriter<'b, 'c>,
) -> Self {
Self {
tracker,
path: cmd.path(),
tw,
}
}
pub fn with_command(mut self, cmd: u16) -> Result<CmdDataWriter<'a, 'b, 'c>, Error> {
let mut writer = CmdDataWriter::new(self.tracker, self.tw);
writer.start_struct(TagType::Anonymous)?;
writer.start_struct(TagType::Context(InvRespTag::Cmd as _))?;
self.path.path.leaf = Some(cmd as _);
self.path
.to_tlv(&mut writer, TagType::Context(CmdDataTag::Path as _))?;
Ok(writer)
}
}
pub struct CmdDataWriter<'a, 'b, 'c> {
tracker: &'a mut CmdDataTracker,
tw: &'a mut TLVWriter<'b, 'c>,
anchor: usize,
completed: bool,
}
impl<'a, 'b, 'c> CmdDataWriter<'a, 'b, 'c> {
pub const TAG: TagType = TagType::Context(CmdDataTag::Data as _);
fn new(tracker: &'a mut CmdDataTracker, tw: &'a mut TLVWriter<'b, 'c>) -> Self {
let anchor = tw.get_tail();
Self {
tracker,
tw,
anchor,
completed: false,
}
}
pub fn set<T: ToTLV>(self, value: T) -> Result<(), Error> {
value.to_tlv(self.tw, Self::TAG)?;
self.complete()
}
pub fn complete(mut self) -> Result<(), Error> {
self.tw.end_container()?;
self.tw.end_container()?;
self.completed = true;
self.tracker.complete();
Ok(())
}
fn reset(&mut self) {
self.tw.rewind_to(self.anchor);
}
}
impl<'a, 'b, 'c> Drop for CmdDataWriter<'a, 'b, 'c> {
fn drop(&mut self) {
if !self.completed {
self.reset();
}
}
}
impl<'a, 'b, 'c> Deref for CmdDataWriter<'a, 'b, 'c> {
type Target = TLVWriter<'b, 'c>;
fn deref(&self) -> &Self::Target {
self.tw
}
}
impl<'a, 'b, 'c> DerefMut for CmdDataWriter<'a, 'b, 'c> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.tw
}
}
#[derive(Copy, Clone, Debug)]
pub struct AttrType<T>(PhantomData<fn() -> T>);
impl<T> AttrType<T> {
pub const fn new() -> Self {
Self(PhantomData)
}
pub fn encode(&self, writer: AttrDataWriter, value: T) -> Result<(), Error>
where
T: ToTLV,
{
writer.set(value)
}
pub fn decode<'a>(&self, data: &'a TLVElement) -> Result<T, Error>
where
T: FromTLV<'a>,
{
T::from_tlv(data)
}
}
impl<T> Default for AttrType<T> {
fn default() -> Self {
Self::new()
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct AttrUtfType;
impl AttrUtfType {
pub const fn new() -> Self {
Self
}
pub fn encode(&self, writer: AttrDataWriter, value: &str) -> Result<(), Error> {
writer.set(UtfStr::new(value.as_bytes()))
}
pub fn decode<'a>(&self, data: &'a TLVElement) -> Result<&'a str, IMStatusCode> {
data.str().map_err(|_| IMStatusCode::InvalidDataType)
}
}
#[allow(unused_macros)]
#[macro_export]
macro_rules! attribute_enum {
($en:ty) => {
impl core::convert::TryFrom<$crate::data_model::objects::AttrId> for $en {
type Error = $crate::error::Error;
fn try_from(id: $crate::data_model::objects::AttrId) -> Result<Self, Self::Error> {
<$en>::from_repr(id)
.ok_or_else(|| $crate::error::ErrorCode::AttributeNotFound.into())
}
}
};
}
#[allow(unused_macros)]
#[macro_export]
macro_rules! command_enum {
($en:ty) => {
impl core::convert::TryFrom<$crate::data_model::objects::CmdId> for $en {
type Error = $crate::error::Error;
fn try_from(id: $crate::data_model::objects::CmdId) -> Result<Self, Self::Error> {
<$en>::from_repr(id).ok_or_else(|| $crate::error::ErrorCode::CommandNotFound.into())
}
}
};
} }

View file

@ -15,104 +15,85 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::{data_model::objects::ClusterType, error::*, interaction_model::core::IMStatusCode}; use crate::{acl::Accessor, interaction_model::core::IMStatusCode};
use std::fmt; use core::fmt;
use super::{ClusterId, DeviceType}; use super::{AttrId, Attribute, Cluster, ClusterId, CmdId, DeviceType, EndptId};
pub const CLUSTERS_PER_ENDPT: usize = 9; #[derive(Debug, Clone)]
pub struct Endpoint<'a> {
pub struct Endpoint { pub id: EndptId,
dev_type: DeviceType, pub device_type: DeviceType,
clusters: Vec<Box<dyn ClusterType>>, pub clusters: &'a [Cluster<'a>],
} }
pub type BoxedClusters = [Box<dyn ClusterType>]; impl<'a> Endpoint<'a> {
pub fn match_attributes(
impl Endpoint {
pub fn new(dev_type: DeviceType) -> Result<Box<Endpoint>, Error> {
Ok(Box::new(Endpoint {
dev_type,
clusters: Vec::with_capacity(CLUSTERS_PER_ENDPT),
}))
}
pub fn add_cluster(&mut self, cluster: Box<dyn ClusterType>) -> Result<(), Error> {
if self.clusters.len() < self.clusters.capacity() {
self.clusters.push(cluster);
Ok(())
} else {
Err(Error::NoSpace)
}
}
pub fn get_dev_type(&self) -> &DeviceType {
&self.dev_type
}
fn get_cluster_index(&self, cluster_id: ClusterId) -> Option<usize> {
self.clusters.iter().position(|c| c.base().id == cluster_id)
}
pub fn get_cluster(&self, cluster_id: ClusterId) -> Result<&dyn ClusterType, Error> {
let index = self
.get_cluster_index(cluster_id)
.ok_or(Error::ClusterNotFound)?;
Ok(self.clusters[index].as_ref())
}
pub fn get_cluster_mut(
&mut self,
cluster_id: ClusterId,
) -> Result<&mut dyn ClusterType, Error> {
let index = self
.get_cluster_index(cluster_id)
.ok_or(Error::ClusterNotFound)?;
Ok(self.clusters[index].as_mut())
}
// Returns a slice of clusters, with either a single cluster or all (wildcard)
pub fn get_wildcard_clusters(
&self, &self,
cluster: Option<ClusterId>, cl: Option<ClusterId>,
) -> Result<(&BoxedClusters, bool), IMStatusCode> { attr: Option<AttrId>,
if let Some(c) = cluster { ) -> impl Iterator<Item = (&'_ Cluster, &'_ Attribute)> + '_ {
if let Some(i) = self.get_cluster_index(c) { self.match_clusters(cl).flat_map(move |cluster| {
Ok((&self.clusters[i..i + 1], false)) cluster
} else { .match_attributes(attr)
Err(IMStatusCode::UnsupportedCluster) .map(move |attr| (cluster, attr))
})
} }
} else {
Ok((self.clusters.as_slice(), true)) pub fn match_commands(
&self,
cl: Option<ClusterId>,
cmd: Option<CmdId>,
) -> impl Iterator<Item = (&'_ Cluster, CmdId)> + '_ {
self.match_clusters(cl)
.flat_map(move |cluster| cluster.match_commands(cmd).map(move |cmd| (cluster, cmd)))
}
pub fn check_attribute(
&self,
accessor: &Accessor,
cl: ClusterId,
attr: AttrId,
write: bool,
) -> Result<(), IMStatusCode> {
self.check_cluster(cl)
.and_then(|cluster| cluster.check_attribute(accessor, self.id, attr, write))
}
pub fn check_command(
&self,
accessor: &Accessor,
cl: ClusterId,
cmd: CmdId,
) -> Result<(), IMStatusCode> {
self.check_cluster(cl)
.and_then(|cluster| cluster.check_command(accessor, self.id, cmd))
}
pub fn match_clusters(&self, cl: Option<ClusterId>) -> impl Iterator<Item = &'_ Cluster> + '_ {
self.clusters
.iter()
.filter(move |cluster| cl.map(|id| id == cluster.id).unwrap_or(true))
}
pub fn check_cluster(&self, cl: ClusterId) -> Result<&Cluster, IMStatusCode> {
self.clusters
.iter()
.find(|cluster| cluster.id == cl)
.ok_or(IMStatusCode::UnsupportedCluster)
} }
} }
// Returns a slice of clusters, with either a single cluster or all (wildcard) impl<'a> core::fmt::Display for Endpoint<'a> {
pub fn get_wildcard_clusters_mut(
&mut self,
cluster: Option<ClusterId>,
) -> Result<(&mut BoxedClusters, bool), IMStatusCode> {
if let Some(c) = cluster {
if let Some(i) = self.get_cluster_index(c) {
Ok((&mut self.clusters[i..i + 1], false))
} else {
Err(IMStatusCode::UnsupportedCluster)
}
} else {
Ok((&mut self.clusters[..], true))
}
}
}
impl std::fmt::Display for Endpoint {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "clusters:[")?; write!(f, "clusters:[")?;
let mut comma = ""; let mut comma = "";
for element in self.clusters.iter() { for cluster in self.clusters {
write!(f, "{} {{ {} }}", comma, element.base())?; write!(f, "{} {{ {} }}", comma, cluster)?;
comma = ", "; comma = ", ";
} }
write!(f, "]") write!(f, "]")
} }
} }

View file

@ -0,0 +1,513 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::{
error::{Error, ErrorCode},
tlv::TLVElement,
transport::exchange::Exchange,
};
use super::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails};
#[cfg(feature = "nightly")]
pub use asynch::*;
#[cfg(not(feature = "nightly"))]
pub trait DataModelHandler: super::Metadata + Handler {}
#[cfg(not(feature = "nightly"))]
impl<T> DataModelHandler for T where T: super::Metadata + Handler {}
#[cfg(feature = "nightly")]
pub trait DataModelHandler: super::asynch::AsyncMetadata + asynch::AsyncHandler {}
#[cfg(feature = "nightly")]
impl<T> DataModelHandler for T where T: super::asynch::AsyncMetadata + asynch::AsyncHandler {}
pub trait ChangeNotifier<T> {
fn consume_change(&mut self) -> Option<T>;
}
pub trait Handler {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error>;
fn write(&self, _attr: &AttrDetails, _data: AttrData) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into())
}
fn invoke(
&self,
_exchange: &Exchange,
_cmd: &CmdDetails,
_data: &TLVElement,
_encoder: CmdDataEncoder,
) -> Result<(), Error> {
Err(ErrorCode::CommandNotFound.into())
}
}
impl<T> Handler for &T
where
T: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
(**self).read(attr, encoder)
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
(**self).write(attr, data)
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
(**self).invoke(exchange, cmd, data, encoder)
}
}
impl<T> Handler for &mut T
where
T: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
(**self).read(attr, encoder)
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
(**self).write(attr, data)
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
(**self).invoke(exchange, cmd, data, encoder)
}
}
pub trait NonBlockingHandler: Handler {}
impl<T> NonBlockingHandler for &T where T: NonBlockingHandler {}
impl<T> NonBlockingHandler for &mut T where T: NonBlockingHandler {}
impl<M, H> Handler for (M, H)
where
H: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
self.1.read(attr, encoder)
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
self.1.write(attr, data)
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
self.1.invoke(exchange, cmd, data, encoder)
}
}
impl<M, H> NonBlockingHandler for (M, H) where H: NonBlockingHandler {}
pub struct EmptyHandler;
impl EmptyHandler {
pub const fn chain<H>(
self,
handler_endpoint: u16,
handler_cluster: u32,
handler: H,
) -> ChainedHandler<H, Self> {
ChainedHandler {
handler_endpoint,
handler_cluster,
handler,
next: self,
}
}
}
impl Handler for EmptyHandler {
fn read(&self, _attr: &AttrDetails, _encoder: AttrDataEncoder) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into())
}
}
impl NonBlockingHandler for EmptyHandler {}
impl ChangeNotifier<(u16, u32)> for EmptyHandler {
fn consume_change(&mut self) -> Option<(u16, u32)> {
None
}
}
pub struct ChainedHandler<H, T> {
pub handler_endpoint: u16,
pub handler_cluster: u32,
pub handler: H,
pub next: T,
}
impl<H, T> ChainedHandler<H, T> {
pub const fn chain<H2>(
self,
handler_endpoint: u16,
handler_cluster: u32,
handler: H2,
) -> ChainedHandler<H2, Self> {
ChainedHandler {
handler_endpoint,
handler_cluster,
handler,
next: self,
}
}
}
impl<H, T> Handler for ChainedHandler<H, T>
where
H: Handler,
T: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id {
self.handler.read(attr, encoder)
} else {
self.next.read(attr, encoder)
}
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id {
self.handler.write(attr, data)
} else {
self.next.write(attr, data)
}
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id {
self.handler.invoke(exchange, cmd, data, encoder)
} else {
self.next.invoke(exchange, cmd, data, encoder)
}
}
}
impl<H, T> NonBlockingHandler for ChainedHandler<H, T>
where
H: NonBlockingHandler,
T: NonBlockingHandler,
{
}
impl<H, T> ChangeNotifier<(u16, u32)> for ChainedHandler<H, T>
where
H: ChangeNotifier<()>,
T: ChangeNotifier<(u16, u32)>,
{
fn consume_change(&mut self) -> Option<(u16, u32)> {
if self.handler.consume_change().is_some() {
Some((self.handler_endpoint, self.handler_cluster))
} else {
self.next.consume_change()
}
}
}
/// Wrap your `NonBlockingHandler` or `AsyncHandler` implementation in this struct
/// to get your code compilable with and without the `nightly` feature
pub struct HandlerCompat<T>(pub T);
impl<T> Handler for HandlerCompat<T>
where
T: Handler,
{
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
self.0.read(attr, encoder)
}
fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
self.0.write(attr, data)
}
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
self.0.invoke(exchange, cmd, data, encoder)
}
}
impl<T> NonBlockingHandler for HandlerCompat<T> where T: NonBlockingHandler {}
#[allow(unused_macros)]
#[macro_export]
macro_rules! handler_chain_type {
($h:ty) => {
$crate::data_model::objects::ChainedHandler<$h, $crate::data_model::objects::EmptyHandler>
};
($h1:ty $(, $rest:ty)+) => {
$crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+)>
};
($h:ty | $f:ty) => {
$crate::data_model::objects::ChainedHandler<$h, $f>
};
($h1:ty $(, $rest:ty)+ | $f:ty) => {
$crate::data_model::objects::ChainedHandler<$h1, handler_chain_type!($($rest),+ | $f)>
};
}
#[cfg(feature = "nightly")]
mod asynch {
use crate::{
data_model::objects::{AttrData, AttrDataEncoder, AttrDetails, CmdDataEncoder, CmdDetails},
error::{Error, ErrorCode},
tlv::TLVElement,
transport::exchange::Exchange,
};
use super::{ChainedHandler, EmptyHandler, Handler, HandlerCompat, NonBlockingHandler};
pub trait AsyncHandler {
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error>;
async fn write<'a>(
&'a self,
_attr: &'a AttrDetails<'_>,
_data: AttrData<'a>,
) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into())
}
async fn invoke<'a>(
&'a self,
_exchange: &'a Exchange<'_>,
_cmd: &'a CmdDetails<'_>,
_data: &'a TLVElement<'_>,
_encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
Err(ErrorCode::CommandNotFound.into())
}
}
impl<T> AsyncHandler for &mut T
where
T: AsyncHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
(**self).read(attr, encoder).await
}
async fn write<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
(**self).write(attr, data).await
}
async fn invoke<'a>(
&'a self,
exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
(**self).invoke(exchange, cmd, data, encoder).await
}
}
impl<T> AsyncHandler for &T
where
T: AsyncHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
(**self).read(attr, encoder).await
}
async fn write<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
(**self).write(attr, data).await
}
async fn invoke<'a>(
&'a self,
exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
(**self).invoke(exchange, cmd, data, encoder).await
}
}
impl<M, H> AsyncHandler for (M, H)
where
H: AsyncHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
self.1.read(attr, encoder).await
}
async fn write<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
self.1.write(attr, data).await
}
async fn invoke<'a>(
&'a self,
exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
self.1.invoke(exchange, cmd, data, encoder).await
}
}
impl<T> AsyncHandler for HandlerCompat<T>
where
T: NonBlockingHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
Handler::read(&self.0, attr, encoder)
}
async fn write<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
Handler::write(&self.0, attr, data)
}
async fn invoke<'a>(
&'a self,
exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
Handler::invoke(&self.0, exchange, cmd, data, encoder)
}
}
impl AsyncHandler for EmptyHandler {
async fn read<'a>(
&'a self,
_attr: &'a AttrDetails<'_>,
_encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
Err(ErrorCode::AttributeNotFound.into())
}
}
impl<H, T> AsyncHandler for ChainedHandler<H, T>
where
H: AsyncHandler,
T: AsyncHandler,
{
async fn read<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
encoder: AttrDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id
{
self.handler.read(attr, encoder).await
} else {
self.next.read(attr, encoder).await
}
}
async fn write<'a>(
&'a self,
attr: &'a AttrDetails<'_>,
data: AttrData<'a>,
) -> Result<(), Error> {
if self.handler_endpoint == attr.endpoint_id && self.handler_cluster == attr.cluster_id
{
self.handler.write(attr, data).await
} else {
self.next.write(attr, data).await
}
}
async fn invoke<'a>(
&'a self,
exchange: &'a Exchange<'_>,
cmd: &'a CmdDetails<'_>,
data: &'a TLVElement<'_>,
encoder: CmdDataEncoder<'a, '_, '_>,
) -> Result<(), Error> {
if self.handler_endpoint == cmd.endpoint_id && self.handler_cluster == cmd.cluster_id {
self.handler.invoke(exchange, cmd, data, encoder).await
} else {
self.next.invoke(exchange, cmd, data, encoder).await
}
}
}
}

View file

@ -0,0 +1,195 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::data_model::objects::Node;
#[cfg(feature = "nightly")]
pub use asynch::*;
use super::HandlerCompat;
pub trait MetadataGuard {
fn node(&self) -> Node<'_>;
}
impl<T> MetadataGuard for &T
where
T: MetadataGuard,
{
fn node(&self) -> Node<'_> {
(**self).node()
}
}
impl<T> MetadataGuard for &mut T
where
T: MetadataGuard,
{
fn node(&self) -> Node<'_> {
(**self).node()
}
}
pub trait Metadata {
type MetadataGuard<'a>: MetadataGuard
where
Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_>;
}
impl<T> Metadata for &T
where
T: Metadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_> {
(**self).lock()
}
}
impl<T> Metadata for &mut T
where
T: Metadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_> {
(**self).lock()
}
}
impl<'a> MetadataGuard for Node<'a> {
fn node(&self) -> Node<'_> {
Node {
id: self.id,
endpoints: self.endpoints,
}
}
}
impl<'a> Metadata for Node<'a> {
type MetadataGuard<'g> = Node<'g> where Self: 'g;
fn lock(&self) -> Self::MetadataGuard<'_> {
Node {
id: self.id,
endpoints: self.endpoints,
}
}
}
impl<M, H> Metadata for (M, H)
where
M: Metadata,
{
type MetadataGuard<'a> = M::MetadataGuard<'a>
where
Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_> {
self.0.lock()
}
}
impl<T> Metadata for HandlerCompat<T>
where
T: Metadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a>
where
Self: 'a;
fn lock(&self) -> Self::MetadataGuard<'_> {
self.0.lock()
}
}
#[cfg(feature = "nightly")]
pub mod asynch {
use crate::data_model::objects::{HandlerCompat, Node};
use super::{Metadata, MetadataGuard};
pub trait AsyncMetadata {
type MetadataGuard<'a>: MetadataGuard
where
Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_>;
}
impl<T> AsyncMetadata for &T
where
T: AsyncMetadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_> {
(**self).lock().await
}
}
impl<T> AsyncMetadata for &mut T
where
T: AsyncMetadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a> where Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_> {
(**self).lock().await
}
}
impl<'a> AsyncMetadata for Node<'a> {
type MetadataGuard<'g> = Node<'g> where Self: 'g;
async fn lock(&self) -> Self::MetadataGuard<'_> {
Node {
id: self.id,
endpoints: self.endpoints,
}
}
}
impl<M, H> AsyncMetadata for (M, H)
where
M: AsyncMetadata,
{
type MetadataGuard<'a> = M::MetadataGuard<'a>
where
Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_> {
self.0.lock().await
}
}
impl<T> AsyncMetadata for HandlerCompat<T>
where
T: Metadata,
{
type MetadataGuard<'a> = T::MetadataGuard<'a>
where
Self: 'a;
async fn lock(&self) -> Self::MetadataGuard<'_> {
self.0.lock()
}
}
}

View file

@ -14,11 +14,8 @@
* See the License for the specific language governing permissions and * See the License for the specific language governing permissions and
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error;
pub type EndptId = u16; use crate::tlv::{TLVWriter, TagType, ToTLV};
pub type ClusterId = u32;
pub type AttrId = u16;
pub type CmdId = u32;
mod attribute; mod attribute;
pub use attribute::*; pub use attribute::*;
@ -37,3 +34,23 @@ pub use privilege::*;
mod encoder; mod encoder;
pub use encoder::*; pub use encoder::*;
mod handler;
pub use handler::*;
mod dataver;
pub use dataver::*;
mod metadata;
pub use metadata::*;
pub type EndptId = u16;
pub type ClusterId = u32;
pub type AttrId = u16;
pub type CmdId = u32;
#[derive(Debug, ToTLV, Copy, Clone)]
pub struct DeviceType {
pub dtype: u16,
pub drev: u16,
}

View file

@ -16,283 +16,456 @@
*/ */
use crate::{ use crate::{
data_model::objects::{ClusterType, Endpoint}, acl::Accessor,
error::*, alloc,
interaction_model::{core::IMStatusCode, messages::GenericPath}, data_model::objects::Endpoint,
interaction_model::{
core::IMStatusCode,
messages::{
ib::{AttrPath, AttrStatus, CmdStatus, DataVersionFilter},
msg::{InvReq, ReadReq, SubscribeReq, WriteReq},
GenericPath,
},
},
// TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer // TODO: This layer shouldn't really depend on the TLV layer, should create an abstraction layer
tlv::{TLVArray, TLVElement},
};
use core::{
fmt,
iter::{once, Once},
}; };
use std::fmt;
use super::{ClusterId, DeviceType, EndptId}; use super::{AttrDetails, AttrId, Attribute, Cluster, ClusterId, CmdDetails, CmdId, EndptId};
pub trait ChangeConsumer { pub enum WildcardIter<T, E> {
fn endpoint_added(&self, id: EndptId, endpoint: &mut Endpoint) -> Result<(), Error>; None,
Single(Once<E>),
Wildcard(T),
} }
pub const ENDPTS_PER_ACC: usize = 3; impl<T, E> Iterator for WildcardIter<T, E>
where
T: Iterator<Item = E>,
{
type Item = E;
pub type BoxedEndpoints = [Option<Box<Endpoint>>]; fn next(&mut self) -> Option<Self::Item> {
match self {
#[derive(Default)] Self::None => None,
pub struct Node { Self::Single(iter) => iter.next(),
endpoints: [Option<Box<Endpoint>>; ENDPTS_PER_ACC], Self::Wildcard(iter) => iter.next(),
changes_cb: Option<Box<dyn ChangeConsumer>>, }
}
} }
impl std::fmt::Display for Node { #[derive(Debug, Clone)]
pub struct Node<'a> {
pub id: u16,
pub endpoints: &'a [Endpoint<'a>],
}
impl<'a> Node<'a> {
pub fn read<'s, 'm>(
&'s self,
req: &'m ReadReq,
from: Option<GenericPath>,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
{
self.read_attr_requests(
req.attr_requests
.iter()
.flat_map(|attr_requests| attr_requests.iter()),
req.dataver_filters.as_ref(),
req.fabric_filtered,
accessor,
from,
)
}
pub fn subscribing_read<'s, 'm>(
&'s self,
req: &'m SubscribeReq,
from: Option<GenericPath>,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
{
self.read_attr_requests(
req.attr_requests
.iter()
.flat_map(|attr_requests| attr_requests.iter()),
req.dataver_filters.as_ref(),
req.fabric_filtered,
accessor,
from,
)
}
fn read_attr_requests<'s, 'm, P>(
&'s self,
attr_requests: P,
dataver_filters: Option<&'m TLVArray<DataVersionFilter>>,
fabric_filtered: bool,
accessor: &'m Accessor<'m>,
from: Option<GenericPath>,
) -> impl Iterator<Item = Result<AttrDetails, AttrStatus>> + 'm
where
's: 'm,
P: Iterator<Item = AttrPath> + 'm,
{
alloc!(attr_requests.flat_map(move |path| {
if path.to_gp().is_wildcard() {
let from = from.clone();
let iter = self
.match_attributes(path.endpoint, path.cluster, path.attr)
.skip_while(move |(ep, cl, attr)| {
!Self::matches(from.as_ref(), ep.id, cl.id, attr.id as _)
})
.filter(move |(ep, cl, attr)| {
Cluster::check_attr_access(
accessor,
GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)),
false,
attr.access,
)
.is_ok()
})
.map(move |(ep, cl, attr)| {
let dataver = if let Some(dataver_filters) = dataver_filters {
dataver_filters.iter().find_map(|filter| {
(filter.path.endpoint == ep.id && filter.path.cluster == cl.id)
.then_some(filter.data_ver)
})
} else {
None
};
Ok(AttrDetails {
node: self,
endpoint_id: ep.id,
cluster_id: cl.id,
attr_id: attr.id,
list_index: path.list_index,
fab_idx: accessor.fab_idx,
fab_filter: fabric_filtered,
dataver,
wildcard: true,
})
});
WildcardIter::Wildcard(iter)
} else {
let ep = path.endpoint.unwrap();
let cl = path.cluster.unwrap();
let attr = path.attr.unwrap();
let result = match self.check_attribute(accessor, ep, cl, attr, false) {
Ok(()) => {
let dataver = if let Some(dataver_filters) = dataver_filters {
dataver_filters.iter().find_map(|filter| {
(filter.path.endpoint == ep && filter.path.cluster == cl)
.then_some(filter.data_ver)
})
} else {
None
};
Ok(AttrDetails {
node: self,
endpoint_id: ep,
cluster_id: cl,
attr_id: attr,
list_index: path.list_index,
fab_idx: accessor.fab_idx,
fab_filter: fabric_filtered,
dataver,
wildcard: false,
})
}
Err(err) => Err(AttrStatus::new(&path.to_gp(), err, 0)),
};
WildcardIter::Single(once(result))
}
}))
}
pub fn write<'m>(
&'m self,
req: &'m WriteReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<(AttrDetails, TLVElement<'m>), AttrStatus>> + 'm {
alloc!(req.write_requests.iter().flat_map(move |attr_data| {
if attr_data.path.cluster.is_none() {
WildcardIter::Single(once(Err(AttrStatus::new(
&attr_data.path.to_gp(),
IMStatusCode::UnsupportedCluster,
0,
))))
} else if attr_data.path.attr.is_none() {
WildcardIter::Single(once(Err(AttrStatus::new(
&attr_data.path.to_gp(),
IMStatusCode::UnsupportedAttribute,
0,
))))
} else if attr_data.path.to_gp().is_wildcard() {
let iter = self
.match_attributes(
attr_data.path.endpoint,
attr_data.path.cluster,
attr_data.path.attr,
)
.filter(move |(ep, cl, attr)| {
Cluster::check_attr_access(
accessor,
GenericPath::new(Some(ep.id), Some(cl.id), Some(attr.id as _)),
true,
attr.access,
)
.is_ok()
})
.map(move |(ep, cl, attr)| {
Ok((
AttrDetails {
node: self,
endpoint_id: ep.id,
cluster_id: cl.id,
attr_id: attr.id,
list_index: attr_data.path.list_index,
fab_idx: accessor.fab_idx,
fab_filter: false,
dataver: attr_data.data_ver,
wildcard: true,
},
attr_data.data.clone().unwrap_tlv().unwrap(),
))
});
WildcardIter::Wildcard(iter)
} else {
let ep = attr_data.path.endpoint.unwrap();
let cl = attr_data.path.cluster.unwrap();
let attr = attr_data.path.attr.unwrap();
let result = match self.check_attribute(accessor, ep, cl, attr, true) {
Ok(()) => Ok((
AttrDetails {
node: self,
endpoint_id: ep,
cluster_id: cl,
attr_id: attr,
list_index: attr_data.path.list_index,
fab_idx: accessor.fab_idx,
fab_filter: false,
dataver: attr_data.data_ver,
wildcard: false,
},
attr_data.data.unwrap_tlv().unwrap(),
)),
Err(err) => Err(AttrStatus::new(&attr_data.path.to_gp(), err, 0)),
};
WildcardIter::Single(once(result))
}
}))
}
pub fn invoke<'m>(
&'m self,
req: &'m InvReq,
accessor: &'m Accessor<'m>,
) -> impl Iterator<Item = Result<(CmdDetails, TLVElement<'m>), CmdStatus>> + 'm {
alloc!(req
.inv_requests
.iter()
.flat_map(|inv_requests| inv_requests.iter())
.flat_map(move |cmd_data| {
if cmd_data.path.path.is_wildcard() {
let iter = self
.match_commands(
cmd_data.path.path.endpoint,
cmd_data.path.path.cluster,
cmd_data.path.path.leaf.map(|leaf| leaf as _),
)
.filter(move |(ep, cl, cmd)| {
Cluster::check_cmd_access(
accessor,
GenericPath::new(Some(ep.id), Some(cl.id), Some(*cmd)),
)
.is_ok()
})
.map(move |(ep, cl, cmd)| {
Ok((
CmdDetails {
node: self,
endpoint_id: ep.id,
cluster_id: cl.id,
cmd_id: cmd,
wildcard: true,
},
cmd_data.data.clone().unwrap_tlv().unwrap(),
))
});
WildcardIter::Wildcard(iter)
} else {
let ep = cmd_data.path.path.endpoint.unwrap();
let cl = cmd_data.path.path.cluster.unwrap();
let cmd = cmd_data.path.path.leaf.unwrap();
let result = match self.check_command(accessor, ep, cl, cmd) {
Ok(()) => Ok((
CmdDetails {
node: self,
endpoint_id: cmd_data.path.path.endpoint.unwrap(),
cluster_id: cmd_data.path.path.cluster.unwrap(),
cmd_id: cmd_data.path.path.leaf.unwrap(),
wildcard: false,
},
cmd_data.data.unwrap_tlv().unwrap(),
)),
Err(err) => Err(CmdStatus::new(cmd_data.path, err, 0)),
};
WildcardIter::Single(once(result))
}
}))
}
fn matches(path: Option<&GenericPath>, ep: EndptId, cl: ClusterId, leaf: u32) -> bool {
if let Some(path) = path {
path.endpoint.map(|id| id == ep).unwrap_or(true)
&& path.cluster.map(|id| id == cl).unwrap_or(true)
&& path.leaf.map(|id| id == leaf).unwrap_or(true)
} else {
true
}
}
pub fn match_attributes(
&self,
ep: Option<EndptId>,
cl: Option<ClusterId>,
attr: Option<AttrId>,
) -> impl Iterator<Item = (&'_ Endpoint, &'_ Cluster, &'_ Attribute)> + '_ {
self.match_endpoints(ep).flat_map(move |endpoint| {
endpoint
.match_attributes(cl, attr)
.map(move |(cl, attr)| (endpoint, cl, attr))
})
}
pub fn match_commands(
&self,
ep: Option<EndptId>,
cl: Option<ClusterId>,
cmd: Option<CmdId>,
) -> impl Iterator<Item = (&'_ Endpoint, &'_ Cluster, CmdId)> + '_ {
self.match_endpoints(ep).flat_map(move |endpoint| {
endpoint
.match_commands(cl, cmd)
.map(move |(cl, cmd)| (endpoint, cl, cmd))
})
}
pub fn check_attribute(
&self,
accessor: &Accessor,
ep: EndptId,
cl: ClusterId,
attr: AttrId,
write: bool,
) -> Result<(), IMStatusCode> {
self.check_endpoint(ep)
.and_then(|endpoint| endpoint.check_attribute(accessor, cl, attr, write))
}
pub fn check_command(
&self,
accessor: &Accessor,
ep: EndptId,
cl: ClusterId,
cmd: CmdId,
) -> Result<(), IMStatusCode> {
self.check_endpoint(ep)
.and_then(|endpoint| endpoint.check_command(accessor, cl, cmd))
}
pub fn match_endpoints(&self, ep: Option<EndptId>) -> impl Iterator<Item = &'_ Endpoint> + '_ {
self.endpoints
.iter()
.filter(move |endpoint| ep.map(|id| id == endpoint.id).unwrap_or(true))
}
pub fn check_endpoint(&self, ep: EndptId) -> Result<&Endpoint, IMStatusCode> {
self.endpoints
.iter()
.find(|endpoint| endpoint.id == ep)
.ok_or(IMStatusCode::UnsupportedEndpoint)
}
}
impl<'a> core::fmt::Display for Node<'a> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "node:")?; writeln!(f, "node:")?;
for (i, element) in self.endpoints.iter().enumerate() { for (index, endpoint) in self.endpoints.iter().enumerate() {
if let Some(e) = element { writeln!(f, "endpoint {}: {}", index, endpoint)?;
writeln!(f, "endpoint {}: {}", i, e)?;
}
} }
write!(f, "") write!(f, "")
} }
} }
impl Node { pub struct DynamicNode<'a, const N: usize> {
pub fn new() -> Result<Box<Node>, Error> { id: u16,
let node = Box::default(); endpoints: heapless::Vec<Endpoint<'a>, N>,
Ok(node)
} }
pub fn set_changes_cb(&mut self, consumer: Box<dyn ChangeConsumer>) { impl<'a, const N: usize> DynamicNode<'a, N> {
self.changes_cb = Some(consumer); pub const fn new(id: u16) -> Self {
Self {
id,
endpoints: heapless::Vec::new(),
}
} }
pub fn add_endpoint(&mut self, dev_type: DeviceType) -> Result<EndptId, Error> { pub fn node(&self) -> Node<'_> {
Node {
id: self.id,
endpoints: &self.endpoints,
}
}
pub fn add(&mut self, endpoint: Endpoint<'a>) -> Result<(), Endpoint<'a>> {
if !self.endpoints.iter().any(|ep| ep.id == endpoint.id) {
self.endpoints.push(endpoint)
} else {
Err(endpoint)
}
}
pub fn remove(&mut self, endpoint_id: u16) -> Option<Endpoint<'a>> {
let index = self let index = self
.endpoints .endpoints
.iter() .iter()
.position(|x| x.is_none()) .enumerate()
.ok_or(Error::NoSpace)?; .find_map(|(index, ep)| (ep.id == endpoint_id).then_some(index));
let mut endpoint = Endpoint::new(dev_type)?;
if let Some(cb) = &self.changes_cb {
cb.endpoint_added(index as EndptId, &mut endpoint)?;
}
self.endpoints[index] = Some(endpoint);
Ok(index as EndptId)
}
pub fn get_endpoint(&self, endpoint_id: EndptId) -> Result<&Endpoint, Error> { if let Some(index) = index {
if (endpoint_id as usize) < ENDPTS_PER_ACC { Some(self.endpoints.swap_remove(index))
let endpoint = self.endpoints[endpoint_id as usize]
.as_ref()
.ok_or(Error::EndpointNotFound)?;
Ok(endpoint)
} else { } else {
Err(Error::EndpointNotFound) None
}
} }
} }
pub fn get_endpoint_mut(&mut self, endpoint_id: EndptId) -> Result<&mut Endpoint, Error> { impl<'a, const N: usize> core::fmt::Display for DynamicNode<'a, N> {
if (endpoint_id as usize) < ENDPTS_PER_ACC { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let endpoint = self.endpoints[endpoint_id as usize] self.node().fmt(f)
.as_mut()
.ok_or(Error::EndpointNotFound)?;
Ok(endpoint)
} else {
Err(Error::EndpointNotFound)
}
}
pub fn get_cluster_mut(
&mut self,
e: EndptId,
c: ClusterId,
) -> Result<&mut dyn ClusterType, Error> {
self.get_endpoint_mut(e)?.get_cluster_mut(c)
}
pub fn get_cluster(&self, e: EndptId, c: ClusterId) -> Result<&dyn ClusterType, Error> {
self.get_endpoint(e)?.get_cluster(c)
}
pub fn add_cluster(
&mut self,
endpoint_id: EndptId,
cluster: Box<dyn ClusterType>,
) -> Result<(), Error> {
let endpoint_id = endpoint_id as usize;
if endpoint_id < ENDPTS_PER_ACC {
self.endpoints[endpoint_id]
.as_mut()
.ok_or(Error::NoEndpoint)?
.add_cluster(cluster)
} else {
Err(Error::Invalid)
}
}
// Returns a slice of endpoints, with either a single endpoint or all (wildcard)
pub fn get_wildcard_endpoints(
&self,
endpoint: Option<EndptId>,
) -> Result<(&BoxedEndpoints, usize, bool), IMStatusCode> {
if let Some(e) = endpoint {
let e = e as usize;
if self.endpoints.len() <= e || self.endpoints[e].is_none() {
Err(IMStatusCode::UnsupportedEndpoint)
} else {
Ok((&self.endpoints[e..e + 1], e, false))
}
} else {
Ok((&self.endpoints[..], 0, true))
}
}
pub fn get_wildcard_endpoints_mut(
&mut self,
endpoint: Option<EndptId>,
) -> Result<(&mut BoxedEndpoints, usize, bool), IMStatusCode> {
if let Some(e) = endpoint {
let e = e as usize;
if self.endpoints.len() <= e || self.endpoints[e].is_none() {
Err(IMStatusCode::UnsupportedEndpoint)
} else {
Ok((&mut self.endpoints[e..e + 1], e, false))
}
} else {
Ok((&mut self.endpoints[..], 0, true))
}
}
/// Run a closure for all endpoints as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_endpoint<T>(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &Endpoint) -> Result<(), IMStatusCode>,
{
let mut current_path = *path;
let (endpoints, mut endpoint_id, wildcard) = self.get_wildcard_endpoints(path.endpoint)?;
for e in endpoints.iter() {
if let Some(e) = e {
current_path.endpoint = Some(endpoint_id as EndptId);
f(&current_path, e.as_ref())
.or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
endpoint_id += 1;
}
Ok(())
}
/// Run a closure for all endpoints (mutable) as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_endpoint_mut<T>(
&mut self,
path: &GenericPath,
mut f: T,
) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &mut Endpoint) -> Result<(), IMStatusCode>,
{
let mut current_path = *path;
let (endpoints, mut endpoint_id, wildcard) =
self.get_wildcard_endpoints_mut(path.endpoint)?;
for e in endpoints.iter_mut() {
if let Some(e) = e {
current_path.endpoint = Some(endpoint_id as EndptId);
f(&current_path, e.as_mut())
.or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
endpoint_id += 1;
}
Ok(())
}
/// Run a closure for all clusters as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_cluster<T>(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &dyn ClusterType) -> Result<(), IMStatusCode>,
{
self.for_each_endpoint(path, |p, e| {
let mut current_path = *p;
let (clusters, wildcard) = e.get_wildcard_clusters(p.cluster)?;
for c in clusters.iter() {
current_path.cluster = Some(c.base().id);
f(&current_path, c.as_ref())
.or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
Ok(())
})
}
/// Run a closure for all clusters (mutable) as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_cluster_mut<T>(
&mut self,
path: &GenericPath,
mut f: T,
) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &mut dyn ClusterType) -> Result<(), IMStatusCode>,
{
self.for_each_endpoint_mut(path, |p, e| {
let mut current_path = *p;
let (clusters, wildcard) = e.get_wildcard_clusters_mut(p.cluster)?;
for c in clusters.iter_mut() {
current_path.cluster = Some(c.base().id);
f(&current_path, c.as_mut())
.or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
Ok(())
})
}
/// Run a closure for all attributes as specified in the path
///
/// Note that the path is a GenericPath and hence can be a wildcard path. The behaviour
/// of this function is to only capture the successful invocations and ignore the erroneous
/// ones. This is inline with the expected behaviour for wildcard, where it implies that
/// 'please run this operation on this wildcard path "wherever possible"'
///
/// It is expected that if the closure that you pass here returns an error it may not reach
/// out to the caller, in case there was a wildcard path specified
pub fn for_each_attribute<T>(&self, path: &GenericPath, mut f: T) -> Result<(), IMStatusCode>
where
T: FnMut(&GenericPath, &dyn ClusterType) -> Result<(), IMStatusCode>,
{
self.for_each_cluster(path, |current_path, c| {
let mut current_path = *current_path;
let (attributes, wildcard) = c
.base()
.get_wildcard_attribute(path.leaf.map(|at| at as u16))?;
for a in attributes.iter() {
current_path.leaf = Some(a.id as u32);
f(&current_path, c).or_else(|e| if !wildcard { Err(e) } else { Ok(()) })?;
}
Ok(())
})
} }
} }

View file

@ -16,7 +16,7 @@
*/ */
use crate::{ use crate::{
error::Error, error::{Error, ErrorCode},
tlv::{FromTLV, TLVElement, ToTLV}, tlv::{FromTLV, TLVElement, ToTLV},
}; };
use log::error; use log::error;
@ -47,12 +47,12 @@ impl FromTLV<'_> for Privilege {
1 => Ok(Privilege::VIEW), 1 => Ok(Privilege::VIEW),
2 => { 2 => {
error!("ProxyView privilege not yet supporteds"); error!("ProxyView privilege not yet supporteds");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
3 => Ok(Privilege::OPERATE), 3 => Ok(Privilege::OPERATE),
4 => Ok(Privilege::MANAGE), 4 => Ok(Privilege::MANAGE),
5 => Ok(Privilege::ADMIN), 5 => Ok(Privilege::ADMIN),
_ => Err(Error::Invalid), _ => Err(ErrorCode::Invalid.into()),
} }
} }
} }

View file

@ -0,0 +1,125 @@
use core::{borrow::Borrow, cell::RefCell};
use crate::{
acl::AclMgr,
fabric::FabricMgr,
handler_chain_type,
mdns::Mdns,
secure_channel::pake::PaseMgr,
utils::{epoch::Epoch, rand::Rand},
};
use super::{
cluster_basic_information::{self, BasicInfoCluster, BasicInfoConfig},
objects::{Cluster, EmptyHandler, Endpoint, EndptId},
sdm::{
admin_commissioning::{self, AdminCommCluster},
dev_att::DevAttDataFetcher,
failsafe::FailSafe,
general_commissioning::{self, GenCommCluster},
noc::{self, NocCluster},
nw_commissioning::{self, NwCommCluster},
},
system_model::{
access_control::{self, AccessControlCluster},
descriptor::{self, DescriptorCluster},
},
};
pub type RootEndpointHandler<'a> = handler_chain_type!(
DescriptorCluster<'static>,
BasicInfoCluster<'a>,
GenCommCluster<'a>,
NwCommCluster,
AdminCommCluster<'a>,
NocCluster<'a>,
AccessControlCluster<'a>
);
pub const CLUSTERS: [Cluster<'static>; 7] = [
descriptor::CLUSTER,
cluster_basic_information::CLUSTER,
general_commissioning::CLUSTER,
nw_commissioning::CLUSTER,
admin_commissioning::CLUSTER,
noc::CLUSTER,
access_control::CLUSTER,
];
pub const fn endpoint(id: EndptId) -> Endpoint<'static> {
Endpoint {
id,
device_type: super::device_types::DEV_TYPE_ROOT_NODE,
clusters: &CLUSTERS,
}
}
pub fn handler<'a, T>(endpoint_id: u16, matter: &'a T) -> RootEndpointHandler<'a>
where
T: Borrow<BasicInfoConfig<'a>>
+ Borrow<dyn DevAttDataFetcher + 'a>
+ Borrow<RefCell<PaseMgr>>
+ Borrow<RefCell<FabricMgr>>
+ Borrow<RefCell<AclMgr>>
+ Borrow<RefCell<FailSafe>>
+ Borrow<dyn Mdns + 'a>
+ Borrow<Epoch>
+ Borrow<Rand>
+ 'a,
{
wrap(
endpoint_id,
matter.borrow(),
matter.borrow(),
matter.borrow(),
matter.borrow(),
matter.borrow(),
matter.borrow(),
matter.borrow(),
*matter.borrow(),
*matter.borrow(),
)
}
#[allow(clippy::too_many_arguments)]
pub fn wrap<'a>(
endpoint_id: u16,
basic_info: &'a BasicInfoConfig<'a>,
dev_att: &'a dyn DevAttDataFetcher,
pase: &'a RefCell<PaseMgr>,
fabric: &'a RefCell<FabricMgr>,
acl: &'a RefCell<AclMgr>,
failsafe: &'a RefCell<FailSafe>,
mdns: &'a dyn Mdns,
epoch: Epoch,
rand: Rand,
) -> RootEndpointHandler<'a> {
EmptyHandler
.chain(
endpoint_id,
access_control::ID,
AccessControlCluster::new(acl, rand),
)
.chain(
endpoint_id,
noc::ID,
NocCluster::new(dev_att, fabric, acl, failsafe, mdns, epoch, rand),
)
.chain(
endpoint_id,
admin_commissioning::ID,
AdminCommCluster::new(pase, mdns, rand),
)
.chain(endpoint_id, nw_commissioning::ID, NwCommCluster::new(rand))
.chain(
endpoint_id,
general_commissioning::ID,
GenCommCluster::new(failsafe, rand),
)
.chain(
endpoint_id,
cluster_basic_information::ID,
BasicInfoCluster::new(basic_info, rand),
)
.chain(endpoint_id, descriptor::ID, DescriptorCluster::new(rand))
}

View file

@ -15,15 +15,21 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::cmd_enter; use core::cell::RefCell;
use core::convert::TryInto;
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::interaction_model::core::IMStatusCode; use crate::mdns::Mdns;
use crate::secure_channel::pake::PaseMgr; use crate::secure_channel::pake::PaseMgr;
use crate::secure_channel::spake2p::VerifierData; use crate::secure_channel::spake2p::VerifierData;
use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement}; use crate::tlv::{FromTLV, Nullable, OctetStr, TLVElement};
use crate::{error::*, interaction_model::command::CommandReq}; use crate::transport::exchange::Exchange;
use log::{error, info}; use crate::utils::rand::Rand;
use crate::{attribute_enum, cmd_enter};
use crate::{command_enum, error::*};
use log::info;
use num_derive::FromPrimitive; use num_derive::FromPrimitive;
use strum::{EnumDiscriminants, FromRepr};
pub const ID: u32 = 0x003C; pub const ID: u32 = 0x003C;
@ -34,120 +40,54 @@ pub enum WindowStatus {
BasicWindowOpen = 2, BasicWindowOpen = 2,
} }
#[derive(FromPrimitive)] #[derive(Copy, Clone, Debug, FromRepr, EnumDiscriminants)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
WindowStatus = 0, WindowStatus(AttrType<u8>) = 0,
AdminFabricIndex = 1, AdminFabricIndex(AttrType<Nullable<u8>>) = 1,
AdminVendorId = 2, AdminVendorId(AttrType<Nullable<u8>>) = 2,
} }
#[derive(FromPrimitive)] attribute_enum!(Attributes);
#[derive(FromRepr)]
#[repr(u32)]
pub enum Commands { pub enum Commands {
OpenCommWindow = 0x00, OpenCommWindow = 0x00,
OpenBasicCommWindow = 0x01, OpenBasicCommWindow = 0x01,
RevokeComm = 0x02, RevokeComm = 0x02,
} }
fn attr_window_status_new() -> Attribute { command_enum!(Commands);
pub const CLUSTER: Cluster<'static> = Cluster {
id: ID as _,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new( Attribute::new(
Attributes::WindowStatus as u16, AttributesDiscriminants::WindowStatus as u16,
AttrValue::Custom,
Access::RV, Access::RV,
Quality::NONE, Quality::NONE,
) ),
}
fn attr_admin_fabid_new() -> Attribute {
Attribute::new( Attribute::new(
Attributes::AdminFabricIndex as u16, AttributesDiscriminants::AdminFabricIndex as u16,
AttrValue::Custom,
Access::RV, Access::RV,
Quality::NULLABLE, Quality::NULLABLE,
) ),
}
fn attr_admin_vid_new() -> Attribute {
Attribute::new( Attribute::new(
Attributes::AdminVendorId as u16, AttributesDiscriminants::AdminVendorId as u16,
AttrValue::Custom,
Access::RV, Access::RV,
Quality::NULLABLE, Quality::NULLABLE,
) ),
} ],
commands: &[
pub struct AdminCommCluster { Commands::OpenCommWindow as _,
pase_mgr: PaseMgr, // Commands::OpenBasicCommWindow as _,
base: Cluster, // Commands::RevokeComm as _,
} ],
};
impl ClusterType for AdminCommCluster {
fn base(&self) -> &Cluster {
&self.base
}
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
}
fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) {
match num::FromPrimitive::from_u16(attr.attr_id) {
Some(Attributes::WindowStatus) => {
let status = 1_u8;
encoder.encode(EncodeValue::Value(&status))
}
Some(Attributes::AdminVendorId) => {
let vid = Nullable::NotNull(1_u8);
encoder.encode(EncodeValue::Value(&vid))
}
Some(Attributes::AdminFabricIndex) => {
let vid = Nullable::NotNull(1_u8);
encoder.encode(EncodeValue::Value(&vid))
}
_ => {
error!("Unsupported Attribute: this shouldn't happen");
}
}
}
fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> {
let cmd = cmd_req
.cmd
.path
.leaf
.map(num::FromPrimitive::from_u32)
.ok_or(IMStatusCode::UnsupportedCommand)?
.ok_or(IMStatusCode::UnsupportedCommand)?;
match cmd {
Commands::OpenCommWindow => self.handle_command_opencomm_win(cmd_req),
_ => Err(IMStatusCode::UnsupportedCommand),
}
}
}
impl AdminCommCluster {
pub fn new(pase_mgr: PaseMgr) -> Result<Box<Self>, Error> {
let mut c = Box::new(AdminCommCluster {
pase_mgr,
base: Cluster::new(ID)?,
});
c.base.add_attribute(attr_window_status_new())?;
c.base.add_attribute(attr_admin_fabid_new())?;
c.base.add_attribute(attr_admin_vid_new())?;
Ok(c)
}
fn handle_command_opencomm_win(
&mut self,
cmd_req: &mut CommandReq,
) -> Result<(), IMStatusCode> {
cmd_enter!("Open Commissioning Window");
let req =
OpenCommWindowReq::from_tlv(&cmd_req.data).map_err(|_| IMStatusCode::InvalidCommand)?;
let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0);
self.pase_mgr
.enable_pase_session(verifier, req.discriminator)?;
Err(IMStatusCode::Success)
}
}
#[derive(FromTLV)] #[derive(FromTLV)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
@ -158,3 +98,88 @@ pub struct OpenCommWindowReq<'a> {
iterations: u32, iterations: u32,
salt: OctetStr<'a>, salt: OctetStr<'a>,
} }
pub struct AdminCommCluster<'a> {
data_ver: Dataver,
pase_mgr: &'a RefCell<PaseMgr>,
mdns: &'a dyn Mdns,
}
impl<'a> AdminCommCluster<'a> {
pub fn new(pase_mgr: &'a RefCell<PaseMgr>, mdns: &'a dyn Mdns, rand: Rand) -> Self {
Self {
data_ver: Dataver::new(rand),
pase_mgr,
mdns,
}
}
pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
match attr.attr_id.try_into()? {
Attributes::WindowStatus(codec) => codec.encode(writer, 1),
Attributes::AdminVendorId(codec) => codec.encode(writer, Nullable::NotNull(1)),
Attributes::AdminFabricIndex(codec) => {
codec.encode(writer, Nullable::NotNull(1))
}
}
}
} else {
Ok(())
}
}
pub fn invoke(
&self,
cmd: &CmdDetails,
data: &TLVElement,
_encoder: CmdDataEncoder,
) -> Result<(), Error> {
match cmd.cmd_id.try_into()? {
Commands::OpenCommWindow => self.handle_command_opencomm_win(data)?,
_ => Err(ErrorCode::CommandNotFound)?,
}
self.data_ver.changed();
Ok(())
}
fn handle_command_opencomm_win(&self, data: &TLVElement) -> Result<(), Error> {
cmd_enter!("Open Commissioning Window");
let req = OpenCommWindowReq::from_tlv(data)?;
let verifier = VerifierData::new(req.verifier.0, req.iterations, req.salt.0);
self.pase_mgr
.borrow_mut()
.enable_pase_session(verifier, req.discriminator, self.mdns)?;
Ok(())
}
}
impl<'a> Handler for AdminCommCluster<'a> {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
AdminCommCluster::read(self, attr, encoder)
}
fn invoke(
&self,
_exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
AdminCommCluster::invoke(self, cmd, data, encoder)
}
}
impl<'a> NonBlockingHandler for AdminCommCluster<'a> {}
impl<'a> ChangeNotifier<()> for AdminCommCluster<'a> {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
}
}

View file

@ -15,9 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::{error::Error, transport::session::SessionMode}; use crate::{
error::{Error, ErrorCode},
transport::session::SessionMode,
};
use log::error; use log::error;
use std::sync::RwLock;
#[derive(PartialEq)] #[derive(PartialEq)]
#[allow(dead_code)] #[allow(dead_code)]
@ -42,26 +44,20 @@ pub enum State {
Armed(ArmedCtx), Armed(ArmedCtx),
} }
pub struct FailSafeInner { pub struct FailSafe {
state: State, state: State,
} }
pub struct FailSafe {
state: RwLock<FailSafeInner>,
}
impl FailSafe { impl FailSafe {
pub fn new() -> Self { #[inline(always)]
Self { pub const fn new() -> Self {
state: RwLock::new(FailSafeInner { state: State::Idle }), Self { state: State::Idle }
}
} }
pub fn arm(&self, timeout: u8, session_mode: SessionMode) -> Result<(), Error> { pub fn arm(&mut self, timeout: u8, session_mode: SessionMode) -> Result<(), Error> {
let mut inner = self.state.write()?; match &mut self.state {
match &mut inner.state {
State::Idle => { State::Idle => {
inner.state = State::Armed(ArmedCtx { self.state = State::Armed(ArmedCtx {
session_mode, session_mode,
timeout, timeout,
noc_state: NocState::NocNotRecvd, noc_state: NocState::NocNotRecvd,
@ -69,7 +65,8 @@ impl FailSafe {
} }
State::Armed(c) => { State::Armed(c) => {
if c.session_mode != session_mode { if c.session_mode != session_mode {
return Err(Error::Invalid); error!("Received Fail-Safe Arm with different session modes; current {:?}, incoming {:?}", c.session_mode, session_mode);
Err(ErrorCode::Invalid)?;
} }
// re-arm // re-arm
c.timeout = timeout; c.timeout = timeout;
@ -78,58 +75,55 @@ impl FailSafe {
Ok(()) Ok(())
} }
pub fn disarm(&self, session_mode: SessionMode) -> Result<(), Error> { pub fn disarm(&mut self, session_mode: SessionMode) -> Result<(), Error> {
let mut inner = self.state.write()?; match &mut self.state {
match &mut inner.state {
State::Idle => { State::Idle => {
error!("Received Fail-Safe Disarm without it being armed"); error!("Received Fail-Safe Disarm without it being armed");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
State::Armed(c) => { State::Armed(c) => {
match c.noc_state { match c.noc_state {
NocState::NocNotRecvd => return Err(Error::Invalid), NocState::NocNotRecvd => Err(ErrorCode::Invalid)?,
NocState::AddNocRecvd(idx) | NocState::UpdateNocRecvd(idx) => { NocState::AddNocRecvd(idx) | NocState::UpdateNocRecvd(idx) => {
if let SessionMode::Case(c) = session_mode { if let SessionMode::Case(c) = session_mode {
if c.fab_idx != idx { if c.fab_idx != idx {
error!( error!(
"Received disarm in separate session from previous Add/Update NOC" "Received disarm in separate session from previous Add/Update NOC"
); );
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
} else { } else {
error!("Received disarm in a non-CASE session"); error!("Received disarm in a non-CASE session");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
} }
} }
inner.state = State::Idle; self.state = State::Idle;
} }
} }
Ok(()) Ok(())
} }
pub fn is_armed(&self) -> bool { pub fn is_armed(&self) -> bool {
self.state.read().unwrap().state != State::Idle self.state != State::Idle
} }
pub fn record_add_noc(&self, fabric_index: u8) -> Result<(), Error> { pub fn record_add_noc(&mut self, fabric_index: u8) -> Result<(), Error> {
let mut inner = self.state.write()?; match &mut self.state {
match &mut inner.state { State::Idle => Err(ErrorCode::Invalid.into()),
State::Idle => Err(Error::Invalid),
State::Armed(c) => { State::Armed(c) => {
if c.noc_state == NocState::NocNotRecvd { if c.noc_state == NocState::NocNotRecvd {
c.noc_state = NocState::AddNocRecvd(fabric_index); c.noc_state = NocState::AddNocRecvd(fabric_index);
Ok(()) Ok(())
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
} }
} }
pub fn allow_noc_change(&self) -> Result<bool, Error> { pub fn allow_noc_change(&self) -> Result<bool, Error> {
let mut inner = self.state.write()?; let allow = match &self.state {
let allow = match &mut inner.state {
State::Idle => false, State::Idle => false,
State::Armed(c) => c.noc_state == NocState::NocNotRecvd, State::Armed(c) => c.noc_state == NocState::NocNotRecvd,
}; };

View file

@ -15,16 +15,18 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::cmd_enter; use core::cell::RefCell;
use core::convert::TryInto;
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::data_model::sdm::failsafe::FailSafe; use crate::data_model::sdm::failsafe::FailSafe;
use crate::interaction_model::core::IMStatusCode; use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV, UtfStr};
use crate::interaction_model::messages::ib; use crate::transport::exchange::Exchange;
use crate::tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}; use crate::utils::rand::Rand;
use crate::{error::*, interaction_model::command::CommandReq}; use crate::{attribute_enum, cmd_enter};
use log::{error, info}; use crate::{command_enum, error::*};
use num_derive::FromPrimitive; use log::info;
use std::sync::Arc; use strum::{EnumDiscriminants, FromRepr};
#[derive(Clone, Copy)] #[derive(Clone, Copy)]
#[allow(dead_code)] #[allow(dead_code)]
@ -38,65 +40,80 @@ enum CommissioningError {
pub const ID: u32 = 0x0030; pub const ID: u32 = 0x0030;
#[derive(FromPrimitive)] #[derive(FromRepr, EnumDiscriminants)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
BreadCrumb = 0, BreadCrumb(AttrType<u64>) = 0,
BasicCommissioningInfo = 1, BasicCommissioningInfo(()) = 1,
RegConfig = 2, RegConfig(AttrType<u8>) = 2,
LocationCapability = 3, LocationCapability(AttrType<u8>) = 3,
} }
#[derive(FromPrimitive)] attribute_enum!(Attributes);
#[derive(FromRepr)]
#[repr(u32)]
pub enum Commands { pub enum Commands {
ArmFailsafe = 0x00, ArmFailsafe = 0x00,
ArmFailsafeResp = 0x01,
SetRegulatoryConfig = 0x02, SetRegulatoryConfig = 0x02,
SetRegulatoryConfigResp = 0x03,
CommissioningComplete = 0x04, CommissioningComplete = 0x04,
}
command_enum!(Commands);
#[repr(u16)]
pub enum RespCommands {
ArmFailsafeResp = 0x01,
SetRegulatoryConfigResp = 0x03,
CommissioningCompleteResp = 0x05, CommissioningCompleteResp = 0x05,
} }
#[derive(FromTLV, ToTLV)]
#[tlvargs(lifetime = "'a")]
struct CommonResponse<'a> {
error_code: u8,
debug_txt: UtfStr<'a>,
}
pub enum RegLocationType { pub enum RegLocationType {
Indoor = 0, Indoor = 0,
Outdoor = 1, Outdoor = 1,
IndoorOutdoor = 2, IndoorOutdoor = 2,
} }
fn attr_bread_crumb_new(bread_crumb: u64) -> Attribute { pub const CLUSTER: Cluster<'static> = Cluster {
id: ID as _,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new( Attribute::new(
Attributes::BreadCrumb as u16, AttributesDiscriminants::BreadCrumb as u16,
AttrValue::Uint64(bread_crumb), Access::READ.union(Access::WRITE).union(Access::NEED_ADMIN),
Access::READ | Access::WRITE | Access::NEED_ADMIN,
Quality::NONE, Quality::NONE,
) ),
}
fn attr_reg_config_new(reg_config: RegLocationType) -> Attribute {
Attribute::new( Attribute::new(
Attributes::RegConfig as u16, AttributesDiscriminants::RegConfig as u16,
AttrValue::Uint8(reg_config as u8),
Access::RV, Access::RV,
Quality::NONE, Quality::NONE,
) ),
}
fn attr_location_capability_new(reg_config: RegLocationType) -> Attribute {
Attribute::new( Attribute::new(
Attributes::LocationCapability as u16, AttributesDiscriminants::LocationCapability as u16,
AttrValue::Uint8(reg_config as u8),
Access::RV, Access::RV,
Quality::FIXED, Quality::FIXED,
) ),
}
fn attr_comm_info_new() -> Attribute {
Attribute::new( Attribute::new(
Attributes::BasicCommissioningInfo as u16, AttributesDiscriminants::BasicCommissioningInfo as u16,
AttrValue::Custom,
Access::RV, Access::RV,
Quality::FIXED, Quality::FIXED,
) ),
} ],
commands: &[
Commands::ArmFailsafe as _,
Commands::SetRegulatoryConfig as _,
Commands::CommissioningComplete as _,
],
};
#[derive(FromTLV, ToTLV)] #[derive(FromTLV, ToTLV)]
struct FailSafeParams { struct FailSafeParams {
@ -104,144 +121,152 @@ struct FailSafeParams {
bread_crumb: u8, bread_crumb: u8,
} }
pub struct GenCommCluster { pub struct GenCommCluster<'a> {
data_ver: Dataver,
expiry_len: u16, expiry_len: u16,
failsafe: Arc<FailSafe>, failsafe: &'a RefCell<FailSafe>,
base: Cluster,
} }
impl ClusterType for GenCommCluster { impl<'a> GenCommCluster<'a> {
fn base(&self) -> &Cluster { pub fn new(failsafe: &'a RefCell<FailSafe>, rand: Rand) -> Self {
&self.base Self {
} data_ver: Dataver::new(rand),
fn base_mut(&mut self) -> &mut Cluster { failsafe,
&mut self.base
}
fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) {
match num::FromPrimitive::from_u16(attr.attr_id) {
Some(Attributes::BasicCommissioningInfo) => {
encoder.encode(EncodeValue::Closure(&|tag, tw| {
let _ = tw.start_struct(tag);
let _ = tw.u16(TagType::Context(0), self.expiry_len);
let _ = tw.end_container();
}))
}
_ => {
error!("Unsupported Attribute: this shouldn't happen");
}
}
}
fn handle_command(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> {
let cmd = cmd_req
.cmd
.path
.leaf
.map(num::FromPrimitive::from_u32)
.ok_or(IMStatusCode::UnsupportedCommand)?
.ok_or(IMStatusCode::UnsupportedCommand)?;
match cmd {
Commands::ArmFailsafe => self.handle_command_armfailsafe(cmd_req),
Commands::SetRegulatoryConfig => self.handle_command_setregulatoryconfig(cmd_req),
Commands::CommissioningComplete => self.handle_command_commissioningcomplete(cmd_req),
_ => Err(IMStatusCode::UnsupportedCommand),
}
}
}
impl GenCommCluster {
pub fn new() -> Result<Box<Self>, Error> {
let failsafe = Arc::new(FailSafe::new());
let mut c = Box::new(GenCommCluster {
// TODO: Arch-Specific // TODO: Arch-Specific
expiry_len: 120, expiry_len: 120,
failsafe, }
base: Cluster::new(ID)?,
});
c.base.add_attribute(attr_bread_crumb_new(0))?;
// TODO: Arch-Specific
c.base
.add_attribute(attr_reg_config_new(RegLocationType::IndoorOutdoor))?;
// TODO: Arch-Specific
c.base
.add_attribute(attr_location_capability_new(RegLocationType::IndoorOutdoor))?;
c.base.add_attribute(attr_comm_info_new())?;
Ok(c)
} }
pub fn failsafe(&self) -> Arc<FailSafe> { pub fn failsafe(&self) -> &RefCell<FailSafe> {
self.failsafe.clone() self.failsafe
} }
fn handle_command_armfailsafe(&mut self, cmd_req: &mut CommandReq) -> Result<(), IMStatusCode> { pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
match attr.attr_id.try_into()? {
Attributes::BreadCrumb(codec) => codec.encode(writer, 0),
// TODO: Arch-Specific
Attributes::RegConfig(codec) => {
codec.encode(writer, RegLocationType::IndoorOutdoor as _)
}
// TODO: Arch-Specific
Attributes::LocationCapability(codec) => {
codec.encode(writer, RegLocationType::IndoorOutdoor as _)
}
Attributes::BasicCommissioningInfo(_) => {
writer.start_struct(AttrDataWriter::TAG)?;
writer.u16(TagType::Context(0), self.expiry_len)?;
writer.end_container()?;
writer.complete()
}
}
}
} else {
Ok(())
}
}
pub fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
match cmd.cmd_id.try_into()? {
Commands::ArmFailsafe => self.handle_command_armfailsafe(exchange, data, encoder)?,
Commands::SetRegulatoryConfig => {
self.handle_command_setregulatoryconfig(exchange, data, encoder)?
}
Commands::CommissioningComplete => {
self.handle_command_commissioningcomplete(exchange, encoder)?;
}
}
self.data_ver.changed();
Ok(())
}
fn handle_command_armfailsafe(
&self,
exchange: &Exchange,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
cmd_enter!("ARM Fail Safe"); cmd_enter!("ARM Fail Safe");
let p = FailSafeParams::from_tlv(&cmd_req.data)?; let p = FailSafeParams::from_tlv(data)?;
let mut status = CommissioningError::Ok as u8;
if self let status = if self
.failsafe .failsafe
.arm(p.expiry_len, cmd_req.trans.session.get_session_mode()) .borrow_mut()
.arm(
p.expiry_len,
exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?,
)
.is_err() .is_err()
{ {
status = CommissioningError::ErrBusyWithOtherAdmin as u8; CommissioningError::ErrBusyWithOtherAdmin as u8
} } else {
CommissioningError::Ok as u8
};
let cmd_data = CommonResponse { let cmd_data = CommonResponse {
error_code: status, error_code: status,
debug_txt: "".to_owned(), debug_txt: UtfStr::new(b""),
}; };
let resp = ib::InvResp::cmd_new(
0, encoder
ID, .with_command(RespCommands::ArmFailsafeResp as _)?
Commands::ArmFailsafeResp as u16, .set(cmd_data)?;
EncodeValue::Value(&cmd_data),
);
let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous);
cmd_req.trans.complete();
Ok(()) Ok(())
} }
fn handle_command_setregulatoryconfig( fn handle_command_setregulatoryconfig(
&mut self, &self,
cmd_req: &mut CommandReq, _exchange: &Exchange,
) -> Result<(), IMStatusCode> { data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
cmd_enter!("Set Regulatory Config"); cmd_enter!("Set Regulatory Config");
let country_code = cmd_req let country_code = data
.data
.find_tag(1) .find_tag(1)
.map_err(|_| IMStatusCode::InvalidCommand)? .map_err(|_| ErrorCode::InvalidCommand)?
.slice() .slice()
.map_err(|_| IMStatusCode::InvalidCommand)?; .map_err(|_| ErrorCode::InvalidCommand)?;
info!("Received country code: {:?}", country_code); info!("Received country code: {:?}", country_code);
let cmd_data = CommonResponse { let cmd_data = CommonResponse {
error_code: 0, error_code: 0,
debug_txt: "".to_owned(), debug_txt: UtfStr::new(b""),
}; };
let resp = ib::InvResp::cmd_new(
0, encoder
ID, .with_command(RespCommands::SetRegulatoryConfigResp as _)?
Commands::SetRegulatoryConfigResp as u16, .set(cmd_data)?;
EncodeValue::Value(&cmd_data),
);
let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous);
cmd_req.trans.complete();
Ok(()) Ok(())
} }
fn handle_command_commissioningcomplete( fn handle_command_commissioningcomplete(
&mut self, &self,
cmd_req: &mut CommandReq, exchange: &Exchange,
) -> Result<(), IMStatusCode> { encoder: CmdDataEncoder,
) -> Result<(), Error> {
cmd_enter!("Commissioning Complete"); cmd_enter!("Commissioning Complete");
let mut status: u8 = CommissioningError::Ok as u8; let mut status: u8 = CommissioningError::Ok as u8;
// Has to be a Case Session // Has to be a Case Session
if cmd_req.trans.session.get_local_fabric_idx().is_none() { if exchange
.with_session(|sess| Ok(sess.get_local_fabric_idx()))?
.is_none()
{
status = CommissioningError::ErrInvalidAuth as u8; status = CommissioningError::ErrInvalidAuth as u8;
} }
@ -249,7 +274,8 @@ impl GenCommCluster {
// scope that is for this session // scope that is for this session
if self if self
.failsafe .failsafe
.disarm(cmd_req.trans.session.get_session_mode()) .borrow_mut()
.disarm(exchange.with_session(|sess| Ok(sess.get_session_mode().clone()))?)
.is_err() .is_err()
{ {
status = CommissioningError::ErrInvalidAuth as u8; status = CommissioningError::ErrInvalidAuth as u8;
@ -257,22 +283,37 @@ impl GenCommCluster {
let cmd_data = CommonResponse { let cmd_data = CommonResponse {
error_code: status, error_code: status,
debug_txt: "".to_owned(), debug_txt: UtfStr::new(b""),
}; };
let resp = ib::InvResp::cmd_new(
0, encoder
ID, .with_command(RespCommands::CommissioningCompleteResp as _)?
Commands::CommissioningCompleteResp as u16, .set(cmd_data)?;
EncodeValue::Value(&cmd_data),
);
let _ = resp.to_tlv(cmd_req.resp, TagType::Anonymous);
cmd_req.trans.complete();
Ok(()) Ok(())
} }
} }
#[derive(FromTLV, ToTLV)] impl<'a> Handler for GenCommCluster<'a> {
struct CommonResponse { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
error_code: u8, GenCommCluster::read(self, attr, encoder)
debug_txt: String, }
fn invoke(
&self,
exchange: &Exchange,
cmd: &CmdDetails,
data: &TLVElement,
encoder: CmdDataEncoder,
) -> Result<(), Error> {
GenCommCluster::invoke(self, exchange, cmd, data, encoder)
}
}
impl<'a> NonBlockingHandler for GenCommCluster<'a> {}
impl<'a> ChangeNotifier<()> for GenCommCluster<'a> {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
}
} }

File diff suppressed because it is too large Load diff

View file

@ -16,38 +16,59 @@
*/ */
use crate::{ use crate::{
data_model::objects::{Cluster, ClusterType}, data_model::objects::{
error::Error, AttrDataEncoder, AttrDetails, ChangeNotifier, Cluster, Dataver, Handler,
NonBlockingHandler, ATTRIBUTE_LIST, FEATURE_MAP,
},
error::{Error, ErrorCode},
utils::rand::Rand,
}; };
pub const ID: u32 = 0x0031; pub const ID: u32 = 0x0031;
pub struct NwCommCluster {
base: Cluster,
}
impl ClusterType for NwCommCluster {
fn base(&self) -> &Cluster {
&self.base
}
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
}
}
enum FeatureMap { enum FeatureMap {
_Wifi = 0x01, _Wifi = 0x01,
_Thread = 0x02, _Thread = 0x02,
Ethernet = 0x04, Ethernet = 0x04,
} }
pub const CLUSTER: Cluster<'static> = Cluster {
id: ID as _,
feature_map: FeatureMap::Ethernet as _,
attributes: &[FEATURE_MAP, ATTRIBUTE_LIST],
commands: &[],
};
pub struct NwCommCluster {
data_ver: Dataver,
}
impl NwCommCluster { impl NwCommCluster {
pub fn new() -> Result<Box<Self>, Error> { pub fn new(rand: Rand) -> Self {
let mut c = Box::new(Self { Self {
base: Cluster::new(ID)?, data_ver: Dataver::new(rand),
}); }
// TODO: Arch-Specific }
c.base.set_feature_map(FeatureMap::Ethernet as u32)?; }
Ok(c)
impl Handler for NwCommCluster {
fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
Err(ErrorCode::AttributeNotFound.into())
}
} else {
Ok(())
}
}
}
impl NonBlockingHandler for NwCommCluster {}
impl ChangeNotifier<()> for NwCommCluster {
fn consume_change(&mut self) -> Option<()> {
self.data_ver.consume_change(())
} }
} }

View file

@ -15,46 +15,135 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::Arc; use core::cell::RefCell;
use core::convert::TryInto;
use num_derive::FromPrimitive; use strum::{EnumDiscriminants, FromRepr};
use crate::acl::{self, AclEntry, AclMgr}; use crate::acl::{self, AclEntry, AclMgr};
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::error::*;
use crate::interaction_model::core::IMStatusCode;
use crate::interaction_model::messages::ib::{attr_list_write, ListOperation}; use crate::interaction_model::messages::ib::{attr_list_write, ListOperation};
use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV}; use crate::tlv::{FromTLV, TLVElement, TagType, ToTLV};
use crate::utils::rand::Rand;
use crate::{attribute_enum, error::*};
use log::{error, info}; use log::{error, info};
pub const ID: u32 = 0x001F; pub const ID: u32 = 0x001F;
#[derive(FromPrimitive)] #[derive(FromRepr, EnumDiscriminants)]
#[repr(u16)]
pub enum Attributes { pub enum Attributes {
Acl = 0, Acl(()) = 0,
Extension = 1, Extension(()) = 1,
SubjectsPerEntry = 2, SubjectsPerEntry(AttrType<u16>) = 2,
TargetsPerEntry = 3, TargetsPerEntry(AttrType<u16>) = 3,
EntriesPerFabric = 4, EntriesPerFabric(AttrType<u16>) = 4,
} }
pub struct AccessControlCluster { attribute_enum!(Attributes);
base: Cluster,
acl_mgr: Arc<AclMgr>, pub const CLUSTER: Cluster<'static> = Cluster {
id: ID,
feature_map: 0,
attributes: &[
FEATURE_MAP,
ATTRIBUTE_LIST,
Attribute::new(
AttributesDiscriminants::Acl as u16,
Access::RWFA,
Quality::NONE,
),
Attribute::new(
AttributesDiscriminants::Extension as u16,
Access::RWFA,
Quality::NONE,
),
Attribute::new(
AttributesDiscriminants::SubjectsPerEntry as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::TargetsPerEntry as u16,
Access::RV,
Quality::FIXED,
),
Attribute::new(
AttributesDiscriminants::EntriesPerFabric as u16,
Access::RV,
Quality::FIXED,
),
],
commands: &[],
};
pub struct AccessControlCluster<'a> {
data_ver: Dataver,
acl_mgr: &'a RefCell<AclMgr>,
} }
impl AccessControlCluster { impl<'a> AccessControlCluster<'a> {
pub fn new(acl_mgr: Arc<AclMgr>) -> Result<Box<Self>, Error> { pub fn new(acl_mgr: &'a RefCell<AclMgr>, rand: Rand) -> Self {
let mut c = Box::new(AccessControlCluster { Self {
base: Cluster::new(ID)?, data_ver: Dataver::new(rand),
acl_mgr, acl_mgr,
}); }
c.base.add_attribute(attr_acl_new())?; }
c.base.add_attribute(attr_extension_new())?;
c.base.add_attribute(attr_subjects_per_entry_new())?; pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
c.base.add_attribute(attr_targets_per_entry_new())?; if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? {
c.base.add_attribute(attr_entries_per_fabric_new())?; if attr.is_system() {
Ok(c) CLUSTER.read(attr.attr_id, writer)
} else {
match attr.attr_id.try_into()? {
Attributes::Acl(_) => {
writer.start_array(AttrDataWriter::TAG)?;
self.acl_mgr.borrow().for_each_acl(|entry| {
if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx {
entry.to_tlv(&mut writer, TagType::Anonymous)?;
}
Ok(())
})?;
writer.end_container()?;
writer.complete()
}
Attributes::Extension(_) => {
// Empty for now
writer.start_array(AttrDataWriter::TAG)?;
writer.end_container()?;
writer.complete()
}
Attributes::SubjectsPerEntry(codec) => {
codec.encode(writer, acl::SUBJECTS_PER_ENTRY as u16)
}
Attributes::TargetsPerEntry(codec) => {
codec.encode(writer, acl::TARGETS_PER_ENTRY as u16)
}
Attributes::EntriesPerFabric(codec) => {
codec.encode(writer, acl::ENTRIES_PER_FABRIC as u16)
}
}
}
} else {
Ok(())
}
}
pub fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
match attr.attr_id.try_into()? {
Attributes::Acl(_) => {
attr_list_write(attr, data.with_dataver(self.data_ver.get())?, |op, data| {
self.write_acl_attr(&op, data, attr.fab_idx)
})
}
_ => {
error!("Attribute not yet supported: this shouldn't happen");
Err(ErrorCode::AttributeNotFound.into())
}
}
} }
/// Write the ACL Attribute /// Write the ACL Attribute
@ -62,145 +151,63 @@ impl AccessControlCluster {
/// This takes care of 4 things, add item, edit item, delete item, delete list. /// This takes care of 4 things, add item, edit item, delete item, delete list.
/// Care about fabric-scoped behaviour is taken /// Care about fabric-scoped behaviour is taken
fn write_acl_attr( fn write_acl_attr(
&mut self, &self,
op: &ListOperation, op: &ListOperation,
data: &TLVElement, data: &TLVElement,
fab_idx: u8, fab_idx: u8,
) -> Result<(), IMStatusCode> { ) -> Result<(), Error> {
info!("Performing ACL operation {:?}", op); info!("Performing ACL operation {:?}", op);
let result = match op { match op {
ListOperation::AddItem | ListOperation::EditItem(_) => { ListOperation::AddItem | ListOperation::EditItem(_) => {
let mut acl_entry = let mut acl_entry = AclEntry::from_tlv(data)?;
AclEntry::from_tlv(data).map_err(|_| IMStatusCode::ConstraintError)?;
info!("ACL {:?}", acl_entry); info!("ACL {:?}", acl_entry);
// Overwrite the fabric index with our accessing fabric index // Overwrite the fabric index with our accessing fabric index
acl_entry.fab_idx = Some(fab_idx); acl_entry.fab_idx = Some(fab_idx);
if let ListOperation::EditItem(index) = op { if let ListOperation::EditItem(index) = op {
self.acl_mgr.edit(*index as u8, fab_idx, acl_entry) self.acl_mgr
.borrow_mut()
.edit(*index as u8, fab_idx, acl_entry)
} else { } else {
self.acl_mgr.add(acl_entry) self.acl_mgr.borrow_mut().add(acl_entry)
} }
} }
ListOperation::DeleteItem(index) => self.acl_mgr.delete(*index as u8, fab_idx), ListOperation::DeleteItem(index) => {
ListOperation::DeleteList => self.acl_mgr.delete_for_fabric(fab_idx), self.acl_mgr.borrow_mut().delete(*index as u8, fab_idx)
}; }
match result { ListOperation::DeleteList => self.acl_mgr.borrow_mut().delete_for_fabric(fab_idx),
Ok(_) => Ok(()),
Err(Error::NoSpace) => Err(IMStatusCode::ResourceExhausted),
_ => Err(IMStatusCode::ConstraintError),
} }
} }
} }
impl ClusterType for AccessControlCluster { impl<'a> Handler for AccessControlCluster<'a> {
fn base(&self) -> &Cluster { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
&self.base AccessControlCluster::read(self, attr, encoder)
}
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
} }
fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { fn write(&self, attr: &AttrDetails, data: AttrData) -> Result<(), Error> {
match num::FromPrimitive::from_u16(attr.attr_id) { AccessControlCluster::write(self, attr, data)
Some(Attributes::Acl) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
let _ = tw.start_array(tag);
let _ = self.acl_mgr.for_each_acl(|entry| {
if !attr.fab_filter || Some(attr.fab_idx) == entry.fab_idx {
let _ = entry.to_tlv(tw, TagType::Anonymous);
}
});
let _ = tw.end_container();
})),
Some(Attributes::Extension) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
// Empty for now
let _ = tw.start_array(tag);
let _ = tw.end_container();
})),
_ => {
error!("Attribute not yet supported: this shouldn't happen");
}
} }
} }
fn write_attribute( impl<'a> NonBlockingHandler for AccessControlCluster<'a> {}
&mut self,
attr: &AttrDetails,
data: &TLVElement,
) -> Result<(), IMStatusCode> {
let result = if let Some(Attributes::Acl) = num::FromPrimitive::from_u16(attr.attr_id) {
attr_list_write(attr, data, |op, data| {
self.write_acl_attr(&op, data, attr.fab_idx)
})
} else {
error!("Attribute not yet supported: this shouldn't happen");
Err(IMStatusCode::NotFound)
};
if result.is_ok() {
self.base.cluster_changed();
}
result
}
}
fn attr_acl_new() -> Attribute { impl<'a> ChangeNotifier<()> for AccessControlCluster<'a> {
Attribute::new( fn consume_change(&mut self) -> Option<()> {
Attributes::Acl as u16, self.data_ver.consume_change(())
AttrValue::Custom,
Access::RWFA,
Quality::NONE,
)
} }
fn attr_extension_new() -> Attribute {
Attribute::new(
Attributes::Extension as u16,
AttrValue::Custom,
Access::RWFA,
Quality::NONE,
)
}
fn attr_subjects_per_entry_new() -> Attribute {
Attribute::new(
Attributes::SubjectsPerEntry as u16,
AttrValue::Uint16(acl::SUBJECTS_PER_ENTRY as u16),
Access::RV,
Quality::FIXED,
)
}
fn attr_targets_per_entry_new() -> Attribute {
Attribute::new(
Attributes::TargetsPerEntry as u16,
AttrValue::Uint16(acl::TARGETS_PER_ENTRY as u16),
Access::RV,
Quality::FIXED,
)
}
fn attr_entries_per_fabric_new() -> Attribute {
Attribute::new(
Attributes::EntriesPerFabric as u16,
AttrValue::Uint16(acl::ENTRIES_PER_FABRIC as u16),
Access::RV,
Quality::FIXED,
)
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use std::sync::Arc; use core::cell::RefCell;
use crate::{ use crate::{
acl::{AclEntry, AclMgr, AuthMode}, acl::{AclEntry, AclMgr, AuthMode},
data_model::{ data_model::objects::{AttrDataEncoder, AttrDetails, Node, Privilege},
core::read::AttrReadEncoder,
objects::{AttrDetails, ClusterType, Privilege},
},
interaction_model::messages::ib::ListOperation, interaction_model::messages::ib::ListOperation,
tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{get_root_node_struct, ElementType, TLVElement, TLVWriter, TagType, ToTLV},
utils::writebuf::WriteBuf, utils::{rand::dummy_rand, writebuf::WriteBuf},
}; };
use super::AccessControlCluster; use super::AccessControlCluster;
@ -209,26 +216,27 @@ mod tests {
/// Add an ACL entry /// Add an ACL entry
fn acl_cluster_add() { fn acl_cluster_add() {
let mut buf: [u8; 100] = [0; 100]; let mut buf: [u8; 100] = [0; 100];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); let acl_mgr = RefCell::new(AclMgr::new());
let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); new.to_tlv(&mut tw, TagType::Anonymous).unwrap();
let data = get_root_node_struct(writebuf.as_borrow_slice()).unwrap(); let data = get_root_node_struct(writebuf.as_slice()).unwrap();
// Test, ACL has fabric index 2, but the accessing fabric is 1 // Test, ACL has fabric index 2, but the accessing fabric is 1
// the fabric index in the TLV should be ignored and the ACL should be created with entry 1 // the fabric index in the TLV should be ignored and the ACL should be created with entry 1
let result = acl.write_acl_attr(&ListOperation::AddItem, &data, 1); let result = acl.write_acl_attr(&ListOperation::AddItem, &data, 1);
assert_eq!(result, Ok(())); assert!(result.is_ok());
let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case); let verifier = AclEntry::new(1, Privilege::VIEW, AuthMode::Case);
acl_mgr acl_mgr
.borrow()
.for_each_acl(|a| { .for_each_acl(|a| {
assert_eq!(*a, verifier); assert_eq!(*a, verifier);
Ok(())
}) })
.unwrap(); .unwrap();
} }
@ -237,38 +245,39 @@ mod tests {
/// - The listindex used for edit should be relative to the current fabric /// - The listindex used for edit should be relative to the current fabric
fn acl_cluster_edit() { fn acl_cluster_edit() {
let mut buf: [u8; 100] = [0; 100]; let mut buf: [u8; 100] = [0; 100];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
// Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order
let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); let acl_mgr = RefCell::new(AclMgr::new());
let mut verifier = [ let mut verifier = [
AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::VIEW, AuthMode::Case),
AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case),
AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case),
]; ];
for i in verifier { for i in &verifier {
acl_mgr.add(i).unwrap(); acl_mgr.borrow_mut().add(i.clone()).unwrap();
} }
let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case); let new = AclEntry::new(2, Privilege::VIEW, AuthMode::Case);
new.to_tlv(&mut tw, TagType::Anonymous).unwrap(); new.to_tlv(&mut tw, TagType::Anonymous).unwrap();
let data = get_root_node_struct(writebuf.as_borrow_slice()).unwrap(); let data = get_root_node_struct(writebuf.as_slice()).unwrap();
// Test, Edit Fabric 2's index 1 - with accessing fabring as 2 - allow // Test, Edit Fabric 2's index 1 - with accessing fabring as 2 - allow
let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2); let result = acl.write_acl_attr(&ListOperation::EditItem(1), &data, 2);
// Fabric 2's index 1, is actually our index 2, update the verifier // Fabric 2's index 1, is actually our index 2, update the verifier
verifier[2] = new; verifier[2] = new;
assert_eq!(result, Ok(())); assert!(result.is_ok());
// Also validate in the acl_mgr that the entries are in the right order // Also validate in the acl_mgr that the entries are in the right order
let mut index = 0; let mut index = 0;
acl_mgr acl_mgr
.borrow()
.for_each_acl(|a| { .for_each_acl(|a| {
assert_eq!(*a, verifier[index]); assert_eq!(*a, verifier[index]);
index += 1; index += 1;
Ok(())
}) })
.unwrap(); .unwrap();
} }
@ -277,30 +286,32 @@ mod tests {
/// - The listindex used for delete should be relative to the current fabric /// - The listindex used for delete should be relative to the current fabric
fn acl_cluster_delete() { fn acl_cluster_delete() {
// Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order
let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); let acl_mgr = RefCell::new(AclMgr::new());
let input = [ let input = [
AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::VIEW, AuthMode::Case),
AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case),
AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case),
]; ];
for i in input { for i in &input {
acl_mgr.add(i).unwrap(); acl_mgr.borrow_mut().add(i.clone()).unwrap();
} }
let mut acl = AccessControlCluster::new(acl_mgr.clone()).unwrap(); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
// data is don't-care actually // data is don't-care actually
let data = TLVElement::new(TagType::Anonymous, ElementType::True); let data = TLVElement::new(TagType::Anonymous, ElementType::True);
// Test , Delete Fabric 1's index 0 // Test , Delete Fabric 1's index 0
let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, 1); let result = acl.write_acl_attr(&ListOperation::DeleteItem(0), &data, 1);
assert_eq!(result, Ok(())); assert!(result.is_ok());
let verifier = [input[0], input[2]]; let verifier = [input[0].clone(), input[2].clone()];
// Also validate in the acl_mgr that the entries are in the right order // Also validate in the acl_mgr that the entries are in the right order
let mut index = 0; let mut index = 0;
acl_mgr acl_mgr
.borrow()
.for_each_acl(|a| { .for_each_acl(|a| {
assert_eq!(*a, verifier[index]); assert_eq!(*a, verifier[index]);
index += 1; index += 1;
Ok(())
}) })
.unwrap(); .unwrap();
} }
@ -309,84 +320,126 @@ mod tests {
/// - acl read with and without fabric filtering /// - acl read with and without fabric filtering
fn acl_cluster_read() { fn acl_cluster_read() {
let mut buf: [u8; 100] = [0; 100]; let mut buf: [u8; 100] = [0; 100];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
// Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order // Add 3 ACLs, belonging to fabric index 2, 1 and 2, in that order
let acl_mgr = Arc::new(AclMgr::new_with(false).unwrap()); let acl_mgr = RefCell::new(AclMgr::new());
let input = [ let input = [
AclEntry::new(2, Privilege::VIEW, AuthMode::Case), AclEntry::new(2, Privilege::VIEW, AuthMode::Case),
AclEntry::new(1, Privilege::VIEW, AuthMode::Case), AclEntry::new(1, Privilege::VIEW, AuthMode::Case),
AclEntry::new(2, Privilege::ADMIN, AuthMode::Case), AclEntry::new(2, Privilege::ADMIN, AuthMode::Case),
]; ];
for i in input { for i in input {
acl_mgr.add(i).unwrap(); acl_mgr.borrow_mut().add(i).unwrap();
} }
let acl = AccessControlCluster::new(acl_mgr).unwrap(); let acl = AccessControlCluster::new(&acl_mgr, dummy_rand);
// Test 1, all 3 entries are read in the response without fabric filtering // Test 1, all 3 entries are read in the response without fabric filtering
{ {
let mut tw = TLVWriter::new(&mut writebuf); let attr = AttrDetails {
let mut encoder = AttrReadEncoder::new(&mut tw); node: &Node {
let attr_details = AttrDetails { id: 0,
endpoints: &[],
},
endpoint_id: 0,
cluster_id: 0,
attr_id: 0, attr_id: 0,
list_index: None, list_index: None,
fab_idx: 1, fab_idx: 1,
fab_filter: false, fab_filter: false,
dataver: None,
wildcard: false,
}; };
acl.read_custom_attribute(&mut encoder, &attr_details);
let mut tw = TLVWriter::new(&mut writebuf);
let encoder = AttrDataEncoder::new(&attr, &mut tw);
acl.read(&attr, encoder).unwrap();
assert_eq!( assert_eq!(
// &[
// 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54,
// 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254,
// 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24,
// 24
// ],
&[ &[
21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1,
4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 21, 36, 1, 1, 36, 2, 2, 54,
1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, 3, 24, 54, 4, 24, 36, 254, 1, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24,
24 36, 254, 2, 24, 24, 24, 24
], ],
writebuf.as_borrow_slice() writebuf.as_slice()
); );
} }
writebuf.reset(0); writebuf.reset();
// Test 2, only single entry is read in the response with fabric filtering and fabric idx 1 // Test 2, only single entry is read in the response with fabric filtering and fabric idx 1
{ {
let mut tw = TLVWriter::new(&mut writebuf); let attr = AttrDetails {
let mut encoder = AttrReadEncoder::new(&mut tw); node: &Node {
id: 0,
let attr_details = AttrDetails { endpoints: &[],
},
endpoint_id: 0,
cluster_id: 0,
attr_id: 0, attr_id: 0,
list_index: None, list_index: None,
fab_idx: 1, fab_idx: 1,
fab_filter: true, fab_filter: true,
dataver: None,
wildcard: false,
}; };
acl.read_custom_attribute(&mut encoder, &attr_details);
let mut tw = TLVWriter::new(&mut writebuf);
let encoder = AttrDataEncoder::new(&attr, &mut tw);
acl.read(&attr, encoder).unwrap();
assert_eq!( assert_eq!(
// &[
// 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54,
// 4, 24, 36, 254, 1, 24, 24, 24, 24
// ],
&[ &[
21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1,
4, 24, 36, 254, 1, 24, 24, 24, 24 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 1, 24, 24, 24, 24
], ],
writebuf.as_borrow_slice() writebuf.as_slice()
); );
} }
writebuf.reset(0); writebuf.reset();
// Test 3, only single entry is read in the response with fabric filtering and fabric idx 2 // Test 3, only single entry is read in the response with fabric filtering and fabric idx 2
{ {
let mut tw = TLVWriter::new(&mut writebuf); let attr = AttrDetails {
let mut encoder = AttrReadEncoder::new(&mut tw); node: &Node {
id: 0,
let attr_details = AttrDetails { endpoints: &[],
},
endpoint_id: 0,
cluster_id: 0,
attr_id: 0, attr_id: 0,
list_index: None, list_index: None,
fab_idx: 2, fab_idx: 2,
fab_filter: true, fab_filter: true,
dataver: None,
wildcard: false,
}; };
acl.read_custom_attribute(&mut encoder, &attr_details);
let mut tw = TLVWriter::new(&mut writebuf);
let encoder = AttrDataEncoder::new(&attr, &mut tw);
acl.read(&attr, encoder).unwrap();
assert_eq!( assert_eq!(
// &[
// 21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54,
// 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254,
// 2, 24, 24, 24, 24
// ],
&[ &[
21, 53, 1, 36, 0, 0, 55, 1, 24, 54, 2, 21, 36, 1, 1, 36, 2, 2, 54, 3, 24, 54, 21, 53, 1, 36, 0, 0, 55, 1, 36, 2, 0, 36, 3, 0, 36, 4, 0, 24, 54, 2, 21, 36, 1,
4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 1, 36, 2, 2, 54, 3, 24, 54, 4, 24, 36, 254, 2, 24, 21, 36, 1, 5, 36, 2, 2, 54,
2, 24, 24, 24, 24 3, 24, 54, 4, 24, 36, 254, 2, 24, 24, 24, 24
], ],
writebuf.as_borrow_slice() writebuf.as_slice()
); );
} }
} }

View file

@ -15,18 +15,20 @@
* limitations under the License. * limitations under the License.
*/ */
use num_derive::FromPrimitive; use core::convert::TryInto;
use crate::data_model::core::DataModel; use strum::FromRepr;
use crate::attribute_enum;
use crate::data_model::objects::*; use crate::data_model::objects::*;
use crate::error::*; use crate::error::Error;
use crate::interaction_model::messages::GenericPath;
use crate::tlv::{TLVWriter, TagType, ToTLV}; use crate::tlv::{TLVWriter, TagType, ToTLV};
use log::error; use crate::utils::rand::Rand;
pub const ID: u32 = 0x001D; pub const ID: u32 = 0x001D;
#[derive(FromPrimitive)] #[derive(FromRepr)]
#[repr(u16)]
#[allow(clippy::enum_variant_names)] #[allow(clippy::enum_variant_names)]
pub enum Attributes { pub enum Attributes {
DeviceTypeList = 0, DeviceTypeList = 0,
@ -35,134 +37,210 @@ pub enum Attributes {
PartsList = 3, PartsList = 3,
} }
pub struct DescriptorCluster { attribute_enum!(Attributes);
base: Cluster,
endpoint_id: EndptId,
data_model: DataModel,
}
impl DescriptorCluster { pub const CLUSTER: Cluster<'static> = Cluster {
pub fn new(endpoint_id: EndptId, data_model: DataModel) -> Result<Box<Self>, Error> { id: ID as _,
let mut c = Box::new(DescriptorCluster { feature_map: 0,
endpoint_id, attributes: &[
data_model, FEATURE_MAP,
base: Cluster::new(ID)?, ATTRIBUTE_LIST,
}); Attribute::new(Attributes::DeviceTypeList as u16, Access::RV, Quality::NONE),
let attrs = [ Attribute::new(Attributes::ServerList as u16, Access::RV, Quality::NONE),
Attribute::new( Attribute::new(Attributes::PartsList as u16, Access::RV, Quality::NONE),
Attributes::DeviceTypeList as u16, Attribute::new(Attributes::ClientList as u16, Access::RV, Quality::NONE),
AttrValue::Custom, ],
Access::RV, commands: &[],
Quality::NONE,
),
Attribute::new(
Attributes::ServerList as u16,
AttrValue::Custom,
Access::RV,
Quality::NONE,
),
Attribute::new(
Attributes::PartsList as u16,
AttrValue::Custom,
Access::RV,
Quality::NONE,
),
Attribute::new(
Attributes::ClientList as u16,
AttrValue::Custom,
Access::RV,
Quality::NONE,
),
];
c.base.add_attributes(&attrs[..])?;
Ok(c)
}
fn encode_devtype_list(&self, tag: TagType, tw: &mut TLVWriter) {
let path = GenericPath {
endpoint: Some(self.endpoint_id),
cluster: None,
leaf: None,
}; };
let _ = tw.start_array(tag);
let dm = self.data_model.node.read().unwrap(); struct StandardPartsMatcher;
let _ = dm.for_each_endpoint(&path, |_, e| {
let dev_type = e.get_dev_type(); impl PartsMatcher for StandardPartsMatcher {
let _ = dev_type.to_tlv(tw, TagType::Anonymous); fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool {
Ok(()) our_endpoint == 0 && endpoint != our_endpoint
}); }
let _ = tw.end_container();
} }
fn encode_server_list(&self, tag: TagType, tw: &mut TLVWriter) { struct AggregatorPartsMatcher;
let path = GenericPath {
endpoint: Some(self.endpoint_id), impl PartsMatcher for AggregatorPartsMatcher {
cluster: None, fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool {
leaf: None, endpoint != our_endpoint && endpoint != 0
}; }
let _ = tw.start_array(tag);
let dm = self.data_model.node.read().unwrap();
let _ = dm.for_each_cluster(&path, |_current_path, c| {
let _ = tw.u32(TagType::Anonymous, c.base().id());
Ok(())
});
let _ = tw.end_container();
} }
fn encode_parts_list(&self, tag: TagType, tw: &mut TLVWriter) { pub trait PartsMatcher {
let path = GenericPath { fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool;
endpoint: None,
cluster: None,
leaf: None,
};
let _ = tw.start_array(tag);
if self.endpoint_id == 0 {
// TODO: If endpoint is another than 0, need to figure out what to do
let dm = self.data_model.node.read().unwrap();
let _ = dm.for_each_endpoint(&path, |current_path, _| {
if let Some(endpoint_id) = current_path.endpoint {
if endpoint_id != 0 {
let _ = tw.u16(TagType::Anonymous, endpoint_id);
}
}
Ok(())
});
}
let _ = tw.end_container();
} }
fn encode_client_list(&self, tag: TagType, tw: &mut TLVWriter) { impl<T> PartsMatcher for &T
where
T: PartsMatcher,
{
fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool {
(**self).describe(our_endpoint, endpoint)
}
}
impl<T> PartsMatcher for &mut T
where
T: PartsMatcher,
{
fn describe(&self, our_endpoint: EndptId, endpoint: EndptId) -> bool {
(**self).describe(our_endpoint, endpoint)
}
}
pub struct DescriptorCluster<'a> {
matcher: &'a dyn PartsMatcher,
data_ver: Dataver,
}
impl DescriptorCluster<'static> {
pub fn new(rand: Rand) -> Self {
Self::new_matching(&StandardPartsMatcher, rand)
}
pub fn new_aggregator(rand: Rand) -> Self {
Self::new_matching(&AggregatorPartsMatcher, rand)
}
}
impl<'a> DescriptorCluster<'a> {
pub fn new_matching(matcher: &'a dyn PartsMatcher, rand: Rand) -> DescriptorCluster<'a> {
Self {
matcher,
data_ver: Dataver::new(rand),
}
}
pub fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
if let Some(mut writer) = encoder.with_dataver(self.data_ver.get())? {
if attr.is_system() {
CLUSTER.read(attr.attr_id, writer)
} else {
match attr.attr_id.try_into()? {
Attributes::DeviceTypeList => {
self.encode_devtype_list(
attr.node,
attr.endpoint_id,
AttrDataWriter::TAG,
&mut writer,
)?;
writer.complete()
}
Attributes::ServerList => {
self.encode_server_list(
attr.node,
attr.endpoint_id,
AttrDataWriter::TAG,
&mut writer,
)?;
writer.complete()
}
Attributes::PartsList => {
self.encode_parts_list(
attr.node,
attr.endpoint_id,
AttrDataWriter::TAG,
&mut writer,
)?;
writer.complete()
}
Attributes::ClientList => {
self.encode_client_list(
attr.node,
attr.endpoint_id,
AttrDataWriter::TAG,
&mut writer,
)?;
writer.complete()
}
}
}
} else {
Ok(())
}
}
fn encode_devtype_list(
&self,
node: &Node,
endpoint_id: u16,
tag: TagType,
tw: &mut TLVWriter,
) -> Result<(), Error> {
tw.start_array(tag)?;
for endpoint in node.endpoints {
if endpoint.id == endpoint_id {
let dev_type = endpoint.device_type;
dev_type.to_tlv(tw, TagType::Anonymous)?;
}
}
tw.end_container()
}
fn encode_server_list(
&self,
node: &Node,
endpoint_id: u16,
tag: TagType,
tw: &mut TLVWriter,
) -> Result<(), Error> {
tw.start_array(tag)?;
for endpoint in node.endpoints {
if endpoint.id == endpoint_id {
for cluster in endpoint.clusters {
tw.u32(TagType::Anonymous, cluster.id as _)?;
}
}
}
tw.end_container()
}
fn encode_parts_list(
&self,
node: &Node,
endpoint_id: u16,
tag: TagType,
tw: &mut TLVWriter,
) -> Result<(), Error> {
tw.start_array(tag)?;
for endpoint in node.endpoints {
if self.matcher.describe(endpoint_id, endpoint.id) {
tw.u16(TagType::Anonymous, endpoint.id)?;
}
}
tw.end_container()
}
fn encode_client_list(
&self,
_node: &Node,
_endpoint_id: u16,
tag: TagType,
tw: &mut TLVWriter,
) -> Result<(), Error> {
// No Clients supported // No Clients supported
let _ = tw.start_array(tag); tw.start_array(tag)?;
let _ = tw.end_container(); tw.end_container()
} }
} }
impl ClusterType for DescriptorCluster { impl<'a> Handler for DescriptorCluster<'a> {
fn base(&self) -> &Cluster { fn read(&self, attr: &AttrDetails, encoder: AttrDataEncoder) -> Result<(), Error> {
&self.base DescriptorCluster::read(self, attr, encoder)
} }
fn base_mut(&mut self) -> &mut Cluster {
&mut self.base
} }
fn read_custom_attribute(&self, encoder: &mut dyn Encoder, attr: &AttrDetails) { impl<'a> NonBlockingHandler for DescriptorCluster<'a> {}
match num::FromPrimitive::from_u16(attr.attr_id) {
Some(Attributes::DeviceTypeList) => encoder.encode(EncodeValue::Closure(&|tag, tw| { impl<'a> ChangeNotifier<()> for DescriptorCluster<'a> {
self.encode_devtype_list(tag, tw) fn consume_change(&mut self) -> Option<()> {
})), self.data_ver.consume_change(())
Some(Attributes::ServerList) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
self.encode_server_list(tag, tw)
})),
Some(Attributes::PartsList) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
self.encode_parts_list(tag, tw)
})),
Some(Attributes::ClientList) => encoder.encode(EncodeValue::Closure(&|tag, tw| {
self.encode_client_list(tag, tw)
})),
_ => {
error!("Attribute not supported: this shouldn't happen");
}
}
} }
} }

View file

@ -15,15 +15,10 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::{array::TryFromSliceError, fmt, str::Utf8Error};
array::TryFromSliceError, fmt, string::FromUtf8Error, sync::PoisonError, time::SystemTimeError,
};
use async_channel::{SendError, TryRecvError}; #[derive(Debug, PartialEq, Eq, Clone, Copy)]
use log::error; pub enum ErrorCode {
#[derive(Debug, PartialEq, Clone, Copy)]
pub enum Error {
AttributeNotFound, AttributeNotFound,
AttributeIsCustom, AttributeIsCustom,
BufferTooSmall, BufferTooSmall,
@ -31,6 +26,13 @@ pub enum Error {
CommandNotFound, CommandNotFound,
Duplicate, Duplicate,
EndpointNotFound, EndpointNotFound,
InvalidAction,
InvalidCommand,
InvalidDataType,
UnsupportedAccess,
ResourceExhausted,
Busy,
DataVersionMismatch,
Crypto, Crypto,
TLSStack, TLSStack,
MdnsError, MdnsError,
@ -71,78 +73,155 @@ pub enum Error {
Utf8Fail, Utf8Fail,
} }
impl From<ErrorCode> for Error {
fn from(code: ErrorCode) -> Self {
Self::new(code)
}
}
pub struct Error {
code: ErrorCode,
#[cfg(all(feature = "std", feature = "backtrace"))]
backtrace: std::backtrace::Backtrace,
}
impl Error {
pub fn new(code: ErrorCode) -> Self {
Self {
code,
#[cfg(all(feature = "std", feature = "backtrace"))]
backtrace: std::backtrace::Backtrace::capture(),
}
}
pub const fn code(&self) -> ErrorCode {
self.code
}
#[cfg(all(feature = "std", feature = "backtrace"))]
pub const fn backtrace(&self) -> &std::backtrace::Backtrace {
&self.backtrace
}
pub fn remap<F>(self, matcher: F, to: Self) -> Self
where
F: FnOnce(&Self) -> bool,
{
if matcher(&self) {
to
} else {
self
}
}
pub fn map_invalid(self, to: Self) -> Self {
self.remap(
|e| matches!(e.code(), ErrorCode::Invalid | ErrorCode::InvalidData),
to,
)
}
pub fn map_invalid_command(self) -> Self {
self.map_invalid(Error::new(ErrorCode::InvalidCommand))
}
pub fn map_invalid_action(self) -> Self {
self.map_invalid(Error::new(ErrorCode::InvalidAction))
}
pub fn map_invalid_data_type(self) -> Self {
self.map_invalid(Error::new(ErrorCode::InvalidDataType))
}
}
#[cfg(feature = "std")]
impl From<std::io::Error> for Error { impl From<std::io::Error> for Error {
fn from(_e: std::io::Error) -> Self { fn from(_e: std::io::Error) -> Self {
// Keep things simple for now // Keep things simple for now
Self::StdIoError Self::new(ErrorCode::StdIoError)
} }
} }
impl<T> From<PoisonError<T>> for Error { #[cfg(feature = "std")]
fn from(_e: PoisonError<T>) -> Self { impl<T> From<std::sync::PoisonError<T>> for Error {
Self::RwLock fn from(_e: std::sync::PoisonError<T>) -> Self {
Self::new(ErrorCode::RwLock)
} }
} }
#[cfg(feature = "crypto_openssl")] #[cfg(feature = "openssl")]
impl From<openssl::error::ErrorStack> for Error { impl From<openssl::error::ErrorStack> for Error {
fn from(e: openssl::error::ErrorStack) -> Self { fn from(e: openssl::error::ErrorStack) -> Self {
error!("Error in TLS: {}", e); ::log::error!("Error in TLS: {}", e);
Self::TLSStack Self::new(ErrorCode::TLSStack)
} }
} }
#[cfg(feature = "crypto_mbedtls")] #[cfg(all(feature = "mbedtls", not(target_os = "espidf")))]
impl From<mbedtls::Error> for Error { impl From<mbedtls::Error> for Error {
fn from(e: mbedtls::Error) -> Self { fn from(e: mbedtls::Error) -> Self {
error!("Error in TLS: {}", e); ::log::error!("Error in TLS: {}", e);
Self::TLSStack Self::new(ErrorCode::TLSStack)
} }
} }
#[cfg(feature = "crypto_rustcrypto")] #[cfg(target_os = "espidf")]
impl From<esp_idf_sys::EspError> for Error {
fn from(e: esp_idf_sys::EspError) -> Self {
::log::error!("Error in ESP: {}", e);
Self::new(ErrorCode::TLSStack) // TODO: Not a good mapping
}
}
#[cfg(feature = "rustcrypto")]
impl From<ccm::aead::Error> for Error { impl From<ccm::aead::Error> for Error {
fn from(_e: ccm::aead::Error) -> Self { fn from(_e: ccm::aead::Error) -> Self {
Self::Crypto Self::new(ErrorCode::Crypto)
} }
} }
impl From<SystemTimeError> for Error { #[cfg(feature = "std")]
fn from(_e: SystemTimeError) -> Self { impl From<std::time::SystemTimeError> for Error {
Self::SysTimeFail fn from(_e: std::time::SystemTimeError) -> Self {
Error::new(ErrorCode::SysTimeFail)
} }
} }
impl From<TryFromSliceError> for Error { impl From<TryFromSliceError> for Error {
fn from(_e: TryFromSliceError) -> Self { fn from(_e: TryFromSliceError) -> Self {
Self::Invalid Self::new(ErrorCode::Invalid)
} }
} }
impl<T> From<SendError<T>> for Error { impl From<Utf8Error> for Error {
fn from(e: SendError<T>) -> Self { fn from(_e: Utf8Error) -> Self {
error!("Error in channel send {}", e); Self::new(ErrorCode::Utf8Fail)
Self::Invalid
} }
} }
impl From<FromUtf8Error> for Error { impl fmt::Debug for Error {
fn from(_e: FromUtf8Error) -> Self { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Self::Utf8Fail #[cfg(not(all(feature = "std", feature = "backtrace")))]
} {
write!(f, "Error::{}", self)?;
} }
impl From<TryRecvError> for Error { #[cfg(all(feature = "std", feature = "backtrace"))]
fn from(e: TryRecvError) -> Self { {
error!("Error in channel try_recv {}", e); writeln!(f, "Error::{} {{", self)?;
Self::Invalid write!(f, "{}", self.backtrace())?;
writeln!(f, "}}")?;
}
Ok(())
} }
} }
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self) write!(f, "{:?}", self.code())
} }
} }
#[cfg(feature = "std")]
impl std::error::Error for Error {} impl std::error::Error for Error {}

View file

@ -15,56 +15,26 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::{Arc, Mutex, MutexGuard, RwLock}; use core::fmt::Write;
use byteorder::{BigEndian, ByteOrder, LittleEndian}; use byteorder::{BigEndian, ByteOrder, LittleEndian};
use log::{error, info}; use heapless::{String, Vec};
use owning_ref::RwLockReadGuardRef; use log::info;
use crate::{ use crate::{
cert::Cert, cert::{Cert, MAX_CERT_TLV_LEN},
crypto::{self, crypto_dummy::KeyPairDummy, hkdf_sha256, CryptoKeyPair, HmacSha256, KeyPair}, crypto::{self, hkdf_sha256, HmacSha256, KeyPair},
error::Error, error::{Error, ErrorCode},
group_keys::KeySet, group_keys::KeySet,
mdns::{self, Mdns}, mdns::{Mdns, ServiceMode},
sys::{Psm, SysMdnsService}, tlv::{self, FromTLV, OctetStr, TLVList, TLVWriter, TagType, ToTLV, UtfStr},
tlv::{OctetStr, TLVWriter, TagType, ToTLV, UtfStr}, utils::writebuf::WriteBuf,
}; };
const MAX_CERT_TLV_LEN: usize = 350;
const COMPRESSED_FABRIC_ID_LEN: usize = 8; const COMPRESSED_FABRIC_ID_LEN: usize = 8;
macro_rules! fb_key {
($index:ident, $key:ident) => {
&format!("fb{}{}", $index, $key)
};
}
const ST_VID: &str = "vid";
const ST_RCA: &str = "rca";
const ST_ICA: &str = "ica";
const ST_NOC: &str = "noc";
const ST_IPK: &str = "ipk";
const ST_LBL: &str = "label";
const ST_PBKEY: &str = "pubkey";
const ST_PRKEY: &str = "privkey";
#[allow(dead_code)] #[allow(dead_code)]
pub struct Fabric { #[derive(Debug, ToTLV)]
node_id: u64,
fabric_id: u64,
vendor_id: u16,
key_pair: Box<dyn CryptoKeyPair>,
pub root_ca: Cert,
pub icac: Option<Cert>,
pub noc: Cert,
pub ipk: KeySet,
label: String,
compressed_id: [u8; COMPRESSED_FABRIC_ID_LEN],
mdns_service: Option<SysMdnsService>,
}
#[derive(ToTLV)]
#[tlvargs(lifetime = "'a", start = 1)] #[tlvargs(lifetime = "'a", start = 1)]
pub struct FabricDescriptor<'a> { pub struct FabricDescriptor<'a> {
root_public_key: OctetStr<'a>, root_public_key: OctetStr<'a>,
@ -77,64 +47,70 @@ pub struct FabricDescriptor<'a> {
pub fab_idx: Option<u8>, pub fab_idx: Option<u8>,
} }
#[derive(Debug, ToTLV, FromTLV)]
pub struct Fabric {
node_id: u64,
fabric_id: u64,
vendor_id: u16,
key_pair: KeyPair,
pub root_ca: Vec<u8, { MAX_CERT_TLV_LEN }>,
pub icac: Option<Vec<u8, { MAX_CERT_TLV_LEN }>>,
pub noc: Vec<u8, { MAX_CERT_TLV_LEN }>,
pub ipk: KeySet,
label: String<32>,
mdns_service_name: String<33>,
}
impl Fabric { impl Fabric {
pub fn new( pub fn new(
key_pair: KeyPair, key_pair: KeyPair,
root_ca: Cert, root_ca: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
icac: Option<Cert>, icac: Option<heapless::Vec<u8, { MAX_CERT_TLV_LEN }>>,
noc: Cert, noc: heapless::Vec<u8, { MAX_CERT_TLV_LEN }>,
ipk: &[u8], ipk: &[u8],
vendor_id: u16, vendor_id: u16,
label: &str,
) -> Result<Self, Error> { ) -> Result<Self, Error> {
let node_id = noc.get_node_id()?; let (node_id, fabric_id) = {
let fabric_id = noc.get_fabric_id()?; let noc_p = Cert::new(&noc)?;
(noc_p.get_node_id()?, noc_p.get_fabric_id()?)
let mut f = Self {
node_id,
fabric_id,
vendor_id,
key_pair: Box::new(key_pair),
root_ca,
icac,
noc,
ipk: KeySet::default(),
compressed_id: [0; COMPRESSED_FABRIC_ID_LEN],
label: "".into(),
mdns_service: None,
}; };
Fabric::get_compressed_id(f.root_ca.get_pubkey(), fabric_id, &mut f.compressed_id)?;
f.ipk = KeySet::new(ipk, &f.compressed_id)?;
let mut mdns_service_name = String::with_capacity(33); let mut compressed_id = [0_u8; COMPRESSED_FABRIC_ID_LEN];
for c in f.compressed_id {
mdns_service_name.push_str(&format!("{:02X}", c)); let ipk = {
let root_ca_p = Cert::new(&root_ca)?;
Fabric::get_compressed_id(root_ca_p.get_pubkey(), fabric_id, &mut compressed_id)?;
KeySet::new(ipk, &compressed_id)?
};
let mut mdns_service_name = heapless::String::<33>::new();
for c in compressed_id {
let mut hex = heapless::String::<4>::new();
write!(&mut hex, "{:02X}", c).unwrap();
mdns_service_name.push_str(&hex).unwrap();
} }
mdns_service_name.push('-'); mdns_service_name.push('-').unwrap();
let mut node_id_be: [u8; 8] = [0; 8]; let mut node_id_be: [u8; 8] = [0; 8];
BigEndian::write_u64(&mut node_id_be, node_id); BigEndian::write_u64(&mut node_id_be, node_id);
for c in node_id_be { for c in node_id_be {
mdns_service_name.push_str(&format!("{:02X}", c)); let mut hex = heapless::String::<4>::new();
write!(&mut hex, "{:02X}", c).unwrap();
mdns_service_name.push_str(&hex).unwrap();
} }
info!("MDNS Service Name: {}", mdns_service_name); info!("MDNS Service Name: {}", mdns_service_name);
f.mdns_service = Some(
Mdns::get()?.publish_service(&mdns_service_name, mdns::ServiceMode::Commissioned)?,
);
Ok(f)
}
pub fn dummy() -> Result<Self, Error> {
Ok(Self { Ok(Self {
node_id: 0, node_id,
fabric_id: 0, fabric_id,
vendor_id: 0, vendor_id,
key_pair: Box::new(KeyPairDummy::new()?), key_pair,
root_ca: Cert::default(), root_ca,
icac: Some(Cert::default()), icac,
noc: Cert::default(), noc,
ipk: KeySet::default(), ipk,
label: "".into(), label: label.into(),
compressed_id: [0; COMPRESSED_FABRIC_ID_LEN], mdns_service_name,
mdns_service: None,
}) })
} }
@ -147,14 +123,14 @@ impl Fabric {
0x69, 0x63, 0x69, 0x63,
]; ];
hkdf_sha256(&fabric_id_be, root_pubkey, &COMPRESSED_FABRIC_ID_INFO, out) hkdf_sha256(&fabric_id_be, root_pubkey, &COMPRESSED_FABRIC_ID_INFO, out)
.map_err(|_| Error::NoSpace) .map_err(|_| Error::from(ErrorCode::NoSpace))
} }
pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<(), Error> { pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<(), Error> {
let mut mac = HmacSha256::new(self.ipk.op_key())?; let mut mac = HmacSha256::new(self.ipk.op_key())?;
mac.update(random)?; mac.update(random)?;
mac.update(self.root_ca.get_pubkey())?; mac.update(self.get_root_ca()?.get_pubkey())?;
let mut buf: [u8; 8] = [0; 8]; let mut buf: [u8; 8] = [0; 8];
LittleEndian::write_u64(&mut buf, self.fabric_id); LittleEndian::write_u64(&mut buf, self.fabric_id);
@ -168,7 +144,7 @@ impl Fabric {
if id.as_slice() == target { if id.as_slice() == target {
Ok(()) Ok(())
} else { } else {
Err(Error::NotFound) Err(ErrorCode::NotFound.into())
} }
} }
@ -184,255 +160,180 @@ impl Fabric {
self.fabric_id self.fabric_id
} }
pub fn get_fabric_desc(&self, fab_idx: u8) -> FabricDescriptor { pub fn get_root_ca(&self) -> Result<Cert<'_>, Error> {
FabricDescriptor { Cert::new(&self.root_ca)
root_public_key: OctetStr::new(self.root_ca.get_pubkey()), }
pub fn get_fabric_desc<'a>(
&'a self,
fab_idx: u8,
root_ca_cert: &'a Cert,
) -> Result<FabricDescriptor<'a>, Error> {
let desc = FabricDescriptor {
root_public_key: OctetStr::new(root_ca_cert.get_pubkey()),
vendor_id: self.vendor_id, vendor_id: self.vendor_id,
fabric_id: self.fabric_id, fabric_id: self.fabric_id,
node_id: self.node_id, node_id: self.node_id,
label: UtfStr(self.label.as_bytes()), label: UtfStr(self.label.as_bytes()),
fab_idx: Some(fab_idx), fab_idx: Some(fab_idx),
}
}
fn rm_store(&self, index: usize, psm: &MutexGuard<Psm>) {
psm.rm(fb_key!(index, ST_RCA));
psm.rm(fb_key!(index, ST_ICA));
psm.rm(fb_key!(index, ST_NOC));
psm.rm(fb_key!(index, ST_IPK));
psm.rm(fb_key!(index, ST_LBL));
psm.rm(fb_key!(index, ST_PBKEY));
psm.rm(fb_key!(index, ST_PRKEY));
psm.rm(fb_key!(index, ST_VID));
}
fn store(&self, index: usize, psm: &MutexGuard<Psm>) -> Result<(), Error> {
let mut key = [0u8; MAX_CERT_TLV_LEN];
let len = self.root_ca.as_tlv(&mut key)?;
psm.set_kv_slice(fb_key!(index, ST_RCA), &key[..len])?;
let len = if let Some(icac) = &self.icac {
icac.as_tlv(&mut key)?
} else {
0
};
psm.set_kv_slice(fb_key!(index, ST_ICA), &key[..len])?;
let len = self.noc.as_tlv(&mut key)?;
psm.set_kv_slice(fb_key!(index, ST_NOC), &key[..len])?;
psm.set_kv_slice(fb_key!(index, ST_IPK), self.ipk.epoch_key())?;
psm.set_kv_slice(fb_key!(index, ST_LBL), self.label.as_bytes())?;
let mut key = [0_u8; crypto::EC_POINT_LEN_BYTES];
let len = self.key_pair.get_public_key(&mut key)?;
let key = &key[..len];
psm.set_kv_slice(fb_key!(index, ST_PBKEY), key)?;
let mut key = [0_u8; crypto::BIGNUM_LEN_BYTES];
let len = self.key_pair.get_private_key(&mut key)?;
let key = &key[..len];
psm.set_kv_slice(fb_key!(index, ST_PRKEY), key)?;
psm.set_kv_u64(fb_key!(index, ST_VID), self.vendor_id.into())?;
Ok(())
}
fn load(index: usize, psm: &MutexGuard<Psm>) -> Result<Self, Error> {
let mut root_ca = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_RCA), &mut root_ca)?;
let root_ca = Cert::new(root_ca.as_slice())?;
let mut icac = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_ICA), &mut icac)?;
let icac = if !icac.is_empty() {
Some(Cert::new(icac.as_slice())?)
} else {
None
}; };
let mut noc = Vec::new(); Ok(desc)
psm.get_kv_slice(fb_key!(index, ST_NOC), &mut noc)?;
let noc = Cert::new(noc.as_slice())?;
let mut ipk = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_IPK), &mut ipk)?;
let mut label = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_LBL), &mut label)?;
let label = String::from_utf8(label).map_err(|_| {
error!("Couldn't read label");
Error::Invalid
})?;
let mut pub_key = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_PBKEY), &mut pub_key)?;
let mut priv_key = Vec::new();
psm.get_kv_slice(fb_key!(index, ST_PRKEY), &mut priv_key)?;
let keypair = KeyPair::new_from_components(pub_key.as_slice(), priv_key.as_slice())?;
let mut vendor_id = 0;
psm.get_kv_u64(fb_key!(index, ST_VID), &mut vendor_id)?;
let f = Fabric::new(
keypair,
root_ca,
icac,
noc,
ipk.as_slice(),
vendor_id as u16,
);
f.map(|mut f| {
f.label = label;
f
})
} }
} }
pub const MAX_SUPPORTED_FABRICS: usize = 3; pub const MAX_SUPPORTED_FABRICS: usize = 3;
#[derive(Default)]
pub struct FabricMgrInner { type FabricEntries = Vec<Option<Fabric>, MAX_SUPPORTED_FABRICS>;
// The outside world expects Fabric Index to be one more than the actual one
// since 0 is not allowed. Need to handle this cleanly somehow
pub fabrics: [Option<Fabric>; MAX_SUPPORTED_FABRICS],
}
pub struct FabricMgr { pub struct FabricMgr {
inner: RwLock<FabricMgrInner>, fabrics: FabricEntries,
psm: Arc<Mutex<Psm>>, changed: bool,
} }
impl FabricMgr { impl FabricMgr {
pub fn new() -> Result<Self, Error> { #[inline(always)]
let dummy_fabric = Fabric::dummy()?; pub const fn new() -> Self {
let mut mgr = FabricMgrInner::default(); Self {
mgr.fabrics[0] = Some(dummy_fabric); fabrics: FabricEntries::new(),
let mut fm = Self { changed: false,
inner: RwLock::new(mgr), }
psm: Psm::get()?,
};
fm.load()?;
Ok(fm)
} }
fn store(&self, index: usize, fabric: &Fabric) -> Result<(), Error> { pub fn load(&mut self, data: &[u8], mdns: &dyn Mdns) -> Result<(), Error> {
let psm = self.psm.lock().unwrap(); for fabric in self.fabrics.iter().flatten() {
fabric.store(index, &psm) mdns.remove(&fabric.mdns_service_name)?;
} }
fn load(&mut self) -> Result<(), Error> { let root = TLVList::new(data).iter().next().ok_or(ErrorCode::Invalid)?;
let mut mgr = self.inner.write()?;
let psm = self.psm.lock().unwrap(); tlv::from_tlv(&mut self.fabrics, &root)?;
for i in 0..MAX_SUPPORTED_FABRICS {
let result = Fabric::load(i, &psm); for fabric in self.fabrics.iter().flatten() {
if let Ok(fabric) = result { mdns.add(&fabric.mdns_service_name, ServiceMode::Commissioned)?;
info!("Adding new fabric at index {}", i);
mgr.fabrics[i] = Some(fabric);
}
} }
self.changed = false;
Ok(()) Ok(())
} }
pub fn add(&self, f: Fabric) -> Result<u8, Error> { pub fn store<'a>(&mut self, buf: &'a mut [u8]) -> Result<Option<&'a [u8]>, Error> {
let mut mgr = self.inner.write()?; if self.changed {
let index = mgr let mut wb = WriteBuf::new(buf);
.fabrics let mut tw = TLVWriter::new(&mut wb);
.iter()
.position(|f| f.is_none())
.ok_or(Error::NoSpace)?;
self.store(index, &f)?; self.fabrics
.as_slice()
.to_tlv(&mut tw, TagType::Anonymous)?;
mgr.fabrics[index] = Some(f); self.changed = false;
Ok(index as u8)
let len = tw.get_tail();
Ok(Some(&buf[..len]))
} else {
Ok(None)
}
} }
pub fn remove(&self, fab_idx: u8) -> Result<(), Error> { pub fn is_changed(&self) -> bool {
let fab_idx = fab_idx as usize; self.changed
let mut mgr = self.inner.write().unwrap(); }
let psm = self.psm.lock().unwrap();
if let Some(f) = &mgr.fabrics[fab_idx] { pub fn add(&mut self, f: Fabric, mdns: &dyn Mdns) -> Result<u8, Error> {
f.rm_store(fab_idx, &psm); let slot = self.fabrics.iter().position(|x| x.is_none());
mgr.fabrics[fab_idx] = None;
if slot.is_some() || self.fabrics.len() < MAX_SUPPORTED_FABRICS {
mdns.add(&f.mdns_service_name, ServiceMode::Commissioned)?;
self.changed = true;
if let Some(index) = slot {
self.fabrics[index] = Some(f);
Ok((index + 1) as u8)
} else {
self.fabrics
.push(Some(f))
.map_err(|_| ErrorCode::NoSpace)
.unwrap();
Ok(self.fabrics.len() as u8)
}
} else {
Err(ErrorCode::NoSpace.into())
}
}
pub fn remove(&mut self, fab_idx: u8, mdns: &dyn Mdns) -> Result<(), Error> {
if fab_idx > 0 && fab_idx as usize <= self.fabrics.len() {
if let Some(f) = self.fabrics[(fab_idx - 1) as usize].take() {
mdns.remove(&f.mdns_service_name)?;
self.changed = true;
Ok(()) Ok(())
} else { } else {
Err(Error::NotFound) Err(ErrorCode::NotFound.into())
}
} else {
Err(ErrorCode::NotFound.into())
} }
} }
pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<usize, Error> { pub fn match_dest_id(&self, random: &[u8], target: &[u8]) -> Result<usize, Error> {
let mgr = self.inner.read()?; for (index, fabric) in self.fabrics.iter().enumerate() {
for i in 0..MAX_SUPPORTED_FABRICS { if let Some(fabric) = fabric {
if let Some(fabric) = &mgr.fabrics[i] {
if fabric.match_dest_id(random, target).is_ok() { if fabric.match_dest_id(random, target).is_ok() {
return Ok(i); return Ok(index + 1);
} }
} }
} }
Err(Error::NotFound) Err(ErrorCode::NotFound.into())
} }
pub fn get_fabric<'ret, 'me: 'ret>( pub fn get_fabric(&self, idx: usize) -> Result<Option<&Fabric>, Error> {
&'me self, if idx == 0 {
idx: usize, Ok(None)
) -> Result<RwLockReadGuardRef<'ret, FabricMgrInner, Option<Fabric>>, Error> { } else {
Ok(RwLockReadGuardRef::new(self.inner.read()?).map(|fm| &fm.fabrics[idx])) Ok(self.fabrics[idx - 1].as_ref())
}
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
let mgr = self.inner.read().unwrap(); !self.fabrics.iter().any(Option::is_some)
for i in 1..MAX_SUPPORTED_FABRICS {
if mgr.fabrics[i].is_some() {
return false;
}
}
true
} }
pub fn used_count(&self) -> usize { pub fn used_count(&self) -> usize {
let mgr = self.inner.read().unwrap(); self.fabrics.iter().filter(|f| f.is_some()).count()
let mut count = 0;
for i in 1..MAX_SUPPORTED_FABRICS {
if mgr.fabrics[i].is_some() {
count += 1;
}
}
count
} }
// Parameters to T are the Fabric and its Fabric Index // Parameters to T are the Fabric and its Fabric Index
pub fn for_each<T>(&self, mut f: T) -> Result<(), Error> pub fn for_each<T>(&self, mut f: T) -> Result<(), Error>
where where
T: FnMut(&Fabric, u8), T: FnMut(&Fabric, u8) -> Result<(), Error>,
{ {
let mgr = self.inner.read().unwrap(); for (index, fabric) in self.fabrics.iter().enumerate() {
for i in 1..MAX_SUPPORTED_FABRICS { if let Some(fabric) = fabric {
if let Some(fabric) = &mgr.fabrics[i] { f(fabric, (index + 1) as u8)?;
f(fabric, i as u8)
} }
} }
Ok(()) Ok(())
} }
pub fn set_label(&self, index: u8, label: String) -> Result<(), Error> { pub fn set_label(&mut self, index: u8, label: &str) -> Result<(), Error> {
let index = index as usize; if !label.is_empty()
let mut mgr = self.inner.write()?; && self
if !label.is_empty() { .fabrics
for i in 1..MAX_SUPPORTED_FABRICS { .iter()
if let Some(fabric) = &mgr.fabrics[i] { .filter_map(|f| f.as_ref())
if fabric.label == label { .any(|f| f.label == label)
return Err(Error::Invalid); {
} return Err(ErrorCode::Invalid.into());
}
}
}
if let Some(fabric) = &mut mgr.fabrics[index] {
let old = fabric.label.clone();
fabric.label = label;
let psm = self.psm.lock().unwrap();
if fabric.store(index, &psm).is_err() {
fabric.label = old;
return Err(Error::StdIoError);
} }
let index = (index - 1) as usize;
if let Some(fabric) = &mut self.fabrics[index] {
fabric.label = label.into();
self.changed = true;
} }
Ok(()) Ok(())
} }

View file

@ -15,39 +15,18 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::{Arc, Mutex, Once}; use crate::{
crypto::{self, SYMM_KEY_LEN_BYTES},
error::{Error, ErrorCode},
tlv::{FromTLV, TLVWriter, TagType, ToTLV},
};
use crate::{crypto, error::Error}; type KeySetKey = [u8; SYMM_KEY_LEN_BYTES];
// This is just makeshift implementation for now, not used anywhere #[derive(Debug, Default, FromTLV, ToTLV)]
pub struct GroupKeys {}
static mut G_GRP_KEYS: Option<Arc<Mutex<GroupKeys>>> = None;
static INIT: Once = Once::new();
impl GroupKeys {
fn new() -> Self {
Self {}
}
pub fn get() -> Result<Arc<Mutex<Self>>, Error> {
unsafe {
INIT.call_once(|| {
G_GRP_KEYS = Some(Arc::new(Mutex::new(GroupKeys::new())));
});
Ok(G_GRP_KEYS.as_ref().ok_or(Error::Invalid)?.clone())
}
}
pub fn insert_key() -> Result<(), Error> {
Ok(())
}
}
#[derive(Debug, Default)]
pub struct KeySet { pub struct KeySet {
pub epoch_key: [u8; crypto::SYMM_KEY_LEN_BYTES], pub epoch_key: KeySetKey,
pub op_key: [u8; crypto::SYMM_KEY_LEN_BYTES], pub op_key: KeySetKey,
} }
impl KeySet { impl KeySet {
@ -63,7 +42,8 @@ impl KeySet {
0x47, 0x72, 0x6f, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x20, 0x76, 0x31, 0x2e, 0x30, 0x47, 0x72, 0x6f, 0x75, 0x70, 0x4b, 0x65, 0x79, 0x20, 0x76, 0x31, 0x2e, 0x30,
]; ];
crypto::hkdf_sha256(compressed_id, ipk, &GRP_KEY_INFO, opkey).map_err(|_| Error::NoSpace) crypto::hkdf_sha256(compressed_id, ipk, &GRP_KEY_INFO, opkey)
.map_err(|_| ErrorCode::NoSpace.into())
} }
pub fn op_key(&self) -> &[u8] { pub fn op_key(&self) -> &[u8] {

View file

@ -1,88 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use super::core::IMStatusCode;
use super::core::OpCode;
use super::messages::ib;
use super::messages::msg;
use super::messages::msg::InvReq;
use super::InteractionModel;
use super::Transaction;
use crate::{
error::*,
tlv::{get_root_node_struct, print_tlv_list, FromTLV, TLVElement, TLVWriter, TagType},
transport::{packet::Packet, proto_demux::ResponseRequired},
};
use log::error;
#[macro_export]
macro_rules! cmd_enter {
($e:expr) => {{
use colored::Colorize;
info! {"{} {}", "Handling Command".cyan(), $e.cyan()}
}};
}
pub struct CommandReq<'a, 'b, 'c, 'd, 'e> {
pub cmd: ib::CmdPath,
pub data: TLVElement<'a>,
pub resp: &'a mut TLVWriter<'b, 'c>,
pub trans: &'a mut Transaction<'d, 'e>,
}
impl InteractionModel {
pub fn handle_invoke_req(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
if InteractionModel::req_timeout_handled(trans, proto_tx)? {
return Ok(ResponseRequired::Yes);
}
proto_tx.set_proto_opcode(OpCode::InvokeResponse as u8);
let mut tw = TLVWriter::new(proto_tx.get_writebuf()?);
let root = get_root_node_struct(rx_buf)?;
let inv_req = InvReq::from_tlv(&root)?;
let timed_tx = trans.get_timeout().map(|_| true);
let timed_request = inv_req.timed_request.filter(|a| *a);
// Either both should be None, or both should be Some(true)
if timed_tx != timed_request {
InteractionModel::create_status_response(proto_tx, IMStatusCode::TimedRequestMisMatch)?;
return Ok(ResponseRequired::Yes);
}
tw.start_struct(TagType::Anonymous)?;
// Suppress Response -> TODO: Need to revisit this for cases where we send a command back
tw.bool(
TagType::Context(msg::InvRespTag::SupressResponse as u8),
false,
)?;
self.consumer
.consume_invoke_cmd(&inv_req, trans, &mut tw)
.map_err(|e| {
error!("Error in handling command: {:?}", e);
print_tlv_list(rx_buf);
e
})?;
tw.end_container()?;
Ok(ResponseRequired::Yes)
}
}

View file

@ -15,212 +15,29 @@
* limitations under the License. * limitations under the License.
*/ */
use std::time::{Duration, SystemTime}; use core::time::Duration;
use crate::{ use crate::{
acl::Accessor,
error::*, error::*,
interaction_model::messages::msg::StatusResp, tlv::{get_root_node_struct, FromTLV, TLVElement, TLVWriter, TagType, ToTLV},
tlv::{self, get_root_node_struct, FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, transport::{exchange::Exchange, packet::Packet},
transport::{ utils::epoch::Epoch,
exchange::Exchange,
packet::Packet,
proto_demux::{self, ProtoCtx, ResponseRequired},
session::SessionHandle,
},
}; };
use colored::Colorize; use log::error;
use log::{error, info}; use num::{self, FromPrimitive};
use num;
use num_derive::FromPrimitive; use num_derive::FromPrimitive;
use super::InteractionModel; use super::messages::msg::{
use super::Transaction; self, InvReq, ReadReq, StatusResp, SubscribeReq, SubscribeResp, TimedReq, WriteReq,
use super::TransactionState;
use super::{messages::msg::TimedReq, InteractionConsumer};
/* Handle messages related to the Interation Model
*/
/* Interaction Model ID as per the Matter Spec */
const PROTO_ID_INTERACTION_MODEL: usize = 0x01;
#[derive(FromPrimitive, Debug, Copy, Clone, PartialEq)]
pub enum OpCode {
Reserved = 0,
StatusResponse = 1,
ReadRequest = 2,
SubscribeRequest = 3,
SubscriptResponse = 4,
ReportData = 5,
WriteRequest = 6,
WriteResponse = 7,
InvokeRequest = 8,
InvokeResponse = 9,
TimedRequest = 10,
}
impl<'a, 'b> Transaction<'a, 'b> {
pub fn new(session: &'a mut SessionHandle<'b>, exch: &'a mut Exchange) -> Self {
Self {
state: TransactionState::Ongoing,
session,
exch,
}
}
/// Terminates the transaction, no communication (even ACKs) happens hence forth
pub fn terminate(&mut self) {
self.state = TransactionState::Terminate
}
pub fn is_terminate(&self) -> bool {
self.state == TransactionState::Terminate
}
/// Marks the transaction as completed from the application's perspective
pub fn complete(&mut self) {
self.state = TransactionState::Complete
}
pub fn is_complete(&self) -> bool {
self.state == TransactionState::Complete
}
pub fn set_timeout(&mut self, timeout: u64) {
self.exch
.set_data_time(SystemTime::now().checked_add(Duration::from_millis(timeout)));
}
pub fn get_timeout(&mut self) -> Option<SystemTime> {
self.exch.get_data_time()
}
pub fn has_timed_out(&self) -> bool {
if let Some(timeout) = self.exch.get_data_time() {
if SystemTime::now() > timeout {
return true;
}
}
false
}
}
impl InteractionModel {
pub fn new(consumer: Box<dyn InteractionConsumer>) -> InteractionModel {
InteractionModel { consumer }
}
pub fn handle_subscribe_req(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
let mut tw = TLVWriter::new(proto_tx.get_writebuf()?);
let (opcode, resp) = self.consumer.consume_subscribe(rx_buf, trans, &mut tw)?;
proto_tx.set_proto_opcode(opcode as u8);
Ok(resp)
}
pub fn handle_status_resp(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
let mut tw = TLVWriter::new(proto_tx.get_writebuf()?);
let root = get_root_node_struct(rx_buf)?;
let req = StatusResp::from_tlv(&root)?;
let (opcode, resp) = self.consumer.consume_status_report(&req, trans, &mut tw)?;
proto_tx.set_proto_opcode(opcode as u8);
Ok(resp)
}
pub fn handle_timed_req(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
proto_tx.set_proto_opcode(OpCode::StatusResponse as u8);
let root = get_root_node_struct(rx_buf)?;
let req = TimedReq::from_tlv(&root)?;
trans.set_timeout(req.timeout.into());
let status = StatusResp {
status: IMStatusCode::Success,
};
let mut tw = TLVWriter::new(proto_tx.get_writebuf()?);
let _ = status.to_tlv(&mut tw, TagType::Anonymous);
Ok(ResponseRequired::Yes)
}
/// Handle Request Timeouts
/// This API checks if a request was a timed request, and if so, and if the timeout has
/// expired, it will generate the appropriate response as expected
pub(super) fn req_timeout_handled(
trans: &mut Transaction,
proto_tx: &mut Packet,
) -> Result<bool, Error> {
if trans.has_timed_out() {
trans.complete();
InteractionModel::create_status_response(proto_tx, IMStatusCode::Timeout)?;
Ok(true)
} else {
Ok(false)
}
}
pub(super) fn create_status_response(
proto_tx: &mut Packet,
status: IMStatusCode,
) -> Result<(), Error> {
proto_tx.set_proto_opcode(OpCode::StatusResponse as u8);
let mut tw = TLVWriter::new(proto_tx.get_writebuf()?);
let status = StatusResp { status };
status.to_tlv(&mut tw, TagType::Anonymous)
}
}
impl proto_demux::HandleProto for InteractionModel {
fn handle_proto_id(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> {
let mut trans = Transaction::new(&mut ctx.exch_ctx.sess, ctx.exch_ctx.exch);
let proto_opcode: OpCode =
num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?;
ctx.tx.set_proto_id(PROTO_ID_INTERACTION_MODEL as u16);
let buf = ctx.rx.as_borrow_slice();
info!("{} {:?}", "Received command".cyan(), proto_opcode);
tlv::print_tlv_list(buf);
let result = match proto_opcode {
OpCode::InvokeRequest => self.handle_invoke_req(&mut trans, buf, &mut ctx.tx)?,
OpCode::ReadRequest => self.handle_read_req(&mut trans, buf, &mut ctx.tx)?,
OpCode::WriteRequest => self.handle_write_req(&mut trans, buf, &mut ctx.tx)?,
OpCode::TimedRequest => self.handle_timed_req(&mut trans, buf, &mut ctx.tx)?,
OpCode::SubscribeRequest => self.handle_subscribe_req(&mut trans, buf, &mut ctx.tx)?,
OpCode::StatusResponse => self.handle_status_resp(&mut trans, buf, &mut ctx.tx)?,
_ => {
error!("Opcode Not Handled: {:?}", proto_opcode);
return Err(Error::InvalidOpcode);
}
}; };
if result == ResponseRequired::Yes { #[macro_export]
info!("Sending response"); macro_rules! cmd_enter {
tlv::print_tlv_list(ctx.tx.as_borrow_slice()); ($e:expr) => {{
} use owo_colors::OwoColorize;
if trans.is_terminate() { info! {"{} {}", "Handling command".cyan(), $e.cyan()}
ctx.exch_ctx.exch.terminate(); }};
} else if trans.is_complete() {
ctx.exch_ctx.exch.close();
}
Ok(result)
}
fn get_proto_id(&self) -> usize {
PROTO_ID_INTERACTION_MODEL
}
} }
#[derive(FromPrimitive, Debug, Clone, Copy, PartialEq)] #[derive(FromPrimitive, Debug, Clone, Copy, PartialEq)]
@ -253,21 +70,33 @@ pub enum IMStatusCode {
FailSafeRequired = 0xca, FailSafeRequired = 0xca,
} }
impl From<Error> for IMStatusCode { impl From<ErrorCode> for IMStatusCode {
fn from(e: Error) -> Self { fn from(e: ErrorCode) -> Self {
match e { match e {
Error::EndpointNotFound => IMStatusCode::UnsupportedEndpoint, ErrorCode::EndpointNotFound => IMStatusCode::UnsupportedEndpoint,
Error::ClusterNotFound => IMStatusCode::UnsupportedCluster, ErrorCode::ClusterNotFound => IMStatusCode::UnsupportedCluster,
Error::AttributeNotFound => IMStatusCode::UnsupportedAttribute, ErrorCode::AttributeNotFound => IMStatusCode::UnsupportedAttribute,
Error::CommandNotFound => IMStatusCode::UnsupportedCommand, ErrorCode::CommandNotFound => IMStatusCode::UnsupportedCommand,
ErrorCode::InvalidAction => IMStatusCode::InvalidAction,
ErrorCode::InvalidCommand => IMStatusCode::InvalidCommand,
ErrorCode::UnsupportedAccess => IMStatusCode::UnsupportedAccess,
ErrorCode::Busy => IMStatusCode::Busy,
ErrorCode::DataVersionMismatch => IMStatusCode::DataVersionMismatch,
ErrorCode::ResourceExhausted => IMStatusCode::ResourceExhausted,
_ => IMStatusCode::Failure, _ => IMStatusCode::Failure,
} }
} }
} }
impl From<Error> for IMStatusCode {
fn from(value: Error) -> Self {
Self::from(value.code())
}
}
impl FromTLV<'_> for IMStatusCode { impl FromTLV<'_> for IMStatusCode {
fn from_tlv(t: &TLVElement) -> Result<Self, Error> { fn from_tlv(t: &TLVElement) -> Result<Self, Error> {
num::FromPrimitive::from_u16(t.u16()?).ok_or(Error::Invalid) FromPrimitive::from_u16(t.u16()?).ok_or_else(|| ErrorCode::Invalid.into())
} }
} }
@ -276,3 +105,629 @@ impl ToTLV for IMStatusCode {
tw.u16(tag_type, *self as u16) tw.u16(tag_type, *self as u16)
} }
} }
#[derive(FromPrimitive, Debug, Copy, Clone, Eq, PartialEq)]
pub enum OpCode {
Reserved = 0,
StatusResponse = 1,
ReadRequest = 2,
SubscribeRequest = 3,
SubscribeResponse = 4,
ReportData = 5,
WriteRequest = 6,
WriteResponse = 7,
InvokeRequest = 8,
InvokeResponse = 9,
TimedRequest = 10,
}
/* Interaction Model ID as per the Matter Spec */
pub const PROTO_ID_INTERACTION_MODEL: u16 = 0x01;
// This is the amount of space we reserve for other things to be attached towards
// the end of long reads.
const LONG_READS_TLV_RESERVE_SIZE: usize = 24;
impl<'a> ReadReq<'a> {
pub fn tx_start<'r, 'p>(&self, tx: &'r mut Packet<'p>) -> Result<TLVWriter<'r, 'p>, Error> {
tx.reset();
tx.set_proto_id(PROTO_ID_INTERACTION_MODEL);
tx.set_proto_opcode(OpCode::ReportData as u8);
let mut tw = Self::reserve_long_read_space(tx)?;
tw.start_struct(TagType::Anonymous)?;
if self.attr_requests.is_some() {
tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?;
}
Ok(tw)
}
pub fn tx_finish_chunk(&self, tx: &mut Packet) -> Result<(), Error> {
self.complete(tx, true)
}
pub fn tx_finish(&self, tx: &mut Packet) -> Result<(), Error> {
self.complete(tx, false)
}
fn complete(&self, tx: &mut Packet<'_>, more_chunks: bool) -> Result<(), Error> {
let mut tw = Self::restore_long_read_space(tx)?;
if self.attr_requests.is_some() {
tw.end_container()?;
}
if more_chunks {
tw.bool(
TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8),
true,
)?;
}
tw.bool(
TagType::Context(msg::ReportDataTag::SupressResponse as u8),
!more_chunks,
)?;
tw.end_container()
}
fn reserve_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result<TLVWriter<'p, 'b>, Error> {
let wb = tx.get_writebuf()?;
wb.shrink(LONG_READS_TLV_RESERVE_SIZE)?;
Ok(TLVWriter::new(wb))
}
fn restore_long_read_space<'p, 'b>(tx: &'p mut Packet<'b>) -> Result<TLVWriter<'p, 'b>, Error> {
let wb = tx.get_writebuf()?;
wb.expand(LONG_READS_TLV_RESERVE_SIZE)?;
Ok(TLVWriter::new(wb))
}
}
impl<'a> WriteReq<'a> {
pub fn tx_start<'r, 'p>(
&self,
tx: &'r mut Packet<'p>,
epoch: Epoch,
timeout: Option<Duration>,
) -> Result<Option<TLVWriter<'r, 'p>>, Error> {
if has_timed_out(epoch, timeout) {
Interaction::status_response(tx, IMStatusCode::Timeout)?;
Ok(None)
} else {
tx.reset();
tx.set_proto_id(PROTO_ID_INTERACTION_MODEL);
tx.set_proto_opcode(OpCode::WriteResponse as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
tw.start_struct(TagType::Anonymous)?;
tw.start_array(TagType::Context(msg::WriteRespTag::WriteResponses as u8))?;
Ok(Some(tw))
}
}
pub fn tx_finish(&self, tx: &mut Packet<'_>) -> Result<(), Error> {
let mut tw = TLVWriter::new(tx.get_writebuf()?);
tw.end_container()?;
tw.end_container()
}
}
impl<'a> InvReq<'a> {
pub fn tx_start<'r, 'p>(
&self,
tx: &'r mut Packet<'p>,
epoch: Epoch,
timeout: Option<Duration>,
) -> Result<Option<TLVWriter<'r, 'p>>, Error> {
if has_timed_out(epoch, timeout) {
Interaction::status_response(tx, IMStatusCode::Timeout)?;
Ok(None)
} else {
let timed_tx = timeout.map(|_| true);
let timed_request = self.timed_request.filter(|a| *a);
// Either both should be None, or both should be Some(true)
if timed_tx != timed_request {
Interaction::status_response(tx, IMStatusCode::TimedRequestMisMatch)?;
Ok(None)
} else {
tx.reset();
tx.set_proto_id(PROTO_ID_INTERACTION_MODEL);
tx.set_proto_opcode(OpCode::InvokeResponse as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
tw.start_struct(TagType::Anonymous)?;
// Suppress Response -> TODO: Need to revisit this for cases where we send a command back
tw.bool(
TagType::Context(msg::InvRespTag::SupressResponse as u8),
false,
)?;
if self.inv_requests.is_some() {
tw.start_array(TagType::Context(msg::InvRespTag::InvokeResponses as u8))?;
}
Ok(Some(tw))
}
}
}
pub fn tx_finish(&self, tx: &mut Packet<'_>) -> Result<(), Error> {
let mut tw = TLVWriter::new(tx.get_writebuf()?);
if self.inv_requests.is_some() {
tw.end_container()?;
}
tw.end_container()
}
}
impl TimedReq {
pub fn timeout(&self, epoch: Epoch) -> Duration {
epoch()
.checked_add(Duration::from_millis(self.timeout as _))
.unwrap()
}
pub fn tx_process(self, tx: &mut Packet<'_>, epoch: Epoch) -> Result<Duration, Error> {
Interaction::status_response(tx, IMStatusCode::Success)?;
Ok(epoch()
.checked_add(Duration::from_millis(self.timeout as _))
.unwrap())
}
}
impl<'a> SubscribeReq<'a> {
pub fn tx_start<'r, 'p>(
&self,
tx: &'r mut Packet<'p>,
subscription_id: u32,
) -> Result<TLVWriter<'r, 'p>, Error> {
tx.reset();
tx.set_proto_id(PROTO_ID_INTERACTION_MODEL);
tx.set_proto_opcode(OpCode::ReportData as u8);
let mut tw = ReadReq::reserve_long_read_space(tx)?;
tw.start_struct(TagType::Anonymous)?;
tw.u32(
TagType::Context(msg::ReportDataTag::SubscriptionId as u8),
subscription_id,
)?;
if self.attr_requests.is_some() {
tw.start_array(TagType::Context(msg::ReportDataTag::AttributeReports as u8))?;
}
Ok(tw)
}
pub fn tx_finish_chunk(&self, tx: &mut Packet<'_>, more_chunks: bool) -> Result<(), Error> {
let mut tw = ReadReq::restore_long_read_space(tx)?;
if self.attr_requests.is_some() {
tw.end_container()?;
}
if more_chunks {
tw.bool(
TagType::Context(msg::ReportDataTag::MoreChunkedMsgs as u8),
true,
)?;
}
tw.bool(
TagType::Context(msg::ReportDataTag::SupressResponse as u8),
false,
)?;
tw.end_container()
}
pub fn tx_process_final(&self, tx: &mut Packet, subscription_id: u32) -> Result<(), Error> {
tx.reset();
tx.set_proto_id(PROTO_ID_INTERACTION_MODEL);
tx.set_proto_opcode(OpCode::SubscribeResponse as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
let resp = SubscribeResp::new(subscription_id, 40);
resp.to_tlv(&mut tw, TagType::Anonymous)
}
}
pub struct ReadDriver<'a, 'r, 'p> {
exchange: &'r mut Exchange<'a>,
tx: &'r mut Packet<'p>,
rx: &'r mut Packet<'p>,
completed: bool,
}
impl<'a, 'r, 'p> ReadDriver<'a, 'r, 'p> {
fn new(exchange: &'r mut Exchange<'a>, tx: &'r mut Packet<'p>, rx: &'r mut Packet<'p>) -> Self {
Self {
exchange,
tx,
rx,
completed: false,
}
}
fn start(&mut self, req: &ReadReq) -> Result<(), Error> {
req.tx_start(self.tx)?;
Ok(())
}
pub fn accessor(&self) -> Result<Accessor<'a>, Error> {
self.exchange.accessor()
}
pub fn writer(&mut self) -> Result<TLVWriter<'_, 'p>, Error> {
if self.completed {
Err(ErrorCode::Invalid.into()) // TODO
} else {
Ok(TLVWriter::new(self.tx.get_writebuf()?))
}
}
pub async fn send_chunk(&mut self, req: &ReadReq<'_>) -> Result<bool, Error> {
req.tx_finish_chunk(self.tx)?;
if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success {
self.completed = true;
Ok(false)
} else {
req.tx_start(self.tx)?;
Ok(true)
}
}
pub async fn complete(&mut self, req: &ReadReq<'_>) -> Result<(), Error> {
req.tx_finish(self.tx)?;
self.exchange.send_complete(self.tx).await
}
}
pub struct WriteDriver<'a, 'r, 'p> {
exchange: &'r mut Exchange<'a>,
tx: &'r mut Packet<'p>,
epoch: Epoch,
timeout: Option<Duration>,
}
impl<'a, 'r, 'p> WriteDriver<'a, 'r, 'p> {
fn new(
exchange: &'r mut Exchange<'a>,
epoch: Epoch,
timeout: Option<Duration>,
tx: &'r mut Packet<'p>,
) -> Self {
Self {
exchange,
tx,
epoch,
timeout,
}
}
async fn start(&mut self, req: &WriteReq<'_>) -> Result<bool, Error> {
if req.tx_start(self.tx, self.epoch, self.timeout)?.is_some() {
Ok(true)
} else {
self.exchange.send_complete(self.tx).await?;
Ok(false)
}
}
pub fn accessor(&self) -> Result<Accessor<'a>, Error> {
self.exchange.accessor()
}
pub fn writer(&mut self) -> Result<TLVWriter<'_, 'p>, Error> {
Ok(TLVWriter::new(self.tx.get_writebuf()?))
}
pub async fn complete(&mut self, req: &WriteReq<'_>) -> Result<(), Error> {
if !req.supress_response.unwrap_or_default() {
req.tx_finish(self.tx)?;
self.exchange.send_complete(self.tx).await?;
}
Ok(())
}
}
pub struct InvokeDriver<'a, 'r, 'p> {
exchange: &'r mut Exchange<'a>,
tx: &'r mut Packet<'p>,
epoch: Epoch,
timeout: Option<Duration>,
}
impl<'a, 'r, 'p> InvokeDriver<'a, 'r, 'p> {
fn new(
exchange: &'r mut Exchange<'a>,
epoch: Epoch,
timeout: Option<Duration>,
tx: &'r mut Packet<'p>,
) -> Self {
Self {
exchange,
tx,
epoch,
timeout,
}
}
async fn start(&mut self, req: &InvReq<'_>) -> Result<bool, Error> {
if req.tx_start(self.tx, self.epoch, self.timeout)?.is_some() {
Ok(true)
} else {
self.exchange.send_complete(self.tx).await?;
Ok(false)
}
}
pub fn accessor(&self) -> Result<Accessor<'a>, Error> {
self.exchange.accessor()
}
pub fn writer(&mut self) -> Result<TLVWriter<'_, 'p>, Error> {
Ok(TLVWriter::new(self.tx.get_writebuf()?))
}
pub fn writer_exchange(&mut self) -> Result<(TLVWriter<'_, 'p>, &Exchange<'a>), Error> {
Ok((TLVWriter::new(self.tx.get_writebuf()?), (self.exchange)))
}
pub async fn complete(&mut self, req: &InvReq<'_>) -> Result<(), Error> {
if !req.suppress_response.unwrap_or_default() {
req.tx_finish(self.tx)?;
self.exchange.send_complete(self.tx).await?;
}
Ok(())
}
}
pub struct SubscribeDriver<'a, 'r, 'p> {
exchange: &'r mut Exchange<'a>,
tx: &'r mut Packet<'p>,
rx: &'r mut Packet<'p>,
subscription_id: u32,
completed: bool,
}
impl<'a, 'r, 'p> SubscribeDriver<'a, 'r, 'p> {
fn new(
exchange: &'r mut Exchange<'a>,
subscription_id: u32,
tx: &'r mut Packet<'p>,
rx: &'r mut Packet<'p>,
) -> Self {
Self {
exchange,
tx,
rx,
subscription_id,
completed: false,
}
}
fn start(&mut self, req: &SubscribeReq) -> Result<(), Error> {
req.tx_start(self.tx, self.subscription_id)?;
Ok(())
}
pub fn accessor(&self) -> Result<Accessor<'a>, Error> {
self.exchange.accessor()
}
pub fn writer(&mut self) -> Result<TLVWriter<'_, 'p>, Error> {
if self.completed {
Err(ErrorCode::Invalid.into()) // TODO
} else {
Ok(TLVWriter::new(self.tx.get_writebuf()?))
}
}
pub async fn send_chunk(&mut self, req: &SubscribeReq<'_>) -> Result<bool, Error> {
req.tx_finish_chunk(self.tx, true)?;
if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success {
self.completed = true;
Ok(false)
} else {
req.tx_start(self.tx, self.subscription_id)?;
Ok(true)
}
}
pub async fn complete(&mut self, req: &SubscribeReq<'_>) -> Result<(), Error> {
if !self.completed {
req.tx_finish_chunk(self.tx, false)?;
if exchange_confirm(self.exchange, self.tx, self.rx).await? != IMStatusCode::Success {
self.completed = true;
} else {
req.tx_process_final(self.tx, self.subscription_id)?;
self.exchange.send_complete(self.tx).await?;
}
}
Ok(())
}
}
pub enum Interaction<'a, 'r, 'p> {
Read {
req: ReadReq<'r>,
driver: ReadDriver<'a, 'r, 'p>,
},
Write {
req: WriteReq<'r>,
driver: WriteDriver<'a, 'r, 'p>,
},
Invoke {
req: InvReq<'r>,
driver: InvokeDriver<'a, 'r, 'p>,
},
Subscribe {
req: SubscribeReq<'r>,
driver: SubscribeDriver<'a, 'r, 'p>,
},
}
impl<'a, 'r, 'p> Interaction<'a, 'r, 'p> {
pub async fn timeout(
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
) -> Result<Option<Duration>, Error> {
let epoch = exchange.matter.epoch;
let mut opcode: OpCode = rx.get_proto_opcode()?;
let mut timeout = None;
while opcode == OpCode::TimedRequest {
let rx_data = rx.as_slice();
let req = TimedReq::from_tlv(&get_root_node_struct(rx_data)?)?;
timeout = Some(req.tx_process(tx, epoch)?);
exchange.exchange(tx, rx).await?;
opcode = rx.get_proto_opcode()?;
}
Ok(timeout)
}
#[inline(always)]
pub fn new<S>(
exchange: &'r mut Exchange<'a>,
rx: &'r mut Packet<'p>,
tx: &'r mut Packet<'p>,
rx_status: &'r mut Packet<'p>,
subscription_id: S,
timeout: Option<Duration>,
) -> Result<Interaction<'a, 'r, 'p>, Error>
where
S: FnOnce() -> u32,
{
let epoch = exchange.matter.epoch;
let opcode = rx.get_proto_opcode()?;
let rx_data = rx.as_slice();
match opcode {
OpCode::ReadRequest => {
let req = ReadReq::from_tlv(&get_root_node_struct(rx_data)?)?;
let driver = ReadDriver::new(exchange, tx, rx_status);
Ok(Self::Read { req, driver })
}
OpCode::WriteRequest => {
let req = WriteReq::from_tlv(&get_root_node_struct(rx_data)?)?;
let driver = WriteDriver::new(exchange, epoch, timeout, tx);
Ok(Self::Write { req, driver })
}
OpCode::InvokeRequest => {
let req = InvReq::from_tlv(&get_root_node_struct(rx_data)?)?;
let driver = InvokeDriver::new(exchange, epoch, timeout, tx);
Ok(Self::Invoke { req, driver })
}
OpCode::SubscribeRequest => {
let req = SubscribeReq::from_tlv(&get_root_node_struct(rx_data)?)?;
let driver = SubscribeDriver::new(exchange, subscription_id(), tx, rx_status);
Ok(Self::Subscribe { req, driver })
}
_ => {
error!("Opcode not handled: {:?}", opcode);
Err(ErrorCode::InvalidOpcode.into())
}
}
}
pub async fn start(&mut self) -> Result<bool, Error> {
let started = match self {
Self::Read { req, driver } => {
driver.start(req)?;
true
}
Self::Write { req, driver } => driver.start(req).await?,
Self::Invoke { req, driver } => driver.start(req).await?,
Self::Subscribe { req, driver } => {
driver.start(req)?;
true
}
};
Ok(started)
}
fn status_response(tx: &mut Packet, status: IMStatusCode) -> Result<(), Error> {
tx.reset();
tx.set_proto_id(PROTO_ID_INTERACTION_MODEL);
tx.set_proto_opcode(OpCode::StatusResponse as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
let status = StatusResp { status };
status.to_tlv(&mut tw, TagType::Anonymous)
}
}
async fn exchange_confirm(
exchange: &mut Exchange<'_>,
tx: &mut Packet<'_>,
rx: &mut Packet<'_>,
) -> Result<IMStatusCode, Error> {
exchange.exchange(tx, rx).await?;
let opcode: OpCode = rx.get_proto_opcode()?;
if opcode == OpCode::StatusResponse {
let resp = StatusResp::from_tlv(&get_root_node_struct(rx.as_slice())?)?;
Ok(resp.status)
} else {
Interaction::status_response(tx, IMStatusCode::Busy)?; // TODO
exchange.send_complete(tx).await?;
Err(ErrorCode::Invalid.into()) // TODO
}
}
fn has_timed_out(epoch: Epoch, timeout: Option<Duration>) -> bool {
timeout.map(|timeout| epoch() > timeout).unwrap_or(false)
}

View file

@ -17,13 +17,13 @@
use crate::{ use crate::{
data_model::objects::{ClusterId, EndptId}, data_model::objects::{ClusterId, EndptId},
error::Error, error::{Error, ErrorCode},
tlv::{FromTLV, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{FromTLV, TLVWriter, TagType, ToTLV},
}; };
// A generic path with endpoint, clusters, and a leaf // A generic path with endpoint, clusters, and a leaf
// The leaf could be command, attribute, event // The leaf could be command, attribute, event
#[derive(Default, Clone, Copy, Debug, PartialEq, FromTLV, ToTLV)] #[derive(Default, Clone, Debug, PartialEq, FromTLV, ToTLV)]
#[tlvargs(datatype = "list")] #[tlvargs(datatype = "list")]
pub struct GenericPath { pub struct GenericPath {
pub endpoint: Option<EndptId>, pub endpoint: Option<EndptId>,
@ -48,7 +48,7 @@ impl GenericPath {
cluster: Some(c), cluster: Some(c),
leaf: Some(l), leaf: Some(l),
} => Ok((e, c, l)), } => Ok((e, c, l)),
_ => Err(Error::Invalid), _ => Err(ErrorCode::Invalid.into()),
} }
} }
/// Returns true, if the path is wildcard /// Returns true, if the path is wildcard
@ -69,7 +69,7 @@ pub mod msg {
use crate::{ use crate::{
error::Error, error::Error,
interaction_model::core::IMStatusCode, interaction_model::core::IMStatusCode,
tlv::{FromTLV, TLVArray, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{FromTLV, TLVArray, TLVWriter, TagType, ToTLV},
}; };
use super::ib::{ use super::ib::{
@ -77,7 +77,7 @@ pub mod msg {
EventPath, EventPath,
}; };
#[derive(Default, FromTLV, ToTLV)] #[derive(Debug, Default, FromTLV, ToTLV)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct SubscribeReq<'a> { pub struct SubscribeReq<'a> {
pub keep_subs: bool, pub keep_subs: bool,
@ -106,16 +106,6 @@ pub mod msg {
self.attr_requests = Some(TLVArray::new(requests)); self.attr_requests = Some(TLVArray::new(requests));
self self
} }
pub fn to_read_req(&self) -> ReadReq<'a> {
ReadReq {
attr_requests: self.attr_requests,
event_requests: self.event_requests,
event_filters: self.event_filters,
fabric_filtered: self.fabric_filtered,
dataver_filters: self.dataver_filters,
}
}
} }
#[derive(Debug, FromTLV, ToTLV)] #[derive(Debug, FromTLV, ToTLV)]
@ -160,13 +150,6 @@ pub mod msg {
pub inv_requests: Option<TLVArray<'a, CmdData<'a>>>, pub inv_requests: Option<TLVArray<'a, CmdData<'a>>>,
} }
// This enum is helpful when we are constructing the response
// step by step in incremental manner
pub enum InvRespTag {
SupressResponse = 0,
InvokeResponses = 1,
}
#[derive(FromTLV, ToTLV, Debug)] #[derive(FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct InvResp<'a> { pub struct InvResp<'a> {
@ -174,7 +157,14 @@ pub mod msg {
pub inv_responses: Option<TLVArray<'a, ib::InvResp<'a>>>, pub inv_responses: Option<TLVArray<'a, ib::InvResp<'a>>>,
} }
#[derive(Default, ToTLV, FromTLV)] // This enum is helpful when we are constructing the response
// step by step in incremental manner
pub enum InvRespTag {
SupressResponse = 0,
InvokeResponses = 1,
}
#[derive(Default, ToTLV, FromTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct ReadReq<'a> { pub struct ReadReq<'a> {
pub attr_requests: Option<TLVArray<'a, AttrPath>>, pub attr_requests: Option<TLVArray<'a, AttrPath>>,
@ -198,17 +188,17 @@ pub mod msg {
} }
} }
#[derive(ToTLV, FromTLV)] #[derive(FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'b")] #[tlvargs(lifetime = "'a")]
pub struct WriteReq<'a, 'b> { pub struct WriteReq<'a> {
pub supress_response: Option<bool>, pub supress_response: Option<bool>,
timed_request: Option<bool>, timed_request: Option<bool>,
pub write_requests: TLVArray<'a, AttrData<'b>>, pub write_requests: TLVArray<'a, AttrData<'a>>,
more_chunked: Option<bool>, more_chunked: Option<bool>,
} }
impl<'a, 'b> WriteReq<'a, 'b> { impl<'a> WriteReq<'a> {
pub fn new(supress_response: bool, write_requests: &'a [AttrData<'b>]) -> Self { pub fn new(supress_response: bool, write_requests: &'a [AttrData<'a>]) -> Self {
let mut w = Self { let mut w = Self {
supress_response: None, supress_response: None,
write_requests: TLVArray::new(write_requests), write_requests: TLVArray::new(write_requests),
@ -223,7 +213,7 @@ pub mod msg {
} }
// Report Data // Report Data
#[derive(FromTLV, ToTLV)] #[derive(FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct ReportDataMsg<'a> { pub struct ReportDataMsg<'a> {
pub subscription_id: Option<u32>, pub subscription_id: Option<u32>,
@ -243,7 +233,7 @@ pub mod msg {
} }
// Write Response // Write Response
#[derive(ToTLV, FromTLV)] #[derive(ToTLV, FromTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct WriteResp<'a> { pub struct WriteResp<'a> {
pub write_responses: TLVArray<'a, AttrStatus>, pub write_responses: TLVArray<'a, AttrStatus>,
@ -255,11 +245,11 @@ pub mod msg {
} }
pub mod ib { pub mod ib {
use std::fmt::Debug; use core::fmt::Debug;
use crate::{ use crate::{
data_model::objects::{AttrDetails, AttrId, ClusterId, EncodeValue, EndptId}, data_model::objects::{AttrDetails, AttrId, ClusterId, CmdId, EncodeValue, EndptId},
error::Error, error::{Error, ErrorCode},
interaction_model::core::IMStatusCode, interaction_model::core::IMStatusCode,
tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV}, tlv::{FromTLV, Nullable, TLVElement, TLVWriter, TagType, ToTLV},
}; };
@ -268,7 +258,7 @@ pub mod ib {
use super::GenericPath; use super::GenericPath;
// Command Response // Command Response
#[derive(Clone, Copy, FromTLV, ToTLV, Debug)] #[derive(Clone, FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub enum InvResp<'a> { pub enum InvResp<'a> {
Cmd(CmdData<'a>), Cmd(CmdData<'a>),
@ -276,18 +266,6 @@ pub mod ib {
} }
impl<'a> InvResp<'a> { impl<'a> InvResp<'a> {
pub fn cmd_new(
endpoint: EndptId,
cluster: ClusterId,
cmd: u16,
data: EncodeValue<'a>,
) -> Self {
Self::Cmd(CmdData::new(
CmdPath::new(Some(endpoint), Some(cluster), Some(cmd)),
data,
))
}
pub fn status_new(cmd_path: CmdPath, status: IMStatusCode, cluster_status: u16) -> Self { pub fn status_new(cmd_path: CmdPath, status: IMStatusCode, cluster_status: u16) -> Self {
Self::Status(CmdStatus { Self::Status(CmdStatus {
path: cmd_path, path: cmd_path,
@ -296,7 +274,24 @@ pub mod ib {
} }
} }
#[derive(FromTLV, ToTLV, Copy, Clone, PartialEq, Debug)] impl<'a> From<CmdData<'a>> for InvResp<'a> {
fn from(value: CmdData<'a>) -> Self {
Self::Cmd(value)
}
}
pub enum InvRespTag {
Cmd = 0,
Status = 1,
}
impl<'a> From<CmdStatus> for InvResp<'a> {
fn from(value: CmdStatus) -> Self {
Self::Status(value)
}
}
#[derive(FromTLV, ToTLV, Clone, PartialEq, Debug)]
pub struct CmdStatus { pub struct CmdStatus {
path: CmdPath, path: CmdPath,
status: Status, status: Status,
@ -314,7 +309,7 @@ pub mod ib {
} }
} }
#[derive(Debug, Clone, Copy, FromTLV, ToTLV)] #[derive(Debug, Clone, FromTLV, ToTLV)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct CmdData<'a> { pub struct CmdData<'a> {
pub path: CmdPath, pub path: CmdPath,
@ -327,8 +322,13 @@ pub mod ib {
} }
} }
pub enum CmdDataTag {
Path = 0,
Data = 1,
}
// Status // Status
#[derive(Debug, Clone, Copy, PartialEq, FromTLV, ToTLV)] #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)]
pub struct Status { pub struct Status {
pub status: IMStatusCode, pub status: IMStatusCode,
pub cluster_status: u16, pub cluster_status: u16,
@ -344,7 +344,7 @@ pub mod ib {
} }
// Attribute Response // Attribute Response
#[derive(Clone, Copy, FromTLV, ToTLV, PartialEq, Debug)] #[derive(Clone, FromTLV, ToTLV, PartialEq, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub enum AttrResp<'a> { pub enum AttrResp<'a> {
Status(AttrStatus), Status(AttrStatus),
@ -352,10 +352,6 @@ pub mod ib {
} }
impl<'a> AttrResp<'a> { impl<'a> AttrResp<'a> {
pub fn new(data_ver: u32, path: &AttrPath, data: EncodeValue<'a>) -> Self {
AttrResp::Data(AttrData::new(Some(data_ver), *path, data))
}
pub fn unwrap_data(self) -> AttrData<'a> { pub fn unwrap_data(self) -> AttrData<'a> {
match self { match self {
AttrResp::Data(d) => d, AttrResp::Data(d) => d,
@ -366,8 +362,25 @@ pub mod ib {
} }
} }
impl<'a> From<AttrData<'a>> for AttrResp<'a> {
fn from(value: AttrData<'a>) -> Self {
Self::Data(value)
}
}
impl<'a> From<AttrStatus> for AttrResp<'a> {
fn from(value: AttrStatus) -> Self {
Self::Status(value)
}
}
pub enum AttrRespTag {
Status = 0,
Data = 1,
}
// Attribute Data // Attribute Data
#[derive(Clone, Copy, PartialEq, FromTLV, ToTLV, Debug)] #[derive(Clone, PartialEq, FromTLV, ToTLV, Debug)]
#[tlvargs(lifetime = "'a")] #[tlvargs(lifetime = "'a")]
pub struct AttrData<'a> { pub struct AttrData<'a> {
pub data_ver: Option<u32>, pub data_ver: Option<u32>,
@ -385,6 +398,12 @@ pub mod ib {
} }
} }
pub enum AttrDataTag {
DataVer = 0,
Path = 1,
Data = 2,
}
#[derive(Debug)] #[derive(Debug)]
/// Operations on an Interaction Model List /// Operations on an Interaction Model List
pub enum ListOperation { pub enum ListOperation {
@ -399,13 +418,9 @@ pub mod ib {
} }
/// Attribute Lists in Attribute Data are special. Infer the correct meaning using this function /// Attribute Lists in Attribute Data are special. Infer the correct meaning using this function
pub fn attr_list_write<F>( pub fn attr_list_write<F>(attr: &AttrDetails, data: &TLVElement, mut f: F) -> Result<(), Error>
attr: &AttrDetails,
data: &TLVElement,
mut f: F,
) -> Result<(), IMStatusCode>
where where
F: FnMut(ListOperation, &TLVElement) -> Result<(), IMStatusCode>, F: FnMut(ListOperation, &TLVElement) -> Result<(), Error>,
{ {
if let Some(Nullable::NotNull(index)) = attr.list_index { if let Some(Nullable::NotNull(index)) = attr.list_index {
// If list index is valid, // If list index is valid,
@ -422,7 +437,7 @@ pub mod ib {
f(ListOperation::DeleteList, data)?; f(ListOperation::DeleteList, data)?;
// Now the data must be a list, that should be added item by item // Now the data must be a list, that should be added item by item
let container = data.enter().ok_or(Error::Invalid)?; let container = data.enter().ok_or(ErrorCode::Invalid)?;
for d in container { for d in container {
f(ListOperation::AddItem, &d)?; f(ListOperation::AddItem, &d)?;
} }
@ -433,7 +448,7 @@ pub mod ib {
} }
} }
#[derive(Debug, Clone, Copy, PartialEq, FromTLV, ToTLV)] #[derive(Debug, Clone, PartialEq, FromTLV, ToTLV)]
pub struct AttrStatus { pub struct AttrStatus {
path: AttrPath, path: AttrPath,
status: Status, status: Status,
@ -449,7 +464,7 @@ pub mod ib {
} }
// Attribute Path // Attribute Path
#[derive(Default, Clone, Copy, Debug, PartialEq, FromTLV, ToTLV)] #[derive(Default, Clone, Debug, PartialEq, FromTLV, ToTLV)]
#[tlvargs(datatype = "list")] #[tlvargs(datatype = "list")]
pub struct AttrPath { pub struct AttrPath {
pub tag_compression: Option<bool>, pub tag_compression: Option<bool>,
@ -476,7 +491,7 @@ pub mod ib {
} }
// Command Path // Command Path
#[derive(Default, Debug, Copy, Clone, PartialEq)] #[derive(Default, Debug, Clone, PartialEq)]
pub struct CmdPath { pub struct CmdPath {
pub path: GenericPath, pub path: GenericPath,
} }
@ -499,13 +514,13 @@ pub mod ib {
pub fn new( pub fn new(
endpoint: Option<EndptId>, endpoint: Option<EndptId>,
cluster: Option<ClusterId>, cluster: Option<ClusterId>,
command: Option<u16>, command: Option<CmdId>,
) -> Self { ) -> Self {
Self { Self {
path: GenericPath { path: GenericPath {
endpoint, endpoint,
cluster, cluster,
leaf: command.map(|a| a as u32), leaf: command,
}, },
} }
} }
@ -519,7 +534,7 @@ pub mod ib {
if c.path.leaf.is_none() { if c.path.leaf.is_none() {
error!("Wildcard command parameter not supported"); error!("Wildcard command parameter not supported");
Err(Error::CommandNotFound) Err(ErrorCode::CommandNotFound.into())
} else { } else {
Ok(c) Ok(c)
} }
@ -532,20 +547,20 @@ pub mod ib {
} }
} }
#[derive(FromTLV, ToTLV, Copy, Clone)] #[derive(FromTLV, ToTLV, Clone, Debug)]
pub struct ClusterPath { pub struct ClusterPath {
pub node: Option<u64>, pub node: Option<u64>,
pub endpoint: EndptId, pub endpoint: EndptId,
pub cluster: ClusterId, pub cluster: ClusterId,
} }
#[derive(FromTLV, ToTLV, Copy, Clone)] #[derive(FromTLV, ToTLV, Clone, Debug)]
pub struct DataVersionFilter { pub struct DataVersionFilter {
pub path: ClusterPath, pub path: ClusterPath,
pub data_ver: u32, pub data_ver: u32,
} }
#[derive(FromTLV, ToTLV, Copy, Clone)] #[derive(FromTLV, ToTLV, Clone, Debug)]
#[tlvargs(datatype = "list")] #[tlvargs(datatype = "list")]
pub struct EventPath { pub struct EventPath {
pub node: Option<u64>, pub node: Option<u64>,
@ -555,7 +570,7 @@ pub mod ib {
pub is_urgent: Option<bool>, pub is_urgent: Option<bool>,
} }
#[derive(FromTLV, ToTLV, Copy, Clone)] #[derive(FromTLV, ToTLV, Clone, Debug)]
pub struct EventFilter { pub struct EventFilter {
pub node: Option<u64>, pub node: Option<u64>,
pub event_min: Option<u64>, pub event_min: Option<u64>,

View file

@ -15,73 +15,5 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::{
error::Error,
tlv::TLVWriter,
transport::{exchange::Exchange, proto_demux::ResponseRequired, session::SessionHandle},
};
use self::{
core::OpCode,
messages::msg::{InvReq, StatusResp, WriteReq},
};
#[derive(PartialEq)]
pub enum TransactionState {
Ongoing,
Complete,
Terminate,
}
pub struct Transaction<'a, 'b> {
pub state: TransactionState,
pub session: &'a mut SessionHandle<'b>,
pub exch: &'a mut Exchange,
}
pub trait InteractionConsumer {
fn consume_invoke_cmd(
&self,
req: &InvReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error>;
fn consume_read_attr(
&self,
// TODO: This handling is different from the other APIs here, identify
// consistent options for this trait
req: &[u8],
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error>;
fn consume_write_attr(
&self,
req: &WriteReq,
trans: &mut Transaction,
tw: &mut TLVWriter,
) -> Result<(), Error>;
fn consume_status_report(
&self,
_req: &StatusResp,
_trans: &mut Transaction,
_tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error>;
fn consume_subscribe(
&self,
_req: &[u8],
_trans: &mut Transaction,
_tw: &mut TLVWriter,
) -> Result<(OpCode, ResponseRequired), Error>;
}
pub struct InteractionModel {
consumer: Box<dyn InteractionConsumer>,
}
pub mod command;
pub mod core; pub mod core;
pub mod messages; pub mod messages;
pub mod read;
pub mod write;

View file

@ -1,42 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::{
error::Error,
interaction_model::core::OpCode,
tlv::TLVWriter,
transport::{packet::Packet, proto_demux::ResponseRequired},
};
use super::{InteractionModel, Transaction};
impl InteractionModel {
pub fn handle_read_req(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
proto_tx.set_proto_opcode(OpCode::ReportData as u8);
let proto_tx_wb = proto_tx.get_writebuf()?;
let mut tw = TLVWriter::new(proto_tx_wb);
self.consumer.consume_read_attr(rx_buf, trans, &mut tw)?;
Ok(ResponseRequired::Yes)
}
}

View file

@ -1,58 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use log::error;
use crate::{
error::Error,
tlv::{get_root_node_struct, FromTLV, TLVWriter, TagType},
transport::{packet::Packet, proto_demux::ResponseRequired},
};
use super::{core::OpCode, messages::msg::WriteReq, InteractionModel, Transaction};
impl InteractionModel {
pub fn handle_write_req(
&mut self,
trans: &mut Transaction,
rx_buf: &[u8],
proto_tx: &mut Packet,
) -> Result<ResponseRequired, Error> {
if InteractionModel::req_timeout_handled(trans, proto_tx)? {
return Ok(ResponseRequired::Yes);
}
proto_tx.set_proto_opcode(OpCode::WriteResponse as u8);
let mut tw = TLVWriter::new(proto_tx.get_writebuf()?);
let root = get_root_node_struct(rx_buf)?;
let write_req = WriteReq::from_tlv(&root)?;
let supress_response = write_req.supress_response.unwrap_or_default();
tw.start_struct(TagType::Anonymous)?;
self.consumer
.consume_write_attr(&write_req, trans, &mut tw)?;
tw.end_container()?;
trans.complete();
if supress_response {
error!("Supress response is set, is this the expected handling?");
Ok(ResponseRequired::No)
} else {
Ok(ResponseRequired::Yes)
}
}
}

View file

@ -23,7 +23,7 @@
//! Currently Ethernet based transport is supported. //! Currently Ethernet based transport is supported.
//! //!
//! # Examples //! # Examples
//! ``` //! TODO: Fix once new API has stabilized a bit
//! use matter::{Matter, CommissioningData}; //! use matter::{Matter, CommissioningData};
//! use matter::data_model::device_types::device_type_add_on_off_light; //! use matter::data_model::device_types::device_type_add_on_off_light;
//! use matter::data_model::cluster_basic_information::BasicInfoConfig; //! use matter::data_model::cluster_basic_information::BasicInfoConfig;
@ -65,8 +65,12 @@
//! } //! }
//! // Start the Matter Daemon //! // Start the Matter Daemon
//! // matter.start_daemon().unwrap(); //! // matter.start_daemon().unwrap();
//! ``` //!
//! Start off exploring by going to the [Matter] object. //! Start off exploring by going to the [Matter] object.
#![cfg_attr(not(feature = "std"), no_std)]
#![cfg_attr(feature = "nightly", feature(async_fn_in_trait))]
#![cfg_attr(feature = "nightly", feature(impl_trait_projections))]
#![cfg_attr(feature = "nightly", allow(incomplete_features))]
pub mod acl; pub mod acl;
pub mod cert; pub mod cert;
@ -80,10 +84,29 @@ pub mod group_keys;
pub mod interaction_model; pub mod interaction_model;
pub mod mdns; pub mod mdns;
pub mod pairing; pub mod pairing;
pub mod persist;
pub mod secure_channel; pub mod secure_channel;
pub mod sys;
pub mod tlv; pub mod tlv;
pub mod transport; pub mod transport;
pub mod utils; pub mod utils;
pub use crate::core::*; pub use crate::core::*;
#[cfg(feature = "alloc")]
extern crate alloc;
#[cfg(feature = "alloc")]
#[macro_export]
macro_rules! alloc {
($val:expr) => {
alloc::boxed::Box::new($val)
};
}
#[cfg(not(feature = "alloc"))]
#[macro_export]
macro_rules! alloc {
($val:expr) => {
$val
};
}

View file

@ -15,35 +15,71 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::{Arc, Mutex, Once}; use core::fmt::Write;
use crate::{ use crate::{data_model::cluster_basic_information::BasicInfoConfig, error::Error};
error::Error,
sys::{sys_publish_service, SysMdnsService},
transport::udp::MATTER_PORT,
};
#[derive(Default)] #[cfg(all(feature = "std", target_os = "macos"))]
/// The mDNS service handler pub mod astro;
pub struct MdnsInner { pub mod builtin;
/// Vendor ID pub mod proto;
vid: u16,
/// Product ID pub trait Mdns {
pid: u16, fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error>;
/// Device name fn remove(&self, service: &str) -> Result<(), Error>;
device_name: String,
} }
pub struct Mdns { impl<T> Mdns for &mut T
inner: Mutex<MdnsInner>, where
T: Mdns,
{
fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
(**self).add(service, mode)
} }
const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00; fn remove(&self, service: &str) -> Result<(), Error> {
const SHORT_DISCRIMINATOR_SHIFT: u16 = 8; (**self).remove(service)
}
}
static mut G_MDNS: Option<Arc<Mdns>> = None; impl<T> Mdns for &T
static INIT: Once = Once::new(); where
T: Mdns,
{
fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
(**self).add(service, mode)
}
fn remove(&self, service: &str) -> Result<(), Error> {
(**self).remove(service)
}
}
#[cfg(all(feature = "std", target_os = "macos"))]
pub use astro::MdnsService;
#[cfg(all(feature = "std", target_os = "macos"))]
pub use astro::MdnsUdpBuffers;
#[cfg(any(feature = "std", feature = "embassy-net"))]
pub use builtin::MdnsRunBuffers;
#[cfg(not(all(feature = "std", target_os = "macos")))]
pub use builtin::MdnsService;
pub struct DummyMdns;
impl Mdns for DummyMdns {
fn add(&self, _service: &str, _mode: ServiceMode) -> Result<(), Error> {
Ok(())
}
fn remove(&self, _service: &str) -> Result<(), Error> {
Ok(())
}
}
pub type Service<'a> = proto::Service<'a>;
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum ServiceMode { pub enum ServiceMode {
/// The commissioned state /// The commissioned state
Commissioned, Commissioned,
@ -51,69 +87,88 @@ pub enum ServiceMode {
Commissionable(u16), Commissionable(u16),
} }
impl Mdns { impl ServiceMode {
fn new() -> Self { pub fn service<R, F: for<'a> FnOnce(&Service<'a>) -> Result<R, Error>>(
Self { &self,
inner: Mutex::new(MdnsInner { dev_att: &BasicInfoConfig,
..Default::default() matter_port: u16,
name: &str,
f: F,
) -> Result<R, Error> {
match self {
Self::Commissioned => f(&Service {
name,
service: "_matter",
protocol: "_tcp",
port: matter_port,
service_subtypes: &[],
txt_kvs: &[],
}), }),
}
}
/// Get a handle to the globally unique mDNS instance
pub fn get() -> Result<Arc<Self>, Error> {
unsafe {
INIT.call_once(|| {
G_MDNS = Some(Arc::new(Mdns::new()));
});
Ok(G_MDNS.as_ref().ok_or(Error::Invalid)?.clone())
}
}
/// Set mDNS service specific values
/// Values like vid, pid, discriminator etc
// TODO: More things like device-type etc can be added here
pub fn set_values(&self, vid: u16, pid: u16, device_name: &str) {
let mut inner = self.inner.lock().unwrap();
inner.vid = vid;
inner.pid = pid;
inner.device_name = device_name.chars().take(32).collect();
}
/// Publish a mDNS service
/// name - is the service name (comma separated subtypes may follow)
/// mode - the current service mode
#[allow(clippy::needless_pass_by_value)]
pub fn publish_service(&self, name: &str, mode: ServiceMode) -> Result<SysMdnsService, Error> {
match mode {
ServiceMode::Commissioned => {
sys_publish_service(name, "_matter._tcp", MATTER_PORT, &[])
}
ServiceMode::Commissionable(discriminator) => { ServiceMode::Commissionable(discriminator) => {
let inner = self.inner.lock().unwrap(); let discriminator_str = Self::get_discriminator_str(*discriminator);
let short = compute_short_discriminator(discriminator); let vp = Self::get_vp(dev_att.vid, dev_att.pid);
let serv_type = format!("_matterc._udp,_S{},_L{}", short, discriminator);
let str_discriminator = format!("{}", discriminator); let txt_kvs = &[
let txt_kvs = [ ("D", discriminator_str.as_str()),
["D", &str_discriminator], ("CM", "1"),
["CM", "1"], ("DN", dev_att.device_name),
["DN", &inner.device_name], ("VP", &vp),
["VP", &format!("{}+{}", inner.vid, inner.pid)], ("SII", "5000"), /* Sleepy Idle Interval */
["SII", "5000"], /* Sleepy Idle Interval */ ("SAI", "300"), /* Sleepy Active Interval */
["SAI", "300"], /* Sleepy Active Interval */ ("PH", "33"), /* Pairing Hint */
["PH", "33"], /* Pairing Hint */ ("PI", ""), /* Pairing Instruction */
["PI", ""], /* Pairing Instruction */
]; ];
sys_publish_service(name, &serv_type, MATTER_PORT, &txt_kvs)
f(&Service {
name,
service: "_matterc",
protocol: "_udp",
port: matter_port,
service_subtypes: &[
&Self::get_long_service_subtype(*discriminator),
&Self::get_short_service_type(*discriminator),
],
txt_kvs,
})
} }
} }
} }
fn get_long_service_subtype(discriminator: u16) -> heapless::String<32> {
let mut serv_type = heapless::String::new();
write!(&mut serv_type, "_L{}", discriminator).unwrap();
serv_type
}
fn get_short_service_type(discriminator: u16) -> heapless::String<32> {
let short = Self::compute_short_discriminator(discriminator);
let mut serv_type = heapless::String::new();
write!(&mut serv_type, "_S{}", short).unwrap();
serv_type
}
fn get_discriminator_str(discriminator: u16) -> heapless::String<5> {
discriminator.into()
}
fn get_vp(vid: u16, pid: u16) -> heapless::String<11> {
let mut vp = heapless::String::new();
write!(&mut vp, "{}+{}", vid, pid).unwrap();
vp
} }
fn compute_short_discriminator(discriminator: u16) -> u16 { fn compute_short_discriminator(discriminator: u16) -> u16 {
const SHORT_DISCRIMINATOR_MASK: u16 = 0xF00;
const SHORT_DISCRIMINATOR_SHIFT: u16 = 8;
(discriminator & SHORT_DISCRIMINATOR_MASK) >> SHORT_DISCRIMINATOR_SHIFT (discriminator & SHORT_DISCRIMINATOR_MASK) >> SHORT_DISCRIMINATOR_SHIFT
} }
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
@ -122,11 +177,11 @@ mod tests {
#[test] #[test]
fn can_compute_short_discriminator() { fn can_compute_short_discriminator() {
let discriminator: u16 = 0b0000_1111_0000_0000; let discriminator: u16 = 0b0000_1111_0000_0000;
let short = compute_short_discriminator(discriminator); let short = ServiceMode::compute_short_discriminator(discriminator);
assert_eq!(short, 0b1111); assert_eq!(short, 0b1111);
let discriminator: u16 = 840; let discriminator: u16 = 840;
let short = compute_short_discriminator(discriminator); let short = ServiceMode::compute_short_discriminator(discriminator);
assert_eq!(short, 3); assert_eq!(short, 3);
} }
} }

112
matter/src/mdns/astro.rs Normal file
View file

@ -0,0 +1,112 @@
use core::cell::RefCell;
use std::collections::HashMap;
use crate::{
data_model::cluster_basic_information::BasicInfoConfig,
error::{Error, ErrorCode},
transport::pipe::Pipe,
};
use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService};
use log::info;
use super::ServiceMode;
/// Only for API-compatibility with builtin::MdnsRunner
pub struct MdnsUdpBuffers(());
/// Only for API-compatibility with builtin::MdnsRunner
impl MdnsUdpBuffers {
#[inline(always)]
pub const fn new() -> Self {
Self(())
}
}
pub struct MdnsService<'a> {
dev_det: &'a BasicInfoConfig<'a>,
matter_port: u16,
services: RefCell<HashMap<String, RegisteredDnsService>>,
}
impl<'a> MdnsService<'a> {
/// This constructor takes extra parameters for API-compatibility with builtin::MdnsRunner
pub fn new(
_id: u16,
_hostname: &str,
_ip: [u8; 4],
_ipv6: Option<([u8; 16], u32)>,
dev_det: &'a BasicInfoConfig<'a>,
matter_port: u16,
) -> Self {
Self::native_new(dev_det, matter_port)
}
pub fn native_new(dev_det: &'a BasicInfoConfig<'a>, matter_port: u16) -> Self {
Self {
dev_det,
matter_port,
services: RefCell::new(HashMap::new()),
}
}
pub fn add(&self, name: &str, mode: ServiceMode) -> Result<(), Error> {
info!("Registering mDNS service {}/{:?}", name, mode);
let _ = self.remove(name);
mode.service(self.dev_det, self.matter_port, name, |service| {
let composite_service_type = if !service.service_subtypes.is_empty() {
format!(
"{}.{},{}",
service.service,
service.protocol,
service.service_subtypes.join(",")
)
} else {
format!("{}.{}", service.service, service.protocol)
};
let mut builder = DNSServiceBuilder::new(&composite_service_type, service.port)
.with_name(service.name);
for kvs in service.txt_kvs {
info!("mDNS TXT key {} val {}", kvs.0, kvs.1);
builder = builder.with_key_value(kvs.0.to_string(), kvs.1.to_string());
}
let svc = builder.register().map_err(|_| ErrorCode::MdnsError)?;
self.services.borrow_mut().insert(service.name.into(), svc);
Ok(())
})
}
pub fn remove(&self, name: &str) -> Result<(), Error> {
if self.services.borrow_mut().remove(name).is_some() {
info!("Deregistering mDNS service {}", name);
}
Ok(())
}
/// Only for API-compatibility with builtin::MdnsRunner
pub async fn run_udp(&mut self, buffers: &mut MdnsUdpBuffers) -> Result<(), Error> {
core::future::pending::<Result<(), Error>>().await
}
/// Only for API-compatibility with builtin::MdnsRunner
pub async fn run(&self, _tx_pipe: &Pipe<'_>, _rx_pipe: &Pipe<'_>) -> Result<(), Error> {
core::future::pending::<Result<(), Error>>().await
}
}
impl<'a> super::Mdns for MdnsService<'a> {
fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
MdnsService::add(self, service, mode)
}
fn remove(&self, service: &str) -> Result<(), Error> {
MdnsService::remove(self, service)
}
}

351
matter/src/mdns/builtin.rs Normal file
View file

@ -0,0 +1,351 @@
use core::{cell::RefCell, pin::pin};
use domain::base::name::FromStrError;
use domain::base::{octets::ParseError, ShortBuf};
use embassy_futures::select::select;
use embassy_time::{Duration, Timer};
use log::info;
use crate::data_model::cluster_basic_information::BasicInfoConfig;
use crate::error::{Error, ErrorCode};
use crate::transport::network::{Address, IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use crate::transport::pipe::{Chunk, Pipe};
use crate::utils::select::{EitherUnwrap, Notification};
use super::{
proto::{Host, Services},
Service, ServiceMode,
};
const IP_BROADCAST_ADDR: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
const IPV6_BROADCAST_ADDR: Ipv6Addr = Ipv6Addr::new(0xff02, 0, 0, 0, 0, 0, 0, 0x00fb);
const PORT: u16 = 5353;
#[cfg(any(feature = "std", feature = "embassy-net"))]
pub struct MdnsRunBuffers {
udp: crate::transport::udp::UdpBuffers,
tx_buf: core::mem::MaybeUninit<[u8; crate::transport::packet::MAX_TX_BUF_SIZE]>,
rx_buf: core::mem::MaybeUninit<[u8; crate::transport::packet::MAX_RX_BUF_SIZE]>,
}
#[cfg(any(feature = "std", feature = "embassy-net"))]
impl MdnsRunBuffers {
#[inline(always)]
pub const fn new() -> Self {
Self {
udp: crate::transport::udp::UdpBuffers::new(),
tx_buf: core::mem::MaybeUninit::uninit(),
rx_buf: core::mem::MaybeUninit::uninit(),
}
}
}
pub struct MdnsService<'a> {
host: Host<'a>,
#[allow(unused)]
interface: Option<u32>,
dev_det: &'a BasicInfoConfig<'a>,
matter_port: u16,
services: RefCell<heapless::Vec<(heapless::String<40>, ServiceMode), 4>>,
notification: Notification,
}
impl<'a> MdnsService<'a> {
#[inline(always)]
pub const fn new(
id: u16,
hostname: &'a str,
ip: [u8; 4],
ipv6: Option<([u8; 16], u32)>,
dev_det: &'a BasicInfoConfig<'a>,
matter_port: u16,
) -> Self {
Self {
host: Host {
id,
hostname,
ip,
ipv6: if let Some((ipv6, _)) = ipv6 {
Some(ipv6)
} else {
None
},
},
interface: if let Some((_, interface)) = ipv6 {
Some(interface)
} else {
None
},
dev_det,
matter_port,
services: RefCell::new(heapless::Vec::new()),
notification: Notification::new(),
}
}
pub fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
let mut services = self.services.borrow_mut();
services.retain(|(name, _)| name != service);
services
.push((service.into(), mode))
.map_err(|_| ErrorCode::NoSpace)?;
self.notification.signal(());
Ok(())
}
pub fn remove(&self, service: &str) -> Result<(), Error> {
let mut services = self.services.borrow_mut();
services.retain(|(name, _)| name != service);
Ok(())
}
pub fn for_each<F>(&self, mut callback: F) -> Result<(), Error>
where
F: FnMut(&Service) -> Result<(), Error>,
{
let services = self.services.borrow();
for (service, mode) in &*services {
mode.service(self.dev_det, self.matter_port, service, |service| {
callback(service)
})?;
}
Ok(())
}
#[cfg(any(feature = "std", feature = "embassy-net"))]
pub async fn run<D>(
&self,
stack: &crate::transport::network::NetworkStack<D>,
buffers: &mut MdnsRunBuffers,
) -> Result<(), Error>
where
D: crate::transport::network::NetworkStackDriver,
{
let mut udp = crate::transport::udp::UdpListener::new(
stack,
crate::transport::network::SocketAddr::new(IpAddr::V6(Ipv6Addr::UNSPECIFIED), PORT),
&mut buffers.udp,
)
.await?;
// V6 multicast does not work with smoltcp yet (see https://github.com/smoltcp-rs/smoltcp/pull/602)
#[cfg(not(feature = "embassy-net"))]
if let Some(interface) = self.interface {
udp.join_multicast_v6(IPV6_BROADCAST_ADDR, interface)
.await?;
}
udp.join_multicast_v4(
IP_BROADCAST_ADDR,
crate::transport::network::Ipv4Addr::from(self.host.ip),
)
.await?;
let tx_pipe = Pipe::new(unsafe { buffers.tx_buf.assume_init_mut() });
let rx_pipe = Pipe::new(unsafe { buffers.rx_buf.assume_init_mut() });
let tx_pipe = &tx_pipe;
let rx_pipe = &rx_pipe;
let udp = &udp;
let mut tx = pin!(async move {
loop {
{
let mut data = tx_pipe.data.lock().await;
if let Some(chunk) = data.chunk {
udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end])
.await?;
data.chunk = None;
tx_pipe.data_consumed_notification.signal(());
}
}
tx_pipe.data_supplied_notification.wait().await;
}
});
let mut rx = pin!(async move {
loop {
{
let mut data = rx_pipe.data.lock().await;
if data.chunk.is_none() {
let (len, addr) = udp.recv(data.buf).await?;
data.chunk = Some(Chunk {
start: 0,
end: len,
addr: Address::Udp(addr),
});
rx_pipe.data_supplied_notification.signal(());
}
}
rx_pipe.data_consumed_notification.wait().await;
}
});
let mut run = pin!(async move { self.run_piped(tx_pipe, rx_pipe).await });
embassy_futures::select::select3(&mut tx, &mut rx, &mut run)
.await
.unwrap()
}
pub async fn run_piped(&self, tx_pipe: &Pipe<'_>, rx_pipe: &Pipe<'_>) -> Result<(), Error> {
let mut broadcast = pin!(self.broadcast(tx_pipe));
let mut respond = pin!(self.respond(rx_pipe, tx_pipe));
select(&mut broadcast, &mut respond).await.unwrap()
}
#[allow(clippy::await_holding_refcell_ref)]
async fn broadcast(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
loop {
select(
self.notification.wait(),
Timer::after(Duration::from_secs(30)),
)
.await;
for addr in [
IpAddr::V4(IP_BROADCAST_ADDR),
IpAddr::V6(IPV6_BROADCAST_ADDR),
] {
if self.interface.is_some() || addr == IpAddr::V4(IP_BROADCAST_ADDR) {
loop {
let sent = {
let mut data = tx_pipe.data.lock().await;
if data.chunk.is_none() {
let len = self.host.broadcast(self, data.buf, 60)?;
if len > 0 {
info!("Broadasting mDNS entry to {}:{}", addr, PORT);
data.chunk = Some(Chunk {
start: 0,
end: len,
addr: Address::Udp(SocketAddr::new(addr, PORT)),
});
tx_pipe.data_supplied_notification.signal(());
}
true
} else {
false
}
};
if sent {
break;
} else {
tx_pipe.data_consumed_notification.wait().await;
}
}
}
}
}
}
#[allow(clippy::await_holding_refcell_ref)]
async fn respond(&self, rx_pipe: &Pipe<'_>, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
loop {
{
let mut rx_data = rx_pipe.data.lock().await;
if let Some(rx_chunk) = rx_data.chunk {
let data = &rx_data.buf[rx_chunk.start..rx_chunk.end];
loop {
let sent = {
let mut tx_data = tx_pipe.data.lock().await;
if tx_data.chunk.is_none() {
let len = self.host.respond(self, data, tx_data.buf, 60)?;
if len > 0 {
info!("Replying to mDNS query from {}", rx_chunk.addr);
tx_data.chunk = Some(Chunk {
start: 0,
end: len,
addr: rx_chunk.addr,
});
tx_pipe.data_supplied_notification.signal(());
}
true
} else {
false
}
};
if sent {
break;
} else {
tx_pipe.data_consumed_notification.wait().await;
}
}
// info!("Got mDNS query");
rx_data.chunk = None;
rx_pipe.data_consumed_notification.signal(());
}
}
rx_pipe.data_supplied_notification.wait().await;
}
}
}
impl<'a> Services for MdnsService<'a> {
type Error = crate::error::Error;
fn for_each<F>(&self, callback: F) -> Result<(), Error>
where
F: FnMut(&Service) -> Result<(), Error>,
{
MdnsService::for_each(self, callback)
}
}
impl<'a> super::Mdns for MdnsService<'a> {
fn add(&self, service: &str, mode: ServiceMode) -> Result<(), Error> {
MdnsService::add(self, service, mode)
}
fn remove(&self, service: &str) -> Result<(), Error> {
MdnsService::remove(self, service)
}
}
impl From<ShortBuf> for Error {
fn from(_e: ShortBuf) -> Self {
Self::new(ErrorCode::NoSpace)
}
}
impl From<ParseError> for Error {
fn from(_e: ParseError) -> Self {
Self::new(ErrorCode::MdnsError)
}
}
impl From<FromStrError> for Error {
fn from(_e: FromStrError) -> Self {
Self::new(ErrorCode::MdnsError)
}
}

508
matter/src/mdns/proto.rs Normal file
View file

@ -0,0 +1,508 @@
use core::fmt::Write;
use core::str::FromStr;
use domain::{
base::{
header::Flags,
iana::Class,
message_builder::AnswerBuilder,
name::FromStrError,
octets::{Octets256, Octets64, OctetsBuilder, ParseError},
Dname, Message, MessageBuilder, Record, Rtype, ShortBuf, ToDname,
},
rdata::{Aaaa, Ptr, Srv, Txt, A},
};
use log::trace;
pub trait Services {
type Error: From<ShortBuf> + From<ParseError> + From<FromStrError>;
fn for_each<F>(&self, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Service) -> Result<(), Self::Error>;
}
impl<T> Services for &mut T
where
T: Services,
{
type Error = T::Error;
fn for_each<F>(&self, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Service) -> Result<(), Self::Error>,
{
(**self).for_each(callback)
}
}
impl<T> Services for &T
where
T: Services,
{
type Error = T::Error;
fn for_each<F>(&self, callback: F) -> Result<(), Self::Error>
where
F: FnMut(&Service) -> Result<(), Self::Error>,
{
(**self).for_each(callback)
}
}
pub struct Host<'a> {
pub id: u16,
pub hostname: &'a str,
pub ip: [u8; 4],
pub ipv6: Option<[u8; 16]>,
}
impl<'a> Host<'a> {
pub fn broadcast<T: Services>(
&self,
services: T,
buf: &mut [u8],
ttl_sec: u32,
) -> Result<usize, T::Error> {
let buf = Buf(buf, 0);
let message = MessageBuilder::from_target(buf)?;
let mut answer = message.answer();
self.set_broadcast(services, &mut answer, ttl_sec)?;
let buf = answer.finish();
Ok(buf.1)
}
pub fn respond<T: Services>(
&self,
services: T,
data: &[u8],
buf: &mut [u8],
ttl_sec: u32,
) -> Result<usize, T::Error> {
let buf = Buf(buf, 0);
let message = MessageBuilder::from_target(buf)?;
let mut answer = message.answer();
if self.set_response(data, services, &mut answer, ttl_sec)? {
let buf = answer.finish();
Ok(buf.1)
} else {
Ok(0)
}
}
fn set_broadcast<T, F>(
&self,
services: F,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), F::Error>
where
T: OctetsBuilder + AsMut<[u8]>,
F: Services,
{
self.set_header(answer);
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
services.for_each(|service| {
service.add_service(answer, self.hostname, ttl_sec)?;
service.add_service_type(answer, ttl_sec)?;
service.add_service_subtypes(answer, ttl_sec)?;
service.add_txt(answer, ttl_sec)?;
Ok(())
})?;
Ok(())
}
fn set_response<T, F>(
&self,
data: &[u8],
services: F,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<bool, F::Error>
where
T: OctetsBuilder + AsMut<[u8]>,
F: Services,
{
self.set_header(answer);
let message = Message::from_octets(data)?;
let mut replied = false;
for question in message.question() {
trace!("Handling question {:?}", question);
let question = question?;
match question.qtype() {
Rtype::A
if question
.qname()
.name_eq(&Host::host_fqdn(self.hostname, true)?) =>
{
self.add_ipv4(answer, ttl_sec)?;
replied = true;
}
Rtype::Aaaa
if question
.qname()
.name_eq(&Host::host_fqdn(self.hostname, true)?) =>
{
self.add_ipv6(answer, ttl_sec)?;
replied = true;
}
Rtype::Srv => {
services.for_each(|service| {
if question.qname().name_eq(&service.service_fqdn(true)?) {
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
service.add_service(answer, self.hostname, ttl_sec)?;
replied = true;
}
Ok(())
})?;
}
Rtype::Ptr => {
services.for_each(|service| {
if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) {
service.add_service_type(answer, ttl_sec)?;
replied = true;
} else if question.qname().name_eq(&service.service_type_fqdn(true)?) {
// TODO
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
service.add_service(answer, self.hostname, ttl_sec)?;
service.add_service_type(answer, ttl_sec)?;
service.add_service_subtypes(answer, ttl_sec)?;
service.add_txt(answer, ttl_sec)?;
replied = true;
}
Ok(())
})?;
}
Rtype::Txt => {
services.for_each(|service| {
if question.qname().name_eq(&service.service_fqdn(true)?) {
service.add_txt(answer, ttl_sec)?;
replied = true;
}
Ok(())
})?;
}
Rtype::Any => {
// A / AAAA
if question
.qname()
.name_eq(&Host::host_fqdn(self.hostname, true)?)
{
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
replied = true;
}
// PTR
services.for_each(|service| {
if question.qname().name_eq(&Service::dns_sd_fqdn(true)?) {
service.add_service_type(answer, ttl_sec)?;
replied = true;
} else if question.qname().name_eq(&service.service_type_fqdn(true)?) {
// TODO
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
service.add_service(answer, self.hostname, ttl_sec)?;
service.add_service_type(answer, ttl_sec)?;
service.add_service_subtypes(answer, ttl_sec)?;
service.add_txt(answer, ttl_sec)?;
replied = true;
}
Ok(())
})?;
// SRV
services.for_each(|service| {
if question.qname().name_eq(&service.service_fqdn(true)?) {
self.add_ipv4(answer, ttl_sec)?;
self.add_ipv6(answer, ttl_sec)?;
service.add_service(answer, self.hostname, ttl_sec)?;
replied = true;
}
Ok(())
})?;
}
_ => (),
}
}
Ok(replied)
}
fn set_header<T: OctetsBuilder + AsMut<[u8]>>(&self, answer: &mut AnswerBuilder<T>) {
let header = answer.header_mut();
header.set_id(self.id);
header.set_opcode(domain::base::iana::Opcode::Query);
header.set_rcode(domain::base::iana::Rcode::NoError);
let mut flags = Flags::new();
flags.qr = true;
flags.aa = true;
header.set_flags(flags);
}
fn add_ipv4<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
answer.push(Record::<Dname<Octets64>, A>::new(
Self::host_fqdn(self.hostname, false).unwrap(),
Class::In,
ttl_sec,
A::from_octets(self.ip[0], self.ip[1], self.ip[2], self.ip[3]),
))
}
fn add_ipv6<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
if let Some(ip) = &self.ipv6 {
answer.push(Record::<Dname<Octets64>, Aaaa>::new(
Self::host_fqdn(self.hostname, false).unwrap(),
Class::In,
ttl_sec,
Aaaa::new((*ip).into()),
))
} else {
Ok(())
}
}
fn host_fqdn(hostname: &str, suffix: bool) -> Result<Dname<Octets64>, FromStrError> {
let suffix = if suffix { "." } else { "" };
let mut host_fqdn = heapless::String::<60>::new();
write!(host_fqdn, "{}.local{}", hostname, suffix,).unwrap();
Dname::from_str(&host_fqdn)
}
}
pub struct Service<'a> {
pub name: &'a str,
pub service: &'a str,
pub protocol: &'a str,
pub port: u16,
pub service_subtypes: &'a [&'a str],
pub txt_kvs: &'a [(&'a str, &'a str)],
}
impl<'a> Service<'a> {
fn add_service<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
hostname: &str,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
answer.push(Record::<Dname<Octets64>, Srv<_>>::new(
self.service_fqdn(false).unwrap(),
Class::In,
ttl_sec,
Srv::new(0, 0, self.port, Host::host_fqdn(hostname, false).unwrap()),
))
}
fn add_service_type<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
answer.push(Record::<Dname<Octets64>, Ptr<_>>::new(
Self::dns_sd_fqdn(false).unwrap(),
Class::In,
ttl_sec,
Ptr::new(self.service_type_fqdn(false).unwrap()),
))?;
answer.push(Record::<Dname<Octets64>, Ptr<_>>::new(
self.service_type_fqdn(false).unwrap(),
Class::In,
ttl_sec,
Ptr::new(self.service_fqdn(false).unwrap()),
))
}
fn add_service_subtypes<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
for service_subtype in self.service_subtypes {
self.add_service_subtype(answer, service_subtype, ttl_sec)?;
}
Ok(())
}
fn add_service_subtype<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
service_subtype: &str,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
answer.push(Record::<Dname<Octets64>, Ptr<_>>::new(
Self::dns_sd_fqdn(false).unwrap(),
Class::In,
ttl_sec,
Ptr::new(self.service_subtype_fqdn(service_subtype, false).unwrap()),
))?;
answer.push(Record::<Dname<Octets64>, Ptr<_>>::new(
self.service_subtype_fqdn(service_subtype, false).unwrap(),
Class::In,
ttl_sec,
Ptr::new(self.service_fqdn(false).unwrap()),
))
}
fn add_txt<T: OctetsBuilder + AsMut<[u8]>>(
&self,
answer: &mut AnswerBuilder<T>,
ttl_sec: u32,
) -> Result<(), ShortBuf> {
// only way I found to create multiple parts in a Txt
// each slice is the length and then the data
let mut octets = Octets256::new();
//octets.append_slice(&[1u8, b'X'])?;
//octets.append_slice(&[2u8, b'A', b'B'])?;
//octets.append_slice(&[0u8])?;
for (k, v) in self.txt_kvs {
octets.append_slice(&[(k.len() + v.len() + 1) as u8])?;
octets.append_slice(k.as_bytes())?;
octets.append_slice(&[b'='])?;
octets.append_slice(v.as_bytes())?;
}
let txt = Txt::from_octets(&mut octets).unwrap();
answer.push(Record::<Dname<Octets64>, Txt<_>>::new(
self.service_fqdn(false).unwrap(),
Class::In,
ttl_sec,
txt,
))
}
fn service_fqdn(&self, suffix: bool) -> Result<Dname<Octets64>, FromStrError> {
let suffix = if suffix { "." } else { "" };
let mut service_fqdn = heapless::String::<60>::new();
write!(
service_fqdn,
"{}.{}.{}.local{}",
self.name, self.service, self.protocol, suffix,
)
.unwrap();
Dname::from_str(&service_fqdn)
}
fn service_type_fqdn(&self, suffix: bool) -> Result<Dname<Octets64>, FromStrError> {
let suffix = if suffix { "." } else { "" };
let mut service_type_fqdn = heapless::String::<60>::new();
write!(
service_type_fqdn,
"{}.{}.local{}",
self.service, self.protocol, suffix,
)
.unwrap();
Dname::from_str(&service_type_fqdn)
}
fn service_subtype_fqdn(
&self,
service_subtype: &str,
suffix: bool,
) -> Result<Dname<Octets64>, FromStrError> {
let suffix = if suffix { "." } else { "" };
let mut service_subtype_fqdn = heapless::String::<40>::new();
write!(
service_subtype_fqdn,
"{}._sub.{}.{}.local{}",
service_subtype, self.service, self.protocol, suffix,
)
.unwrap();
Dname::from_str(&service_subtype_fqdn)
}
fn dns_sd_fqdn(suffix: bool) -> Result<Dname<Octets64>, FromStrError> {
if suffix {
Dname::from_str("_services._dns-sd._udp.local.")
} else {
Dname::from_str("_services._dns-sd._udp.local")
}
}
}
struct Buf<'a>(pub &'a mut [u8], pub usize);
impl<'a> OctetsBuilder for Buf<'a> {
type Octets = Self;
fn append_slice(&mut self, slice: &[u8]) -> Result<(), ShortBuf> {
if self.1 + slice.len() <= self.0.len() {
let end = self.1 + slice.len();
self.0[self.1..end].copy_from_slice(slice);
self.1 = end;
Ok(())
} else {
Err(ShortBuf)
}
}
fn truncate(&mut self, len: usize) {
self.1 = len;
}
fn freeze(self) -> Self::Octets {
self
}
fn len(&self) -> usize {
self.1
}
fn is_empty(&self) -> bool {
self.1 == 0
}
}
impl<'a> AsMut<[u8]> for Buf<'a> {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.0[..self.1]
}
}

View file

@ -15,56 +15,66 @@
* limitations under the License. * limitations under the License.
*/ */
use core::fmt::Write;
use super::*; use super::*;
pub(super) fn compute_pairing_code(comm_data: &CommissioningData) -> String { pub fn compute_pairing_code(comm_data: &CommissioningData) -> heapless::String<32> {
// 0: no Vendor ID and Product ID present in Manual Pairing Code // 0: no Vendor ID and Product ID present in Manual Pairing Code
const VID_PID_PRESENT: u8 = 0; const VID_PID_PRESENT: u8 = 0;
let passwd = passwd_from_comm_data(comm_data); let passwd = passwd_from_comm_data(comm_data);
let CommissioningData { discriminator, .. } = comm_data; let CommissioningData { discriminator, .. } = comm_data;
let mut digits = String::new(); let mut digits = heapless::String::<32>::new();
digits.push_str(&((VID_PID_PRESENT << 2) | (discriminator >> 10) as u8).to_string()); write!(
digits.push_str(&format!( &mut digits,
"{:0>5}", "{}{:0>5}{:0>4}",
((discriminator & 0x300) << 6) | (passwd & 0x3FFF) as u16 (VID_PID_PRESENT << 2) | (discriminator >> 10) as u8,
)); ((discriminator & 0x300) << 6) | (passwd & 0x3FFF) as u16,
digits.push_str(&format!("{:0>4}", passwd >> 14)); passwd >> 14
)
.unwrap();
let check_digit = digits.calculate_verhoeff_check_digit(); let mut final_digits = heapless::String::<32>::new();
digits.push_str(&check_digit.to_string()); write!(
&mut final_digits,
"{}{}",
digits,
digits.calculate_verhoeff_check_digit()
)
.unwrap();
digits final_digits
} }
pub(super) fn pretty_print_pairing_code(pairing_code: &str) { pub(super) fn pretty_print_pairing_code(pairing_code: &str) {
assert!(pairing_code.len() == 11); assert!(pairing_code.len() == 11);
let mut pretty = String::new(); let mut pretty = heapless::String::<32>::new();
pretty.push_str(&pairing_code[..4]); pretty.push_str(&pairing_code[..4]).unwrap();
pretty.push('-'); pretty.push('-').unwrap();
pretty.push_str(&pairing_code[4..8]); pretty.push_str(&pairing_code[4..8]).unwrap();
pretty.push('-'); pretty.push('-').unwrap();
pretty.push_str(&pairing_code[8..]); pretty.push_str(&pairing_code[8..]).unwrap();
info!("Pairing Code: {}", pretty); info!("Pairing Code: {}", pretty);
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::secure_channel::spake2p::VerifierData; use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand};
#[test] #[test]
fn can_compute_pairing_code() { fn can_compute_pairing_code() {
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(123456), verifier: VerifierData::new_with_pw(123456, dummy_rand),
discriminator: 250, discriminator: 250,
}; };
let pairing_code = compute_pairing_code(&comm_data); let pairing_code = compute_pairing_code(&comm_data);
assert_eq!(pairing_code, "00876800071"); assert_eq!(pairing_code, "00876800071");
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(34567890), verifier: VerifierData::new_with_pw(34567890, dummy_rand),
discriminator: 2976, discriminator: 2976,
}; };
let pairing_code = compute_pairing_code(&comm_data); let pairing_code = compute_pairing_code(&comm_data);

View file

@ -22,7 +22,6 @@ pub mod qr;
pub mod vendor_identifiers; pub mod vendor_identifiers;
use log::info; use log::info;
use qrcode::{render::unicode, QrCode, Version};
use verhoeff::Verhoeff; use verhoeff::Verhoeff;
use crate::{ use crate::{
@ -32,7 +31,7 @@ use crate::{
use self::{ use self::{
code::{compute_pairing_code, pretty_print_pairing_code}, code::{compute_pairing_code, pretty_print_pairing_code},
qr::{payload_base38_representation, print_qr_code, QrSetupPayload}, qr::{compute_qr_code, print_qr_code},
}; };
pub struct DiscoveryCapabilities { pub struct DiscoveryCapabilities {
@ -86,16 +85,18 @@ pub fn print_pairing_code_and_qr(
dev_det: &BasicInfoConfig, dev_det: &BasicInfoConfig,
comm_data: &CommissioningData, comm_data: &CommissioningData,
discovery_capabilities: DiscoveryCapabilities, discovery_capabilities: DiscoveryCapabilities,
) { buf: &mut [u8],
) -> Result<(), Error> {
let pairing_code = compute_pairing_code(comm_data); let pairing_code = compute_pairing_code(comm_data);
let qr_code_data = QrSetupPayload::new(dev_det, comm_data, discovery_capabilities); let qr_code = compute_qr_code(dev_det, comm_data, discovery_capabilities, buf)?;
let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode");
pretty_print_pairing_code(&pairing_code); pretty_print_pairing_code(&pairing_code);
print_qr_code(&data_str); print_qr_code(qr_code);
Ok(())
} }
pub(self) fn passwd_from_comm_data(comm_data: &CommissioningData) -> u32 { fn passwd_from_comm_data(comm_data: &CommissioningData) -> u32 {
// todo: should this be part of the comm_data implementation? // todo: should this be part of the comm_data implementation?
match comm_data.verifier.data { match comm_data.verifier.data {
VerifierOption::Password(pwd) => pwd, VerifierOption::Password(pwd) => pwd,

View file

@ -15,9 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
use std::collections::BTreeMap;
use crate::{ use crate::{
error::ErrorCode,
tlv::{TLVWriter, TagType}, tlv::{TLVWriter, TagType},
utils::writebuf::WriteBuf, utils::writebuf::WriteBuf,
}; };
@ -45,6 +44,7 @@ const TOTAL_PAYLOAD_DATA_SIZE_IN_BITS: usize = VERSION_FIELD_LENGTH_IN_BITS
+ PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS + PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS
+ SETUP_PINCODE_FIELD_LENGTH_IN_BITS + SETUP_PINCODE_FIELD_LENGTH_IN_BITS
+ PADDING_FIELD_LENGTH_IN_BITS; + PADDING_FIELD_LENGTH_IN_BITS;
const TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES: usize = TOTAL_PAYLOAD_DATA_SIZE_IN_BITS / 8; const TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES: usize = TOTAL_PAYLOAD_DATA_SIZE_IN_BITS / 8;
// Spec 5.1.4.2 CHIP-Common Reserved Tags // Spec 5.1.4.2 CHIP-Common Reserved Tags
@ -55,7 +55,7 @@ const SERIAL_NUMBER_TAG: u8 = 0x00;
// const COMMISSIONING_TIMEOUT_TAG: u8 = 0x04; // const COMMISSIONING_TIMEOUT_TAG: u8 = 0x04;
pub enum QRCodeInfoType { pub enum QRCodeInfoType {
String(String), String(heapless::String<128>), // TODO: Big enough?
Int32(i32), Int32(i32),
Int64(i64), Int64(i64),
UInt32(u32), UInt32(u32),
@ -63,7 +63,7 @@ pub enum QRCodeInfoType {
} }
pub enum SerialNumber { pub enum SerialNumber {
String(String), String(heapless::String<128>),
UInt32(u32), UInt32(u32),
} }
@ -78,10 +78,10 @@ pub struct QrSetupPayload<'data> {
version: u8, version: u8,
flow_type: CommissionningFlowType, flow_type: CommissionningFlowType,
discovery_capabilities: DiscoveryCapabilities, discovery_capabilities: DiscoveryCapabilities,
dev_det: &'data BasicInfoConfig, dev_det: &'data BasicInfoConfig<'data>,
comm_data: &'data CommissioningData, comm_data: &'data CommissioningData,
// we use a BTreeMap to keep the order of the optional data stable // The vec is ordered by the tag of OptionalQRCodeInfo
optional_data: BTreeMap<u8, OptionalQRCodeInfo>, optional_data: heapless::Vec<OptionalQRCodeInfo, 16>,
} }
impl<'data> QrSetupPayload<'data> { impl<'data> QrSetupPayload<'data> {
@ -98,11 +98,11 @@ impl<'data> QrSetupPayload<'data> {
discovery_capabilities, discovery_capabilities,
dev_det, dev_det,
comm_data, comm_data,
optional_data: BTreeMap::new(), optional_data: heapless::Vec::new(),
}; };
if !dev_det.serial_no.is_empty() { if !dev_det.serial_no.is_empty() {
result.add_serial_number(SerialNumber::String(dev_det.serial_no.clone())); result.add_serial_number(SerialNumber::String(dev_det.serial_no.into()));
} }
result result
@ -132,13 +132,11 @@ impl<'data> QrSetupPayload<'data> {
/// * `tag` - tag number in the [0x80-0xFF] range /// * `tag` - tag number in the [0x80-0xFF] range
/// * `data` - Data to add /// * `data` - Data to add
pub fn add_optional_vendor_data(&mut self, tag: u8, data: QRCodeInfoType) -> Result<(), Error> { pub fn add_optional_vendor_data(&mut self, tag: u8, data: QRCodeInfoType) -> Result<(), Error> {
if !is_vendor_tag(tag) { if is_vendor_tag(tag) {
return Err(Error::InvalidArgument); self.add_optional_data(tag, data)
} else {
Err(ErrorCode::InvalidArgument.into())
} }
self.optional_data
.insert(tag, OptionalQRCodeInfo { tag, data });
Ok(())
} }
/// A function to add an optional QR Code info CHIP object /// A function to add an optional QR Code info CHIP object
@ -150,16 +148,26 @@ impl<'data> QrSetupPayload<'data> {
tag: u8, tag: u8,
data: QRCodeInfoType, data: QRCodeInfoType,
) -> Result<(), Error> { ) -> Result<(), Error> {
if !is_common_tag(tag) { if is_common_tag(tag) {
return Err(Error::InvalidArgument); self.add_optional_data(tag, data)
} else {
Err(ErrorCode::InvalidArgument.into())
}
} }
self.optional_data fn add_optional_data(&mut self, tag: u8, data: QRCodeInfoType) -> Result<(), Error> {
.insert(tag, OptionalQRCodeInfo { tag, data }); let item = OptionalQRCodeInfo { tag, data };
Ok(()) let index = self.optional_data.iter().position(|info| tag < info.tag);
if let Some(index) = index {
self.optional_data.insert(index, item)
} else {
self.optional_data.push(item)
}
.map_err(|_| ErrorCode::NoSpace.into())
} }
pub fn get_all_optional_data(&self) -> &BTreeMap<u8, OptionalQRCodeInfo> { pub fn get_all_optional_data(&self) -> &[OptionalQRCodeInfo] {
&self.optional_data &self.optional_data
} }
@ -245,35 +253,30 @@ pub enum CommissionningFlowType {
Custom = 2, Custom = 2,
} }
struct TlvData { pub(super) fn payload_base38_representation<'a>(
max_data_length_in_bytes: u32, payload: &QrSetupPayload,
data_length_in_bytes: Option<usize>, buf: &'a mut [u8],
data: Option<Vec<u8>>, ) -> Result<&'a str, Error> {
} if payload.is_valid() {
let (str_buf, bits_buf, tlv_buf) = if payload.has_tlv() {
let (str_buf, buf) = buf.split_at_mut(buf.len() / 3 * 2);
pub(super) fn payload_base38_representation(payload: &QrSetupPayload) -> Result<String, Error> { let (bits_buf, tlv_buf) = buf.split_at_mut(buf.len() / 3);
let (mut bits, tlv_data) = if payload.has_tlv() {
let buffer_size = estimate_buffer_size(payload)?; (str_buf, bits_buf, Some(tlv_buf))
(
vec![0; buffer_size],
Some(TlvData {
max_data_length_in_bytes: buffer_size as u32,
data_length_in_bytes: None,
data: None,
}),
)
} else { } else {
(vec![0; TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES], None) let (str_buf, buf) = buf.split_at_mut(buf.len() / 3 * 2);
(str_buf, buf, None)
}; };
if !payload.is_valid() { payload_base38_representation_with_tlv(payload, str_buf, bits_buf, tlv_buf)
return Err(Error::InvalidArgument); } else {
Err(ErrorCode::InvalidArgument.into())
}
} }
payload_base38_representation_with_tlv(payload, &mut bits, tlv_data) pub fn estimate_buffer_size(payload: &QrSetupPayload) -> Result<usize, Error> {
}
fn estimate_buffer_size(payload: &QrSetupPayload) -> Result<usize, Error> {
// Estimate the size of the needed buffer; initialize with the size of the standard payload. // Estimate the size of the needed buffer; initialize with the size of the standard payload.
let mut estimate = TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES; let mut estimate = TOTAL_PAYLOAD_DATA_SIZE_IN_BYTES;
@ -294,15 +297,14 @@ fn estimate_buffer_size(payload: &QrSetupPayload) -> Result<usize, Error> {
size size
}; };
let vendor_data = payload.get_all_optional_data(); for data in payload.get_all_optional_data() {
vendor_data.values().for_each(|data| {
estimate += data_item_size_estimate(&data.data); estimate += data_item_size_estimate(&data.data);
}); }
estimate = estimate_struct_overhead(estimate); estimate = estimate_struct_overhead(estimate);
if estimate > u32::MAX as usize { if estimate > u32::MAX as usize {
return Err(Error::NoMemory); Err(ErrorCode::NoMemory)?;
} }
Ok(estimate) Ok(estimate)
@ -317,18 +319,38 @@ fn estimate_struct_overhead(first_field_size: usize) -> usize {
first_field_size + 4 + 2 first_field_size + 4 + 2
} }
pub(super) fn print_qr_code(qr_data: &str) { pub(super) fn print_qr_code(qr_code: &str) {
let needed_version = compute_qr_version(qr_data); info!("QR Code: {}", qr_code);
#[cfg(feature = "std")]
{
use qrcode::{render::unicode, QrCode, Version};
let needed_version = compute_qr_version(qr_code);
let code = let code =
QrCode::with_version(qr_data, Version::Normal(needed_version), qrcode::EcLevel::M).unwrap(); QrCode::with_version(qr_code, Version::Normal(needed_version), qrcode::EcLevel::M)
.unwrap();
let image = code let image = code
.render::<unicode::Dense1x2>() .render::<unicode::Dense1x2>()
.dark_color(unicode::Dense1x2::Light) .dark_color(unicode::Dense1x2::Light)
.light_color(unicode::Dense1x2::Dark) .light_color(unicode::Dense1x2::Dark)
.build(); .build();
info!("\n{}", image); info!("\n{}", image);
} }
}
pub fn compute_qr_code<'a>(
dev_det: &BasicInfoConfig,
comm_data: &CommissioningData,
discovery_capabilities: DiscoveryCapabilities,
buf: &'a mut [u8],
) -> Result<&'a str, Error> {
let qr_code_data = QrSetupPayload::new(dev_det, comm_data, discovery_capabilities);
payload_base38_representation(&qr_code_data, buf)
}
#[cfg(feature = "std")]
fn compute_qr_version(qr_data: &str) -> i16 { fn compute_qr_version(qr_data: &str) -> i16 {
match qr_data.len() { match qr_data.len() {
0..=38 => 2, 0..=38 => 2,
@ -346,11 +368,11 @@ fn populate_bits(
total_payload_data_size_in_bits: usize, total_payload_data_size_in_bits: usize,
) -> Result<(), Error> { ) -> Result<(), Error> {
if *offset + number_of_bits > total_payload_data_size_in_bits { if *offset + number_of_bits > total_payload_data_size_in_bits {
return Err(Error::InvalidArgument); Err(ErrorCode::InvalidArgument)?;
} }
if input >= 1u64 << number_of_bits { if input >= 1u64 << number_of_bits {
return Err(Error::InvalidArgument); Err(ErrorCode::InvalidArgument)?;
} }
let mut index = *offset; let mut index = *offset;
@ -368,70 +390,90 @@ fn populate_bits(
Ok(()) Ok(())
} }
fn payload_base38_representation_with_tlv( fn payload_base38_representation_with_tlv<'a>(
payload: &QrSetupPayload, payload: &QrSetupPayload,
bits: &mut [u8], str_buf: &'a mut [u8],
mut tlv_data: Option<TlvData>, bits_buf: &mut [u8],
) -> Result<String, Error> { tlv_buf: Option<&mut [u8]>,
if let Some(tlv_data) = tlv_data.as_mut() { ) -> Result<&'a str, Error> {
generate_tlv_from_optional_data(payload, tlv_data)?; let tlv_data = if let Some(tlv_buf) = tlv_buf {
Some(generate_tlv_from_optional_data(payload, tlv_buf)?)
} else {
None
};
let bits = generate_bit_set(payload, bits_buf, tlv_data)?;
let prefix = "MT:";
if str_buf.len() < prefix.as_bytes().len() {
Err(ErrorCode::NoSpace)?;
} }
let bytes_written = generate_bit_set(payload, bits, tlv_data)?; str_buf[..prefix.as_bytes().len()].copy_from_slice(prefix.as_bytes());
let base38_encoded = base38::encode(&*bits, Some(bytes_written));
Ok(format!("MT:{}", base38_encoded)) let mut offset = prefix.len();
for c in base38::encode(bits) {
let mut char_buf = [0; 4];
let str = c.encode_utf8(&mut char_buf);
if str_buf.len() - offset < str.as_bytes().len() {
Err(ErrorCode::NoSpace)?;
} }
fn generate_tlv_from_optional_data( str_buf[offset..offset + str.as_bytes().len()].copy_from_slice(str.as_bytes());
offset += str.as_bytes().len();
}
Ok(core::str::from_utf8(&str_buf[..offset])?)
}
fn generate_tlv_from_optional_data<'a>(
payload: &QrSetupPayload, payload: &QrSetupPayload,
tlv_data: &mut TlvData, tlv_buf: &'a mut [u8],
) -> Result<(), Error> { ) -> Result<&'a [u8], Error> {
let size_needed = tlv_data.max_data_length_in_bytes as usize; let mut wb = WriteBuf::new(tlv_buf);
let mut tlv_buffer = vec![0u8; size_needed];
let mut wb = WriteBuf::new(&mut tlv_buffer, size_needed);
let mut tw = TLVWriter::new(&mut wb); let mut tw = TLVWriter::new(&mut wb);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
let data = payload.get_all_optional_data();
for (tag, value) in data { for info in payload.get_all_optional_data() {
match &value.data { match &info.data {
QRCodeInfoType::String(data) => tw.utf8(TagType::Context(*tag), data.as_bytes())?, QRCodeInfoType::String(data) => tw.utf8(TagType::Context(info.tag), data.as_bytes())?,
QRCodeInfoType::Int32(data) => tw.i32(TagType::Context(*tag), *data)?, QRCodeInfoType::Int32(data) => tw.i32(TagType::Context(info.tag), *data)?,
QRCodeInfoType::Int64(data) => tw.i64(TagType::Context(*tag), *data)?, QRCodeInfoType::Int64(data) => tw.i64(TagType::Context(info.tag), *data)?,
QRCodeInfoType::UInt32(data) => tw.u32(TagType::Context(*tag), *data)?, QRCodeInfoType::UInt32(data) => tw.u32(TagType::Context(info.tag), *data)?,
QRCodeInfoType::UInt64(data) => tw.u64(TagType::Context(*tag), *data)?, QRCodeInfoType::UInt64(data) => tw.u64(TagType::Context(info.tag), *data)?,
} }
} }
tw.end_container()?; tw.end_container()?;
tlv_data.data_length_in_bytes = Some(tw.get_tail());
tlv_data.data = Some(tlv_buffer);
Ok(()) let tail = tw.get_tail();
Ok(&tlv_buf[..tail])
} }
fn generate_bit_set( fn generate_bit_set<'a>(
payload: &QrSetupPayload, payload: &QrSetupPayload,
bits: &mut [u8], bits_buf: &'a mut [u8],
tlv_data: Option<TlvData>, tlv_data: Option<&[u8]>,
) -> Result<usize, Error> { ) -> Result<&'a [u8], Error> {
let mut offset: usize = 0; let total_payload_size_in_bits =
TOTAL_PAYLOAD_DATA_SIZE_IN_BITS + tlv_data.map(|tlv_data| tlv_data.len() * 8).unwrap_or(0);
let total_payload_size_in_bits = if let Some(tlv_data) = &tlv_data { if bits_buf.len() * 8 < total_payload_size_in_bits {
TOTAL_PAYLOAD_DATA_SIZE_IN_BITS + (tlv_data.data_length_in_bytes.unwrap_or_default() * 8) Err(ErrorCode::BufferTooSmall)?;
} else {
TOTAL_PAYLOAD_DATA_SIZE_IN_BITS
};
if bits.len() * 8 < total_payload_size_in_bits {
return Err(Error::BufferTooSmall);
}; };
let passwd = passwd_from_comm_data(payload.comm_data); let passwd = passwd_from_comm_data(payload.comm_data);
let mut offset: usize = 0;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.version as u64, payload.version as u64,
VERSION_FIELD_LENGTH_IN_BITS, VERSION_FIELD_LENGTH_IN_BITS,
@ -439,7 +481,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.dev_det.vid as u64, payload.dev_det.vid as u64,
VENDOR_IDFIELD_LENGTH_IN_BITS, VENDOR_IDFIELD_LENGTH_IN_BITS,
@ -447,7 +489,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.dev_det.pid as u64, payload.dev_det.pid as u64,
PRODUCT_IDFIELD_LENGTH_IN_BITS, PRODUCT_IDFIELD_LENGTH_IN_BITS,
@ -455,7 +497,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.flow_type as u64, payload.flow_type as u64,
COMMISSIONING_FLOW_FIELD_LENGTH_IN_BITS, COMMISSIONING_FLOW_FIELD_LENGTH_IN_BITS,
@ -463,7 +505,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.discovery_capabilities.as_bits() as u64, payload.discovery_capabilities.as_bits() as u64,
RENDEZVOUS_INFO_FIELD_LENGTH_IN_BITS, RENDEZVOUS_INFO_FIELD_LENGTH_IN_BITS,
@ -471,7 +513,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
payload.comm_data.discriminator as u64, payload.comm_data.discriminator as u64,
PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS, PAYLOAD_DISCRIMINATOR_FIELD_LENGTH_IN_BITS,
@ -479,7 +521,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
passwd as u64, passwd as u64,
SETUP_PINCODE_FIELD_LENGTH_IN_BITS, SETUP_PINCODE_FIELD_LENGTH_IN_BITS,
@ -487,7 +529,7 @@ fn generate_bit_set(
)?; )?;
populate_bits( populate_bits(
bits, bits_buf,
&mut offset, &mut offset,
0, 0,
PADDING_FIELD_LENGTH_IN_BITS, PADDING_FIELD_LENGTH_IN_BITS,
@ -495,26 +537,22 @@ fn generate_bit_set(
)?; )?;
if let Some(tlv_data) = tlv_data { if let Some(tlv_data) = tlv_data {
populate_tlv_bits(bits, &mut offset, tlv_data, total_payload_size_in_bits)?; populate_tlv_bits(bits_buf, &mut offset, tlv_data, total_payload_size_in_bits)?;
} }
let bytes_written = (offset + 7) / 8; let bytes_written = (offset + 7) / 8;
Ok(bytes_written)
Ok(&bits_buf[..bytes_written])
} }
fn populate_tlv_bits( fn populate_tlv_bits(
bits: &mut [u8], bits_buf: &mut [u8],
offset: &mut usize, offset: &mut usize,
tlv_data: TlvData, tlv_data: &[u8],
total_payload_size_in_bits: usize, total_payload_size_in_bits: usize,
) -> Result<(), Error> { ) -> Result<(), Error> {
if let (Some(data), Some(data_length_in_bytes)) = (tlv_data.data, tlv_data.data_length_in_bytes) for b in tlv_data {
{ populate_bits(bits_buf, offset, *b as u64, 8, total_payload_size_in_bits)?;
for b in data.iter().take(data_length_in_bytes) {
populate_bits(bits, offset, *b as u64, 8, total_payload_size_in_bits)?;
}
} else {
return Err(Error::InvalidArgument);
} }
Ok(()) Ok(())
@ -532,16 +570,15 @@ fn is_common_tag(tag: u8) -> bool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::secure_channel::spake2p::VerifierData; use crate::{secure_channel::spake2p::VerifierData, utils::rand::dummy_rand};
#[test] #[test]
fn can_base38_encode() { fn can_base38_encode() {
const QR_CODE: &str = "MT:YNJV7VSC00CMVH7SR00"; const QR_CODE: &str = "MT:YNJV7VSC00CMVH7SR00";
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(34567890), verifier: VerifierData::new_with_pw(34567890, dummy_rand),
discriminator: 2976, discriminator: 2976,
}; };
let dev_det = BasicInfoConfig { let dev_det = BasicInfoConfig {
@ -552,7 +589,9 @@ mod tests {
let disc_cap = DiscoveryCapabilities::new(false, true, false); let disc_cap = DiscoveryCapabilities::new(false, true, false);
let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap); let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap);
let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); let mut buf = [0; 1024];
let data_str =
payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode");
assert_eq!(data_str, QR_CODE) assert_eq!(data_str, QR_CODE)
} }
@ -561,19 +600,21 @@ mod tests {
const QR_CODE: &str = "MT:-24J0AFN00KA064IJ3P0IXZB0DK5N1K8SQ1RYCU1-A40"; const QR_CODE: &str = "MT:-24J0AFN00KA064IJ3P0IXZB0DK5N1K8SQ1RYCU1-A40";
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(20202021), verifier: VerifierData::new_with_pw(20202021, dummy_rand),
discriminator: 3840, discriminator: 3840,
}; };
let dev_det = BasicInfoConfig { let dev_det = BasicInfoConfig {
vid: 65521, vid: 65521,
pid: 32769, pid: 32769,
serial_no: "1234567890".to_string(), serial_no: "1234567890",
..Default::default() ..Default::default()
}; };
let disc_cap = DiscoveryCapabilities::new(true, false, false); let disc_cap = DiscoveryCapabilities::new(true, false, false);
let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap); let qr_code_data = QrSetupPayload::new(&dev_det, &comm_data, disc_cap);
let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); let mut buf = [0; 1024];
let data_str =
payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode");
assert_eq!(data_str, QR_CODE) assert_eq!(data_str, QR_CODE)
} }
@ -588,13 +629,13 @@ mod tests {
const OPTIONAL_DEFAULT_INT_VALUE: i32 = 65550; const OPTIONAL_DEFAULT_INT_VALUE: i32 = 65550;
let comm_data = CommissioningData { let comm_data = CommissioningData {
verifier: VerifierData::new_with_pw(20202021), verifier: VerifierData::new_with_pw(20202021, dummy_rand),
discriminator: 3840, discriminator: 3840,
}; };
let dev_det = BasicInfoConfig { let dev_det = BasicInfoConfig {
vid: 65521, vid: 65521,
pid: 32769, pid: 32769,
serial_no: "1234567890".to_string(), serial_no: "1234567890",
..Default::default() ..Default::default()
}; };
@ -604,7 +645,7 @@ mod tests {
qr_code_data qr_code_data
.add_optional_vendor_data( .add_optional_vendor_data(
OPTIONAL_DEFAULT_STRING_TAG, OPTIONAL_DEFAULT_STRING_TAG,
QRCodeInfoType::String(OPTIONAL_DEFAULT_STRING_VALUE.to_string()), QRCodeInfoType::String(OPTIONAL_DEFAULT_STRING_VALUE.into()),
) )
.expect("Failed to add optional data"); .expect("Failed to add optional data");
@ -617,7 +658,9 @@ mod tests {
) )
.expect("Failed to add optional data"); .expect("Failed to add optional data");
let data_str = payload_base38_representation(&qr_code_data).expect("Failed to encode"); let mut buf = [0; 1024];
let data_str =
payload_base38_representation(&qr_code_data, &mut buf).expect("Failed to encode");
assert_eq!(data_str, QR_CODE) assert_eq!(data_str, QR_CODE)
} }
} }

116
matter/src/persist.rs Normal file
View file

@ -0,0 +1,116 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#[cfg(feature = "std")]
pub use fileio::*;
#[cfg(feature = "std")]
pub mod fileio {
use std::fs;
use std::io::{Read, Write};
use std::path::{Path, PathBuf};
use log::info;
use crate::error::{Error, ErrorCode};
use crate::Matter;
pub struct Psm<'a> {
matter: &'a Matter<'a>,
dir: PathBuf,
buf: [u8; 4096],
}
impl<'a> Psm<'a> {
#[inline(always)]
pub fn new(matter: &'a Matter<'a>, dir: PathBuf) -> Result<Self, Error> {
fs::create_dir_all(&dir)?;
info!("Persisting from/to {}", dir.display());
let mut buf = [0; 4096];
if let Some(data) = Self::load(&dir, "acls", &mut buf)? {
matter.load_acls(data)?;
}
if let Some(data) = Self::load(&dir, "fabrics", &mut buf)? {
matter.load_fabrics(data)?;
}
Ok(Self { matter, dir, buf })
}
pub async fn run(&mut self) -> Result<(), Error> {
loop {
self.matter.wait_changed().await;
if self.matter.is_changed() {
if let Some(data) = self.matter.store_acls(&mut self.buf)? {
Self::store(&self.dir, "acls", data)?;
}
if let Some(data) = self.matter.store_fabrics(&mut self.buf)? {
Self::store(&self.dir, "fabrics", data)?;
}
}
}
}
fn load<'b>(dir: &Path, key: &str, buf: &'b mut [u8]) -> Result<Option<&'b [u8]>, Error> {
let path = dir.join(key);
match fs::File::open(path) {
Ok(mut file) => {
let mut offset = 0;
loop {
if offset == buf.len() {
Err(ErrorCode::NoSpace)?;
}
let len = file.read(&mut buf[offset..])?;
if len == 0 {
break;
}
offset += len;
}
let data = &buf[..offset];
info!("Key {}: loaded {} bytes {:?}", key, data.len(), data);
Ok(Some(data))
}
Err(_) => Ok(None),
}
}
fn store(dir: &Path, key: &str, data: &[u8]) -> Result<(), Error> {
let path = dir.join(key);
let mut file = fs::File::create(path)?;
file.write_all(data)?;
info!("Key {}: stored {} bytes {:?}", key, data.len(), data);
Ok(())
}
}
}

View file

@ -15,37 +15,30 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::Arc; use core::cell::RefCell;
use log::{error, trace}; use log::{error, trace};
use owning_ref::RwLockReadGuardRef;
use rand::prelude::*;
use crate::{ use crate::{
alloc,
cert::Cert, cert::Cert,
crypto::{self, CryptoKeyPair, KeyPair, Sha256}, crypto::{self, KeyPair, Sha256},
error::Error, error::{Error, ErrorCode},
fabric::{Fabric, FabricMgr, FabricMgrInner}, fabric::{Fabric, FabricMgr},
secure_channel::common::SCStatusCodes, secure_channel::common::{self, OpCode, PROTO_ID_SECURE_CHANNEL},
secure_channel::common::{self, OpCode}, secure_channel::common::{complete_with_status, SCStatusCodes},
tlv::{get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType}, tlv::{get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType},
transport::{ transport::{
exchange::Exchange,
network::Address, network::Address,
proto_demux::{ProtoCtx, ResponseRequired}, packet::Packet,
queue::{Msg, WorkQ},
session::{CaseDetails, CloneData, NocCatIds, SessionMode}, session::{CaseDetails, CloneData, NocCatIds, SessionMode},
}, },
utils::writebuf::WriteBuf, utils::{rand::Rand, writebuf::WriteBuf},
}; };
#[derive(PartialEq)] #[derive(Debug, Clone)]
enum State { struct CaseSession {
Sigma1Rx,
Sigma3Rx,
}
pub struct CaseSession {
state: State,
peer_sessid: u16, peer_sessid: u16,
local_sessid: u16, local_sessid: u16,
tt_hash: Sha256, tt_hash: Sha256,
@ -54,12 +47,13 @@ pub struct CaseSession {
peer_pub_key: [u8; crypto::EC_POINT_LEN_BYTES], peer_pub_key: [u8; crypto::EC_POINT_LEN_BYTES],
local_fabric_idx: usize, local_fabric_idx: usize,
} }
impl CaseSession { impl CaseSession {
pub fn new(peer_sessid: u16, local_sessid: u16) -> Result<Self, Error> { #[inline(always)]
pub fn new() -> Result<Self, Error> {
Ok(Self { Ok(Self {
state: State::Sigma1Rx, peer_sessid: 0,
peer_sessid, local_sessid: 0,
local_sessid,
tt_hash: Sha256::new()?, tt_hash: Sha256::new()?,
shared_secret: [0; crypto::ECDH_SHARED_SECRET_LEN_BYTES], shared_secret: [0; crypto::ECDH_SHARED_SECRET_LEN_BYTES],
our_pub_key: [0; crypto::EC_POINT_LEN_BYTES], our_pub_key: [0; crypto::EC_POINT_LEN_BYTES],
@ -69,70 +63,91 @@ impl CaseSession {
} }
} }
pub struct Case { pub struct Case<'a> {
fabric_mgr: Arc<FabricMgr>, fabric_mgr: &'a RefCell<FabricMgr>,
rand: Rand,
} }
impl Case { impl<'a> Case<'a> {
pub fn new(fabric_mgr: Arc<FabricMgr>) -> Self { #[inline(always)]
Self { fabric_mgr } pub fn new(fabric_mgr: &'a RefCell<FabricMgr>, rand: Rand) -> Self {
Self { fabric_mgr, rand }
} }
pub fn casesigma3_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { pub async fn handle(
let mut case_session = ctx &mut self,
.exch_ctx exchange: &mut Exchange<'_>,
.exch rx: &mut Packet<'_>,
.take_data_boxed::<CaseSession>() tx: &mut Packet<'_>,
.ok_or(Error::InvalidState)?; ) -> Result<(), Error> {
if case_session.state != State::Sigma1Rx { let mut session = alloc!(CaseSession::new()?);
return Err(Error::Invalid);
}
case_session.state = State::Sigma3Rx;
let fabric = self.fabric_mgr.get_fabric(case_session.local_fabric_idx)?; self.handle_casesigma1(exchange, rx, tx, &mut session)
.await?;
self.handle_casesigma3(exchange, rx, tx, &mut session).await
}
#[allow(clippy::await_holding_refcell_ref)]
async fn handle_casesigma3(
&mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
case_session: &mut CaseSession,
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::CASESigma3 as _)?;
let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
if fabric.is_none() { if fabric.is_none() {
common::create_sc_status_report( drop(fabric_mgr);
&mut ctx.tx, complete_with_status(
exchange,
tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )
ctx.exch_ctx.exch.close(); .await?;
return Ok(ResponseRequired::Yes); return Ok(());
} }
// Safe to unwrap here // Safe to unwrap here
let fabric = fabric.as_ref().as_ref().unwrap(); let fabric = fabric.unwrap();
let root = get_root_node_struct(ctx.rx.as_borrow_slice())?; let root = get_root_node_struct(rx.as_slice())?;
let encrypted = root.find_tag(1)?.slice()?; let encrypted = root.find_tag(1)?.slice()?;
let mut decrypted: [u8; 800] = [0; 800]; let mut decrypted = alloc!([0; 800]);
if encrypted.len() > decrypted.len() { if encrypted.len() > decrypted.len() {
error!("Data too large"); error!("Data too large");
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
let decrypted = &mut decrypted[..encrypted.len()]; let decrypted = &mut decrypted[..encrypted.len()];
decrypted.copy_from_slice(encrypted); decrypted.copy_from_slice(encrypted);
let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), &case_session, decrypted)?; let len = Case::get_sigma3_decryption(fabric.ipk.op_key(), case_session, decrypted)?;
let decrypted = &decrypted[..len]; let decrypted = &decrypted[..len];
let root = get_root_node_struct(decrypted)?; let root = get_root_node_struct(decrypted)?;
let d = Sigma3Decrypt::from_tlv(&root)?; let d = Sigma3Decrypt::from_tlv(&root)?;
let initiator_noc = Cert::new(d.initiator_noc.0)?; let initiator_noc = alloc!(Cert::new(d.initiator_noc.0)?);
let mut initiator_icac = None; let mut initiator_icac = None;
if let Some(icac) = d.initiator_icac { if let Some(icac) = d.initiator_icac {
initiator_icac = Some(Cert::new(icac.0)?); initiator_icac = Some(alloc!(Cert::new(icac.0)?));
} }
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, &initiator_icac) {
#[cfg(feature = "alloc")]
let initiator_icac_mut = initiator_icac.as_deref();
#[cfg(not(feature = "alloc"))]
let initiator_icac_mut = initiator_icac.as_ref();
if let Err(e) = Case::validate_certs(fabric, &initiator_noc, initiator_icac_mut) {
error!("Certificate Chain doesn't match: {}", e); error!("Certificate Chain doesn't match: {}", e);
common::create_sc_status_report( complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
&mut ctx.tx, .await?;
common::SCStatusCodes::InvalidParameter, return Ok(());
None,
)?;
ctx.exch_ctx.exch.close();
return Ok(ResponseRequired::Yes);
} }
if Case::validate_sigma3_sign( if Case::validate_sigma3_sign(
@ -140,74 +155,80 @@ impl Case {
d.initiator_icac.map(|a| a.0), d.initiator_icac.map(|a| a.0),
&initiator_noc, &initiator_noc,
d.signature.0, d.signature.0,
&case_session, case_session,
) )
.is_err() .is_err()
{ {
error!("Sigma3 Signature doesn't match"); error!("Sigma3 Signature doesn't match");
common::create_sc_status_report( complete_with_status(exchange, tx, common::SCStatusCodes::InvalidParameter, None)
&mut ctx.tx, .await?;
common::SCStatusCodes::InvalidParameter, return Ok(());
None,
)?;
ctx.exch_ctx.exch.close();
return Ok(ResponseRequired::Yes);
} }
// Only now do we add this message to the TT Hash // Only now do we add this message to the TT Hash
let mut peer_catids: NocCatIds = Default::default(); let mut peer_catids: NocCatIds = Default::default();
initiator_noc.get_cat_ids(&mut peer_catids); initiator_noc.get_cat_ids(&mut peer_catids);
case_session.tt_hash.update(ctx.rx.as_borrow_slice())?; case_session.tt_hash.update(rx.as_slice())?;
let clone_data = Case::get_session_clone_data( let clone_data = Case::get_session_clone_data(
fabric.ipk.op_key(), fabric.ipk.op_key(),
fabric.get_node_id(), fabric.get_node_id(),
initiator_noc.get_node_id()?, initiator_noc.get_node_id()?,
ctx.exch_ctx.sess.get_peer_addr(), exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
&case_session, case_session,
&peer_catids, &peer_catids,
)?; )?;
// Queue a transport mgr request to add a new session
WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?;
common::create_sc_status_report( // TODO: Handle NoSpace
&mut ctx.tx, exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
complete_with_status(
exchange,
tx,
SCStatusCodes::SessionEstablishmentSuccess, SCStatusCodes::SessionEstablishmentSuccess,
None, None,
)?; )
ctx.exch_ctx.exch.clear_data_boxed(); .await
ctx.exch_ctx.exch.close();
Ok(ResponseRequired::Yes)
} }
pub fn casesigma1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { #[allow(clippy::await_holding_refcell_ref)]
ctx.tx.set_proto_opcode(OpCode::CASESigma2 as u8); async fn handle_casesigma1(
&mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
case_session: &mut CaseSession,
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::CASESigma1 as _)?;
let rx_buf = ctx.rx.as_borrow_slice(); let rx_buf = rx.as_slice();
let root = get_root_node_struct(rx_buf)?; let root = get_root_node_struct(rx_buf)?;
let r = Sigma1Req::from_tlv(&root)?; let r = Sigma1Req::from_tlv(&root)?;
let local_fabric_idx = self let local_fabric_idx = self
.fabric_mgr .fabric_mgr
.borrow_mut()
.match_dest_id(r.initiator_random.0, r.dest_id.0); .match_dest_id(r.initiator_random.0, r.dest_id.0);
if local_fabric_idx.is_err() { if local_fabric_idx.is_err() {
error!("Fabric Index mismatch"); error!("Fabric Index mismatch");
common::create_sc_status_report( complete_with_status(
&mut ctx.tx, exchange,
tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )
ctx.exch_ctx.exch.close(); .await?;
return Ok(ResponseRequired::Yes);
return Ok(());
} }
let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
let mut case_session = Box::new(CaseSession::new(r.initiator_sessid, local_sessid)?); case_session.peer_sessid = r.initiator_sessid;
case_session.local_sessid = local_sessid;
case_session.tt_hash.update(rx_buf)?; case_session.tt_hash.update(rx_buf)?;
case_session.local_fabric_idx = local_fabric_idx?; case_session.local_fabric_idx = local_fabric_idx?;
if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES { if r.peer_pub_key.0.len() != crypto::EC_POINT_LEN_BYTES {
error!("Invalid public key length"); error!("Invalid public key length");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
case_session.peer_pub_key.copy_from_slice(r.peer_pub_key.0); case_session.peer_pub_key.copy_from_slice(r.peer_pub_key.0);
trace!( trace!(
@ -216,66 +237,88 @@ impl Case {
); );
// Create an ephemeral Key Pair // Create an ephemeral Key Pair
let key_pair = KeyPair::new()?; let key_pair = KeyPair::new(self.rand)?;
let _ = key_pair.get_public_key(&mut case_session.our_pub_key)?; let _ = key_pair.get_public_key(&mut case_session.our_pub_key)?;
// Derive the Shared Secret // Derive the Shared Secret
let len = key_pair.derive_secret(r.peer_pub_key.0, &mut case_session.shared_secret)?; let len = key_pair.derive_secret(r.peer_pub_key.0, &mut case_session.shared_secret)?;
if len != 32 { if len != 32 {
error!("Derived secret length incorrect"); error!("Derived secret length incorrect");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
// println!("Derived secret: {:x?} len: {}", secret, len); // println!("Derived secret: {:x?} len: {}", secret, len);
let mut our_random: [u8; 32] = [0; 32]; let mut our_random: [u8; 32] = [0; 32];
rand::thread_rng().fill_bytes(&mut our_random); (self.rand)(&mut our_random);
// Derive the Encrypted Part // Derive the Encrypted Part
const MAX_ENCRYPTED_SIZE: usize = 800; const MAX_ENCRYPTED_SIZE: usize = 800;
let mut encrypted: [u8; MAX_ENCRYPTED_SIZE] = [0; MAX_ENCRYPTED_SIZE]; let mut encrypted = alloc!([0; MAX_ENCRYPTED_SIZE]);
let encrypted_len = { let encrypted_len = {
let mut signature = [0u8; crypto::EC_SIGNATURE_LEN_BYTES]; let mut signature = alloc!([0u8; crypto::EC_SIGNATURE_LEN_BYTES]);
let fabric = self.fabric_mgr.get_fabric(case_session.local_fabric_idx)?; let fabric_mgr = self.fabric_mgr.borrow();
let fabric = fabric_mgr.get_fabric(case_session.local_fabric_idx)?;
if fabric.is_none() { if fabric.is_none() {
common::create_sc_status_report( drop(fabric_mgr);
&mut ctx.tx, complete_with_status(
exchange,
tx,
common::SCStatusCodes::NoSharedTrustRoots, common::SCStatusCodes::NoSharedTrustRoots,
None, None,
)?; )
ctx.exch_ctx.exch.close(); .await?;
return Ok(ResponseRequired::Yes); return Ok(());
} }
#[cfg(feature = "alloc")]
let signature_mut = &mut *signature;
#[cfg(not(feature = "alloc"))]
let signature_mut = &mut signature;
let sign_len = Case::get_sigma2_sign( let sign_len = Case::get_sigma2_sign(
&fabric, fabric.unwrap(),
&case_session.our_pub_key, &case_session.our_pub_key,
&case_session.peer_pub_key, &case_session.peer_pub_key,
&mut signature, signature_mut,
)?; )?;
let signature = &signature[..sign_len]; let signature = &signature[..sign_len];
#[cfg(feature = "alloc")]
let encrypted_mut = &mut *encrypted;
#[cfg(not(feature = "alloc"))]
let encrypted_mut = &mut encrypted;
Case::get_sigma2_encryption( Case::get_sigma2_encryption(
&fabric, fabric.unwrap(),
self.rand,
&our_random, &our_random,
&mut case_session, case_session,
signature, signature,
&mut encrypted, encrypted_mut,
)? )?
}; };
let encrypted = &encrypted[0..encrypted_len]; let encrypted = &encrypted[0..encrypted_len];
// Generate our Response Body // Generate our Response Body
let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::CASESigma2 as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str8(TagType::Context(1), &our_random)?; tw.str8(TagType::Context(1), &our_random)?;
tw.u16(TagType::Context(2), local_sessid)?; tw.u16(TagType::Context(2), local_sessid)?;
tw.str8(TagType::Context(3), &case_session.our_pub_key)?; tw.str8(TagType::Context(3), &case_session.our_pub_key)?;
tw.str16(TagType::Context(4), encrypted)?; tw.str16(TagType::Context(4), encrypted)?;
tw.end_container()?; tw.end_container()?;
case_session.tt_hash.update(ctx.tx.as_borrow_slice())?;
ctx.exch_ctx.exch.set_data_boxed(case_session); case_session.tt_hash.update(tx.as_mut_slice())?;
Ok(ResponseRequired::Yes)
exchange.exchange(tx, rx).await
} }
fn get_session_clone_data( fn get_session_clone_data(
@ -322,8 +365,8 @@ impl Case {
case_session: &CaseSession, case_session: &CaseSession,
) -> Result<(), Error> { ) -> Result<(), Error> {
const MAX_TBS_SIZE: usize = 800; const MAX_TBS_SIZE: usize = 800;
let mut buf: [u8; MAX_TBS_SIZE] = [0; MAX_TBS_SIZE]; let mut buf = [0; MAX_TBS_SIZE];
let mut write_buf = WriteBuf::new(&mut buf, MAX_TBS_SIZE); let mut write_buf = WriteBuf::new(&mut buf);
let mut tw = TLVWriter::new(&mut write_buf); let mut tw = TLVWriter::new(&mut write_buf);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str16(TagType::Context(1), initiator_noc)?; tw.str16(TagType::Context(1), initiator_noc)?;
@ -339,24 +382,26 @@ impl Case {
Ok(()) Ok(())
} }
fn validate_certs(fabric: &Fabric, noc: &Cert, icac: &Option<Cert>) -> Result<(), Error> { fn validate_certs(fabric: &Fabric, noc: &Cert, icac: Option<&Cert>) -> Result<(), Error> {
let mut verifier = noc.verify_chain_start(); let mut verifier = noc.verify_chain_start();
if fabric.get_fabric_id() != noc.get_fabric_id()? { if fabric.get_fabric_id() != noc.get_fabric_id()? {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
if let Some(icac) = icac { if let Some(icac) = icac {
// If ICAC is present handle it // If ICAC is present handle it
if let Ok(fid) = icac.get_fabric_id() { if let Ok(fid) = icac.get_fabric_id() {
if fid != fabric.get_fabric_id() { if fid != fabric.get_fabric_id() {
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
} }
verifier = verifier.add_cert(icac)?; verifier = verifier.add_cert(icac)?;
} }
verifier.add_cert(&fabric.root_ca)?.finalise()?; verifier
.add_cert(&Cert::new(&fabric.root_ca)?)?
.finalise()?;
Ok(()) Ok(())
} }
@ -370,18 +415,18 @@ impl Case {
0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x73, 0x53, 0x65, 0x73, 0x73, 0x69, 0x6f, 0x6e, 0x4b, 0x65, 0x79, 0x73,
]; ];
if key.len() < 48 { if key.len() < 48 {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
let mut salt = Vec::<u8>::with_capacity(256); let mut salt = heapless::Vec::<u8, 256>::new();
salt.extend_from_slice(ipk); salt.extend_from_slice(ipk).unwrap();
let tt = tt.clone(); let tt = tt.clone();
let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES];
tt.finish(&mut tt_hash)?; tt.finish(&mut tt_hash)?;
salt.extend_from_slice(&tt_hash); salt.extend_from_slice(&tt_hash).unwrap();
// println!("Session Key: salt: {:x?}, len: {}", salt, salt.len()); // println!("Session Key: salt: {:x?}, len: {}", salt, salt.len());
crypto::hkdf_sha256(salt.as_slice(), shared_secret, &SEKEYS_INFO, key) crypto::hkdf_sha256(salt.as_slice(), shared_secret, &SEKEYS_INFO, key)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// println!("Session Key: key: {:x?}", key); // println!("Session Key: key: {:x?}", key);
Ok(()) Ok(())
@ -418,20 +463,20 @@ impl Case {
) -> Result<(), Error> { ) -> Result<(), Error> {
const S3K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x33]; const S3K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x33];
if key.len() < 16 { if key.len() < 16 {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
let mut salt = Vec::<u8>::with_capacity(256); let mut salt = heapless::Vec::<u8, 256>::new();
salt.extend_from_slice(ipk); salt.extend_from_slice(ipk).unwrap();
let tt = tt.clone(); let tt = tt.clone();
let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES];
tt.finish(&mut tt_hash)?; tt.finish(&mut tt_hash)?;
salt.extend_from_slice(&tt_hash); salt.extend_from_slice(&tt_hash).unwrap();
// println!("Sigma3Key: salt: {:x?}, len: {}", salt, salt.len()); // println!("Sigma3Key: salt: {:x?}, len: {}", salt, salt.len());
crypto::hkdf_sha256(salt.as_slice(), shared_secret, &S3K_INFO, key) crypto::hkdf_sha256(salt.as_slice(), shared_secret, &S3K_INFO, key)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// println!("Sigma3Key: key: {:x?}", key); // println!("Sigma3Key: key: {:x?}", key);
Ok(()) Ok(())
@ -445,39 +490,37 @@ impl Case {
) -> Result<(), Error> { ) -> Result<(), Error> {
const S2K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x32]; const S2K_INFO: [u8; 6] = [0x53, 0x69, 0x67, 0x6d, 0x61, 0x32];
if key.len() < 16 { if key.len() < 16 {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
let mut salt = Vec::<u8>::with_capacity(256); let mut salt = heapless::Vec::<u8, 256>::new();
salt.extend_from_slice(ipk); salt.extend_from_slice(ipk).unwrap();
salt.extend_from_slice(our_random); salt.extend_from_slice(our_random).unwrap();
salt.extend_from_slice(&case_session.our_pub_key); salt.extend_from_slice(&case_session.our_pub_key).unwrap();
let tt = case_session.tt_hash.clone(); let tt = case_session.tt_hash.clone();
let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; let mut tt_hash = [0u8; crypto::SHA256_HASH_LEN_BYTES];
tt.finish(&mut tt_hash)?; tt.finish(&mut tt_hash)?;
salt.extend_from_slice(&tt_hash); salt.extend_from_slice(&tt_hash).unwrap();
// println!("Sigma2Key: salt: {:x?}, len: {}", salt, salt.len()); // println!("Sigma2Key: salt: {:x?}, len: {}", salt, salt.len());
crypto::hkdf_sha256(salt.as_slice(), &case_session.shared_secret, &S2K_INFO, key) crypto::hkdf_sha256(salt.as_slice(), &case_session.shared_secret, &S2K_INFO, key)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// println!("Sigma2Key: key: {:x?}", key); // println!("Sigma2Key: key: {:x?}", key);
Ok(()) Ok(())
} }
fn get_sigma2_encryption( fn get_sigma2_encryption(
fabric: &RwLockReadGuardRef<FabricMgrInner, Option<Fabric>>, fabric: &Fabric,
rand: Rand,
our_random: &[u8], our_random: &[u8],
case_session: &mut CaseSession, case_session: &mut CaseSession,
signature: &[u8], signature: &[u8],
out: &mut [u8], out: &mut [u8],
) -> Result<usize, Error> { ) -> Result<usize, Error> {
let mut resumption_id: [u8; 16] = [0; 16]; let mut resumption_id: [u8; 16] = [0; 16];
rand::thread_rng().fill_bytes(&mut resumption_id); rand(&mut resumption_id);
// We are guaranteed this unwrap will work
let fabric = fabric.as_ref().as_ref().unwrap();
let mut sigma2_key = [0_u8; crypto::SYMM_KEY_LEN_BYTES]; let mut sigma2_key = [0_u8; crypto::SYMM_KEY_LEN_BYTES];
Case::get_sigma2_key( Case::get_sigma2_key(
@ -487,12 +530,12 @@ impl Case {
&mut sigma2_key, &mut sigma2_key,
)?; )?;
let mut write_buf = WriteBuf::new(out, out.len()); let mut write_buf = WriteBuf::new(out);
let mut tw = TLVWriter::new(&mut write_buf); let mut tw = TLVWriter::new(&mut write_buf);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; tw.str16(TagType::Context(1), &fabric.noc)?;
if let Some(icac_cert) = &fabric.icac { if let Some(icac_cert) = fabric.icac.as_ref() {
tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))? tw.str16(TagType::Context(2), icac_cert)?
}; };
tw.str8(TagType::Context(3), signature)?; tw.str8(TagType::Context(3), signature)?;
@ -521,21 +564,20 @@ impl Case {
} }
fn get_sigma2_sign( fn get_sigma2_sign(
fabric: &RwLockReadGuardRef<FabricMgrInner, Option<Fabric>>, fabric: &Fabric,
our_pub_key: &[u8], our_pub_key: &[u8],
peer_pub_key: &[u8], peer_pub_key: &[u8],
signature: &mut [u8], signature: &mut [u8],
) -> Result<usize, Error> { ) -> Result<usize, Error> {
// We are guaranteed this unwrap will work // We are guaranteed this unwrap will work
let fabric = fabric.as_ref().as_ref().unwrap();
const MAX_TBS_SIZE: usize = 800; const MAX_TBS_SIZE: usize = 800;
let mut buf: [u8; MAX_TBS_SIZE] = [0; MAX_TBS_SIZE]; let mut buf = [0; MAX_TBS_SIZE];
let mut write_buf = WriteBuf::new(&mut buf, MAX_TBS_SIZE); let mut write_buf = WriteBuf::new(&mut buf);
let mut tw = TLVWriter::new(&mut write_buf); let mut tw = TLVWriter::new(&mut write_buf);
tw.start_struct(TagType::Anonymous)?; tw.start_struct(TagType::Anonymous)?;
tw.str16_as(TagType::Context(1), |buf| fabric.noc.as_tlv(buf))?; tw.str16(TagType::Context(1), &fabric.noc)?;
if let Some(icac_cert) = &fabric.icac { if let Some(icac_cert) = fabric.icac.as_deref() {
tw.str16_as(TagType::Context(2), |buf| icac_cert.as_tlv(buf))?; tw.str16(TagType::Context(2), icac_cert)?;
} }
tw.str8(TagType::Context(3), our_pub_key)?; tw.str8(TagType::Context(3), our_pub_key)?;
tw.str8(TagType::Context(4), peer_pub_key)?; tw.str8(TagType::Context(4), peer_pub_key)?;

View file

@ -15,25 +15,19 @@
* limitations under the License. * limitations under the License.
*/ */
use boxslab::Slab;
use log::info;
use num_derive::FromPrimitive; use num_derive::FromPrimitive;
use crate::{ use crate::{
error::Error, error::Error,
transport::{ transport::{exchange::Exchange, packet::Packet},
exchange::Exchange,
packet::{Packet, PacketPool},
session::SessionHandle,
},
}; };
use super::status_report::{create_status_report, GeneralCode}; use super::status_report::{create_status_report, GeneralCode};
/* Interaction Model ID as per the Matter Spec */ /* Interaction Model ID as per the Matter Spec */
pub const PROTO_ID_SECURE_CHANNEL: usize = 0x00; pub const PROTO_ID_SECURE_CHANNEL: u16 = 0x00;
#[derive(FromPrimitive, Debug)] #[derive(FromPrimitive, Debug, Copy, Clone, Eq, PartialEq)]
pub enum OpCode { pub enum OpCode {
MsgCounterSyncReq = 0x00, MsgCounterSyncReq = 0x00,
MsgCounterSyncResp = 0x01, MsgCounterSyncResp = 0x01,
@ -60,6 +54,17 @@ pub enum SCStatusCodes {
SessionNotFound = 5, SessionNotFound = 5,
} }
pub async fn complete_with_status(
exchange: &mut Exchange<'_>,
tx: &mut Packet<'_>,
status_code: SCStatusCodes,
proto_data: Option<&[u8]>,
) -> Result<(), Error> {
create_sc_status_report(tx, status_code, proto_data)?;
exchange.send_complete(tx).await
}
pub fn create_sc_status_report( pub fn create_sc_status_report(
proto_tx: &mut Packet, proto_tx: &mut Packet,
status_code: SCStatusCodes, status_code: SCStatusCodes,
@ -78,6 +83,7 @@ pub fn create_sc_status_report(
| SCStatusCodes::NoSharedTrustRoots | SCStatusCodes::NoSharedTrustRoots
| SCStatusCodes::SessionNotFound => GeneralCode::Failure, | SCStatusCodes::SessionNotFound => GeneralCode::Failure,
}; };
create_status_report( create_status_report(
proto_tx, proto_tx,
general_code, general_code,
@ -88,14 +94,16 @@ pub fn create_sc_status_report(
} }
pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) { pub fn create_mrp_standalone_ack(proto_tx: &mut Packet) {
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); proto_tx.reset();
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8); proto_tx.set_proto_opcode(OpCode::MRPStandAloneAck as u8);
proto_tx.unset_reliable(); proto_tx.unset_reliable();
} }
pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> { // TODO
info!("Sending standalone ACK"); // pub fn send_mrp_standalone_ack(exch: &mut Exchange, sess: &mut SessionHandle) -> Result<(), Error> {
let mut ack_packet = Slab::<PacketPool>::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?; // info!("Sending standalone ACK");
create_mrp_standalone_ack(&mut ack_packet); // let mut ack_packet = Slab::<PacketPool>::try_new(Packet::new_tx()?).ok_or(Error::NoMemory)?;
exch.send(ack_packet, sess) // create_mrp_standalone_ack(&mut ack_packet);
} // exch.send(ack_packet, sess)
// }

View file

@ -15,65 +15,87 @@
* limitations under the License. * limitations under the License.
*/ */
use std::sync::Arc; use core::borrow::Borrow;
use core::cell::RefCell;
use log::error;
use crate::{ use crate::{
error::*, error::*,
fabric::FabricMgr, fabric::FabricMgr,
secure_channel::common::*, mdns::Mdns,
tlv, secure_channel::{common::*, pake::Pake},
transport::proto_demux::{self, ProtoCtx, ResponseRequired}, transport::{exchange::Exchange, packet::Packet},
utils::{epoch::Epoch, rand::Rand},
}; };
use log::{error, info};
use num;
use super::{case::Case, pake::PaseMgr}; use super::{case::Case, pake::PaseMgr};
/* Handle messages related to the Secure Channel /* Handle messages related to the Secure Channel
*/ */
pub struct SecureChannel { pub struct SecureChannel<'a> {
case: Case, pase: &'a RefCell<PaseMgr>,
pase: PaseMgr, fabric: &'a RefCell<FabricMgr>,
mdns: &'a dyn Mdns,
rand: Rand,
} }
impl SecureChannel { impl<'a> SecureChannel<'a> {
pub fn new(pase: PaseMgr, fabric_mgr: Arc<FabricMgr>) -> SecureChannel { #[inline(always)]
SecureChannel { pub fn new<
T: Borrow<RefCell<FabricMgr>>
+ Borrow<RefCell<PaseMgr>>
+ Borrow<dyn Mdns + 'a>
+ Borrow<Epoch>
+ Borrow<Rand>,
>(
matter: &'a T,
) -> Self {
Self::wrap(
matter.borrow(),
matter.borrow(),
matter.borrow(),
*matter.borrow(),
)
}
#[inline(always)]
pub fn wrap(
pase: &'a RefCell<PaseMgr>,
fabric: &'a RefCell<FabricMgr>,
mdns: &'a dyn Mdns,
rand: Rand,
) -> Self {
Self {
fabric,
pase, pase,
case: Case::new(fabric_mgr), mdns,
} rand,
} }
} }
impl proto_demux::HandleProto for SecureChannel { pub async fn handle(
fn handle_proto_id(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { &self,
let proto_opcode: OpCode = exchange: &mut Exchange<'_>,
num::FromPrimitive::from_u8(ctx.rx.get_proto_opcode()).ok_or(Error::Invalid)?; rx: &mut Packet<'_>,
ctx.tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); tx: &mut Packet<'_>,
info!("Received Opcode: {:?}", proto_opcode); ) -> Result<(), Error> {
info!("Received Data:"); match rx.get_proto_opcode()? {
tlv::print_tlv_list(ctx.rx.as_borrow_slice()); OpCode::PBKDFParamRequest => {
let result = match proto_opcode { Pake::new(self.pase)
OpCode::MRPStandAloneAck => Ok(ResponseRequired::No), .handle(exchange, rx, tx, self.mdns)
OpCode::PBKDFParamRequest => self.pase.pbkdfparamreq_handler(ctx), .await
OpCode::PASEPake1 => self.pase.pasepake1_handler(ctx), }
OpCode::PASEPake3 => self.pase.pasepake3_handler(ctx), OpCode::CASESigma1 => {
OpCode::CASESigma1 => self.case.casesigma1_handler(ctx), Case::new(self.fabric, self.rand)
OpCode::CASESigma3 => self.case.casesigma3_handler(ctx), .handle(exchange, rx, tx)
_ => { .await
error!("OpCode Not Handled: {:?}", proto_opcode); }
Err(Error::InvalidOpcode) proto_opcode => {
error!("OpCode not handled: {:?}", proto_opcode);
Err(ErrorCode::InvalidOpcode.into())
} }
};
if result == Ok(ResponseRequired::Yes) {
info!("Sending response");
tlv::print_tlv_list(ctx.tx.as_borrow_slice());
} }
result
}
fn get_proto_id(&self) -> usize {
PROTO_ID_SECURE_CHANNEL
} }
} }

View file

@ -15,40 +15,13 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error; #[cfg(not(any(feature = "openssl", feature = "mbedtls", feature = "rustcrypto")))]
pub use super::crypto_dummy::CryptoSpake2;
// This trait allows us to switch between crypto providers like OpenSSL and mbedTLS for Spake2 #[cfg(all(feature = "mbedtls", target_os = "espidf"))]
// Currently this is only validate for a verifier(responder) pub use super::crypto_esp_mbedtls::CryptoSpake2;
#[cfg(all(feature = "mbedtls", not(target_os = "espidf")))]
// A verifier will typically do: pub use super::crypto_mbedtls::CryptoSpake2;
// Step 1: w0 and L #[cfg(feature = "openssl")]
// set_w0_from_w0s pub use super::crypto_openssl::CryptoSpake2;
// set_L #[cfg(feature = "rustcrypto")]
// Step 2: get_pB pub use super::crypto_rustcrypto::CryptoSpake2;
// Step 3: get_TT_as_verifier(pA)
// Step 4: Computation of cA and cB happens outside since it doesn't use either BigNum or EcPoint
pub trait CryptoSpake2 {
fn new() -> Result<Self, Error>
where
Self: Sized;
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error>;
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>;
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error>;
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error>;
#[allow(non_snake_case)]
fn set_L(&mut self, l: &[u8]) -> Result<(), Error>;
#[allow(non_snake_case)]
fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error>;
#[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error>;
#[allow(non_snake_case)]
fn get_TT_as_verifier(
&mut self,
context: &[u8],
pA: &[u8],
pB: &[u8],
out: &mut [u8],
) -> Result<(), Error>;
}

View file

@ -0,0 +1,76 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::{
error::{Error, ErrorCode},
utils::rand::Rand,
};
#[allow(non_snake_case)]
pub struct CryptoSpake2 {}
impl CryptoSpake2 {
#[allow(non_snake_case)]
pub fn new() -> Result<Self, Error> {
Ok(Self {})
}
// Computes w0 from w0s respectively
pub fn set_w0_from_w0s(&mut self, _w0s: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
pub fn set_w1_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
pub fn set_w0(&mut self, _w0: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
pub fn set_w1(&mut self, _w1: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
#[allow(non_snake_case)]
pub fn set_L(&mut self, _l: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
#[allow(non_snake_case)]
#[allow(dead_code)]
pub fn set_L_from_w1s(&mut self, _w1s: &[u8]) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
#[allow(non_snake_case)]
pub fn get_pB(&mut self, _pB: &mut [u8], _rand: Rand) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
#[allow(non_snake_case)]
pub fn get_TT_as_verifier(
&mut self,
_context: &[u8],
_pA: &[u8],
_pB: &[u8],
_out: &mut [u8],
) -> Result<(), Error> {
Err(ErrorCode::Invalid.into())
}
}

View file

@ -16,8 +16,7 @@
*/ */
use crate::error::Error; use crate::error::Error;
use crate::utils::rand::Rand;
use super::crypto::CryptoSpake2;
const MATTER_M_BIN: [u8; 65] = [ const MATTER_M_BIN: [u8; 65] = [
0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2,
@ -36,16 +35,16 @@ const MATTER_N_BIN: [u8; 65] = [
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct CryptoEspMbedTls {} pub struct CryptoSpake2 {}
impl CryptoSpake2 for CryptoEspMbedTls { impl CryptoSpake2 {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
Ok(CryptoEspMbedTls {}) Ok(Self {})
} }
// Computes w0 from w0s respectively // Computes w0 from w0s respectively
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w0 = w0s mod p // w0 = w0s mod p
// where p is the order of the curve // where p is the order of the curve
@ -53,7 +52,7 @@ impl CryptoSpake2 for CryptoEspMbedTls {
Ok(()) Ok(())
} }
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w1 = w1s mod p // w1 = w1s mod p
// where p is the order of the curve // where p is the order of the curve
@ -61,17 +60,17 @@ impl CryptoSpake2 for CryptoEspMbedTls {
Ok(()) Ok(())
} }
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> {
Ok(()) Ok(())
} }
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> {
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(dead_code)] #[allow(dead_code)]
fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_L(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec, // From the Matter spec,
// L = w1 * P // L = w1 * P
// where P is the generator of the underlying elliptic curve // where P is the generator of the underlying elliptic curve
@ -79,7 +78,16 @@ impl CryptoSpake2 for CryptoEspMbedTls {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { #[allow(dead_code)]
pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec,
// L = w1 * P
// where P is the generator of the underlying elliptic curve
Ok(())
}
#[allow(non_snake_case)]
pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> {
// From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/)
// for y // for y
// - select random y between 0 to p // - select random y between 0 to p
@ -90,7 +98,7 @@ impl CryptoSpake2 for CryptoEspMbedTls {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_TT_as_verifier( pub fn get_TT_as_verifier(
&mut self, &mut self,
context: &[u8], context: &[u8],
pA: &[u8], pA: &[u8],
@ -101,13 +109,10 @@ impl CryptoSpake2 for CryptoEspMbedTls {
} }
} }
impl CryptoEspMbedTls {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::CryptoEspMbedTls; use super::CryptoSpake2;
use crate::secure_channel::crypto::CryptoSpake2;
use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use crate::secure_channel::spake2p_test_vectors::test_vectors::*;
use openssl::bn::BigNum; use openssl::bn::BigNum;
use openssl::ec::{EcPoint, PointConversionForm}; use openssl::ec::{EcPoint, PointConversionForm};
@ -116,13 +121,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_X() { fn test_get_X() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoEspMbedTls::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = BigNum::from_slice(&t.x).unwrap(); let x = BigNum::from_slice(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator(); let P = c.group.generator();
let r = let r = CryptoSpake2::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
CryptoEspMbedTls::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
assert_eq!( assert_eq!(
t.X, t.X,
r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx)
@ -136,12 +140,11 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_Y() { fn test_get_Y() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoEspMbedTls::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = BigNum::from_slice(&t.y).unwrap(); let y = BigNum::from_slice(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator(); let P = c.group.generator();
let r = let r = CryptoSpake2::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
CryptoEspMbedTls::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
assert_eq!( assert_eq!(
t.Y, t.Y,
r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx)
@ -155,12 +158,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_prover() { fn test_get_ZV_as_prover() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoEspMbedTls::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = BigNum::from_slice(&t.x).unwrap(); let x = BigNum::from_slice(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
c.set_w1(&t.w1).unwrap(); c.set_w1(&t.w1).unwrap();
let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap(); let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap();
let (Z, V) = CryptoEspMbedTls::get_ZV_as_prover( let (Z, V) = CryptoSpake2::get_ZV_as_prover(
&c.w0, &c.w0,
&c.w1, &c.w1,
&mut c.N, &mut c.N,
@ -191,12 +194,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_verifier() { fn test_get_ZV_as_verifier() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoEspMbedTls::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = BigNum::from_slice(&t.y).unwrap(); let y = BigNum::from_slice(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap(); let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap();
let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap(); let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap();
let (Z, V) = CryptoEspMbedTls::get_ZV_as_verifier( let (Z, V) = CryptoSpake2::get_ZV_as_verifier(
&c.w0, &c.w0,
&L, &L,
&mut c.M, &mut c.M,

View file

@ -15,14 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use alloc::sync::Arc;
ops::{Mul, Sub}, use core::ops::{Mul, Sub};
sync::Arc,
use crate::{
error::{Error, ErrorCode},
utils::rand::Rand,
}; };
use crate::error::Error;
use super::crypto::CryptoSpake2;
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use log::error; use log::error;
use mbedtls::{ use mbedtls::{
@ -33,6 +33,8 @@ use mbedtls::{
rng::{CtrDrbg, OsEntropy}, rng::{CtrDrbg, OsEntropy},
}; };
extern crate alloc;
const MATTER_M_BIN: [u8; 65] = [ const MATTER_M_BIN: [u8; 65] = [
0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2,
0x99, 0x3b, 0x64, 0xe1, 0x6e, 0xf3, 0xdc, 0xab, 0x95, 0xaf, 0xd4, 0x97, 0x33, 0x3d, 0x8f, 0xa1, 0x99, 0x3b, 0x64, 0xe1, 0x6e, 0xf3, 0xdc, 0xab, 0x95, 0xaf, 0xd4, 0x97, 0x33, 0x3d, 0x8f, 0xa1,
@ -50,7 +52,7 @@ const MATTER_N_BIN: [u8; 65] = [
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct CryptoMbedTLS { pub struct CryptoSpake2 {
group: EcGroup, group: EcGroup,
order: Mpi, order: Mpi,
xy: Mpi, xy: Mpi,
@ -62,15 +64,15 @@ pub struct CryptoMbedTLS {
pB: EcPoint, pB: EcPoint,
} }
impl CryptoSpake2 for CryptoMbedTLS { impl CryptoSpake2 {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let group = EcGroup::new(mbedtls::pk::EcGroupId::SecP256R1)?; let group = EcGroup::new(mbedtls::pk::EcGroupId::SecP256R1)?;
let order = group.order()?; let order = group.order()?;
let M = EcPoint::from_binary(&group, &MATTER_M_BIN)?; let M = EcPoint::from_binary(&group, &MATTER_M_BIN)?;
let N = EcPoint::from_binary(&group, &MATTER_N_BIN)?; let N = EcPoint::from_binary(&group, &MATTER_N_BIN)?;
Ok(CryptoMbedTLS { Ok(Self {
group, group,
order, order,
xy: Mpi::new(0)?, xy: Mpi::new(0)?,
@ -84,7 +86,7 @@ impl CryptoSpake2 for CryptoMbedTLS {
} }
// Computes w0 from w0s respectively // Computes w0 from w0s respectively
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w0 = w0s mod p // w0 = w0s mod p
// where p is the order of the curve // where p is the order of the curve
@ -94,7 +96,7 @@ impl CryptoSpake2 for CryptoMbedTLS {
Ok(()) Ok(())
} }
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w1 = w1s mod p // w1 = w1s mod p
// where p is the order of the curve // where p is the order of the curve
@ -104,24 +106,25 @@ impl CryptoSpake2 for CryptoMbedTLS {
Ok(()) Ok(())
} }
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> {
self.w0 = Mpi::from_binary(w0)?; self.w0 = Mpi::from_binary(w0)?;
Ok(()) Ok(())
} }
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> {
self.w1 = Mpi::from_binary(w1)?; self.w1 = Mpi::from_binary(w1)?;
Ok(()) Ok(())
} }
fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { #[allow(non_snake_case)]
pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> {
self.L = EcPoint::from_binary(&self.group, l)?; self.L = EcPoint::from_binary(&self.group, l)?;
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(dead_code)] #[allow(dead_code)]
fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec, // From the Matter spec,
// L = w1 * P // L = w1 * P
// where P is the generator of the underlying elliptic curve // where P is the generator of the underlying elliptic curve
@ -132,7 +135,7 @@ impl CryptoSpake2 for CryptoMbedTLS {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> {
// From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/)
// for y // for y
// - select random y between 0 to p // - select random y between 0 to p
@ -150,14 +153,14 @@ impl CryptoSpake2 for CryptoMbedTLS {
let pB_internal = pB_internal.as_slice(); let pB_internal = pB_internal.as_slice();
if pB_internal.len() != pB.len() { if pB_internal.len() != pB.len() {
error!("pB length mismatch"); error!("pB length mismatch");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
pB.copy_from_slice(pB_internal); pB.copy_from_slice(pB_internal);
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_TT_as_verifier( pub fn get_TT_as_verifier(
&mut self, &mut self,
context: &[u8], context: &[u8],
pA: &[u8], pA: &[u8],
@ -166,21 +169,21 @@ impl CryptoSpake2 for CryptoMbedTLS {
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut TT = Md::new(mbedtls::hash::Type::Sha256)?; let mut TT = Md::new(mbedtls::hash::Type::Sha256)?;
// context // context
CryptoMbedTLS::add_to_tt(&mut TT, context)?; Self::add_to_tt(&mut TT, context)?;
// 2 empty identifiers // 2 empty identifiers
CryptoMbedTLS::add_to_tt(&mut TT, &[])?; Self::add_to_tt(&mut TT, &[])?;
CryptoMbedTLS::add_to_tt(&mut TT, &[])?; Self::add_to_tt(&mut TT, &[])?;
// M // M
CryptoMbedTLS::add_to_tt(&mut TT, &MATTER_M_BIN)?; Self::add_to_tt(&mut TT, &MATTER_M_BIN)?;
// N // N
CryptoMbedTLS::add_to_tt(&mut TT, &MATTER_N_BIN)?; Self::add_to_tt(&mut TT, &MATTER_N_BIN)?;
// X = pA // X = pA
CryptoMbedTLS::add_to_tt(&mut TT, pA)?; Self::add_to_tt(&mut TT, pA)?;
// Y = pB // Y = pB
CryptoMbedTLS::add_to_tt(&mut TT, pB)?; Self::add_to_tt(&mut TT, pB)?;
let X = EcPoint::from_binary(&self.group, pA)?; let X = EcPoint::from_binary(&self.group, pA)?;
let (Z, V) = CryptoMbedTLS::get_ZV_as_verifier( let (Z, V) = Self::get_ZV_as_verifier(
&self.w0, &self.w0,
&self.L, &self.L,
&mut self.M, &mut self.M,
@ -193,24 +196,22 @@ impl CryptoSpake2 for CryptoMbedTLS {
// Z // Z
let tmp = Z.to_binary(&self.group, false)?; let tmp = Z.to_binary(&self.group, false)?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
// V // V
let tmp = V.to_binary(&self.group, false)?; let tmp = V.to_binary(&self.group, false)?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
// w0 // w0
let tmp = self.w0.to_binary()?; let tmp = self.w0.to_binary()?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoMbedTLS::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
TT.finish(out)?; TT.finish(out)?;
Ok(()) Ok(())
} }
}
impl CryptoMbedTLS {
fn add_to_tt(tt: &mut Md, buf: &[u8]) -> Result<(), Error> { fn add_to_tt(tt: &mut Md, buf: &[u8]) -> Result<(), Error> {
let mut len_buf: [u8; 8] = [0; 8]; let mut len_buf: [u8; 8] = [0; 8];
LittleEndian::write_u64(&mut len_buf, buf.len() as u64); LittleEndian::write_u64(&mut len_buf, buf.len() as u64);
@ -247,7 +248,7 @@ impl CryptoMbedTLS {
let mut tmp = x.mul(w0)?; let mut tmp = x.mul(w0)?;
tmp = tmp.modulo(order)?; tmp = tmp.modulo(order)?;
let inverted_N = CryptoMbedTLS::invert(group, N)?; let inverted_N = Self::invert(group, N)?;
let Z = EcPoint::muladd(group, Y, x, &inverted_N, &tmp)?; let Z = EcPoint::muladd(group, Y, x, &inverted_N, &tmp)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
@ -283,7 +284,7 @@ impl CryptoMbedTLS {
let mut tmp = y.mul(w0)?; let mut tmp = y.mul(w0)?;
tmp = tmp.modulo(order)?; tmp = tmp.modulo(order)?;
let inverted_M = CryptoMbedTLS::invert(group, M)?; let inverted_M = Self::invert(group, M)?;
let Z = EcPoint::muladd(group, X, y, &inverted_M, &tmp)?; let Z = EcPoint::muladd(group, X, y, &inverted_M, &tmp)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
@ -302,8 +303,7 @@ impl CryptoMbedTLS {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::CryptoMbedTLS; use super::CryptoSpake2;
use crate::secure_channel::crypto::CryptoSpake2;
use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use crate::secure_channel::spake2p_test_vectors::test_vectors::*;
use mbedtls::bignum::Mpi; use mbedtls::bignum::Mpi;
use mbedtls::ecp::EcPoint; use mbedtls::ecp::EcPoint;
@ -312,7 +312,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_X() { fn test_get_X() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoMbedTLS::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = Mpi::from_binary(&t.x).unwrap(); let x = Mpi::from_binary(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator().unwrap(); let P = c.group.generator().unwrap();
@ -326,7 +326,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_Y() { fn test_get_Y() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoMbedTLS::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = Mpi::from_binary(&t.y).unwrap(); let y = Mpi::from_binary(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator().unwrap(); let P = c.group.generator().unwrap();
@ -339,12 +339,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_prover() { fn test_get_ZV_as_prover() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoMbedTLS::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = Mpi::from_binary(&t.x).unwrap(); let x = Mpi::from_binary(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
c.set_w1(&t.w1).unwrap(); c.set_w1(&t.w1).unwrap();
let Y = EcPoint::from_binary(&c.group, &t.Y).unwrap(); let Y = EcPoint::from_binary(&c.group, &t.Y).unwrap();
let (Z, V) = CryptoMbedTLS::get_ZV_as_prover( let (Z, V) = CryptoSpake2::get_ZV_as_prover(
&c.w0, &c.w0,
&c.w1, &c.w1,
&mut c.N, &mut c.N,
@ -364,12 +364,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_verifier() { fn test_get_ZV_as_verifier() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoMbedTLS::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = Mpi::from_binary(&t.y).unwrap(); let y = Mpi::from_binary(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let X = EcPoint::from_binary(&c.group, &t.X).unwrap(); let X = EcPoint::from_binary(&c.group, &t.X).unwrap();
let L = EcPoint::from_binary(&c.group, &t.L).unwrap(); let L = EcPoint::from_binary(&c.group, &t.L).unwrap();
let (Z, V) = CryptoMbedTLS::get_ZV_as_verifier( let (Z, V) = CryptoSpake2::get_ZV_as_verifier(
&c.w0, &c.w0,
&L, &L,
&mut c.M, &mut c.M,

View file

@ -15,9 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error; use crate::{
error::{Error, ErrorCode},
utils::rand::Rand,
};
use super::crypto::CryptoSpake2;
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use log::error; use log::error;
use openssl::{ use openssl::{
@ -44,7 +46,7 @@ const MATTER_N_BIN: [u8; 65] = [
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct CryptoOpenSSL { pub struct CryptoSpake2 {
group: EcGroup, group: EcGroup,
bn_ctx: BigNumContext, bn_ctx: BigNumContext,
// Stores the randomly generated x or y depending upon who we are // Stores the randomly generated x or y depending upon who we are
@ -58,9 +60,9 @@ pub struct CryptoOpenSSL {
order: BigNum, order: BigNum,
} }
impl CryptoSpake2 for CryptoOpenSSL { impl CryptoSpake2 {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?; let group = EcGroup::from_curve_name(Nid::X9_62_PRIME256V1)?;
let mut bn_ctx = BigNumContext::new()?; let mut bn_ctx = BigNumContext::new()?;
let M = EcPoint::from_bytes(&group, &MATTER_M_BIN, &mut bn_ctx)?; let M = EcPoint::from_bytes(&group, &MATTER_M_BIN, &mut bn_ctx)?;
@ -70,7 +72,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
let mut order = BigNum::new()?; let mut order = BigNum::new()?;
group.as_ref().order(&mut order, &mut bn_ctx)?; group.as_ref().order(&mut order, &mut bn_ctx)?;
Ok(CryptoOpenSSL { Ok(Self {
group, group,
bn_ctx, bn_ctx,
xy: BigNum::new()?, xy: BigNum::new()?,
@ -85,7 +87,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
} }
// Computes w0 from w0s respectively // Computes w0 from w0s respectively
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w0 = w0s mod p // w0 = w0s mod p
// where p is the order of the curve // where p is the order of the curve
@ -96,7 +98,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
Ok(()) Ok(())
} }
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w1 = w1s mod p // w1 = w1s mod p
// where p is the order of the curve // where p is the order of the curve
@ -107,24 +109,25 @@ impl CryptoSpake2 for CryptoOpenSSL {
Ok(()) Ok(())
} }
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> {
self.w0 = BigNum::from_slice(w0)?; self.w0 = BigNum::from_slice(w0)?;
Ok(()) Ok(())
} }
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> {
self.w1 = BigNum::from_slice(w1)?; self.w1 = BigNum::from_slice(w1)?;
Ok(()) Ok(())
} }
fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { #[allow(non_snake_case)]
pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> {
self.L = EcPoint::from_bytes(&self.group, l, &mut self.bn_ctx)?; self.L = EcPoint::from_bytes(&self.group, l, &mut self.bn_ctx)?;
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(dead_code)] #[allow(dead_code)]
fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec, // From the Matter spec,
// L = w1 * P // L = w1 * P
// where P is the generator of the underlying elliptic curve // where P is the generator of the underlying elliptic curve
@ -135,7 +138,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { pub fn get_pB(&mut self, pB: &mut [u8], _rand: Rand) -> Result<(), Error> {
// From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/)
// for y // for y
// - select random y between 0 to p // - select random y between 0 to p
@ -143,7 +146,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
// - pB = Y // - pB = Y
self.order.rand_range(&mut self.xy)?; self.order.rand_range(&mut self.xy)?;
let P = self.group.generator(); let P = self.group.generator();
self.pB = CryptoOpenSSL::do_add_mul( self.pB = Self::do_add_mul(
P, P,
&self.xy, &self.xy,
&self.N, &self.N,
@ -159,14 +162,14 @@ impl CryptoSpake2 for CryptoOpenSSL {
let pB_internal = pB_internal.as_slice(); let pB_internal = pB_internal.as_slice();
if pB_internal.len() != pB.len() { if pB_internal.len() != pB.len() {
error!("pB length mismatch"); error!("pB length mismatch");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
pB.copy_from_slice(pB_internal); pB.copy_from_slice(pB_internal);
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_TT_as_verifier( pub fn get_TT_as_verifier(
&mut self, &mut self,
context: &[u8], context: &[u8],
pA: &[u8], pA: &[u8],
@ -175,21 +178,21 @@ impl CryptoSpake2 for CryptoOpenSSL {
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut TT = Hasher::new(MessageDigest::sha256())?; let mut TT = Hasher::new(MessageDigest::sha256())?;
// context // context
CryptoOpenSSL::add_to_tt(&mut TT, context)?; Self::add_to_tt(&mut TT, context)?;
// 2 empty identifiers // 2 empty identifiers
CryptoOpenSSL::add_to_tt(&mut TT, &[])?; Self::add_to_tt(&mut TT, &[])?;
CryptoOpenSSL::add_to_tt(&mut TT, &[])?; Self::add_to_tt(&mut TT, &[])?;
// M // M
CryptoOpenSSL::add_to_tt(&mut TT, &MATTER_M_BIN)?; Self::add_to_tt(&mut TT, &MATTER_M_BIN)?;
// N // N
CryptoOpenSSL::add_to_tt(&mut TT, &MATTER_N_BIN)?; Self::add_to_tt(&mut TT, &MATTER_N_BIN)?;
// X = pA // X = pA
CryptoOpenSSL::add_to_tt(&mut TT, pA)?; Self::add_to_tt(&mut TT, pA)?;
// Y = pB // Y = pB
CryptoOpenSSL::add_to_tt(&mut TT, pB)?; Self::add_to_tt(&mut TT, pB)?;
let X = EcPoint::from_bytes(&self.group, pA, &mut self.bn_ctx)?; let X = EcPoint::from_bytes(&self.group, pA, &mut self.bn_ctx)?;
let (Z, V) = CryptoOpenSSL::get_ZV_as_verifier( let (Z, V) = Self::get_ZV_as_verifier(
&self.w0, &self.w0,
&self.L, &self.L,
&mut self.M, &mut self.M,
@ -207,7 +210,7 @@ impl CryptoSpake2 for CryptoOpenSSL {
&mut self.bn_ctx, &mut self.bn_ctx,
)?; )?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
// V // V
let tmp = V.to_bytes( let tmp = V.to_bytes(
@ -216,20 +219,18 @@ impl CryptoSpake2 for CryptoOpenSSL {
&mut self.bn_ctx, &mut self.bn_ctx,
)?; )?;
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
// w0 // w0
let tmp = self.w0.to_vec(); let tmp = self.w0.to_vec();
let tmp = tmp.as_slice(); let tmp = tmp.as_slice();
CryptoOpenSSL::add_to_tt(&mut TT, tmp)?; Self::add_to_tt(&mut TT, tmp)?;
let h = TT.finish()?; let h = TT.finish()?;
TT_hash.copy_from_slice(h.as_ref()); TT_hash.copy_from_slice(h.as_ref());
Ok(()) Ok(())
} }
}
impl CryptoOpenSSL {
fn add_to_tt(tt: &mut Hasher, buf: &[u8]) -> Result<(), Error> { fn add_to_tt(tt: &mut Hasher, buf: &[u8]) -> Result<(), Error> {
let mut len_buf: [u8; 8] = [0; 8]; let mut len_buf: [u8; 8] = [0; 8];
LittleEndian::write_u64(&mut len_buf, buf.len() as u64); LittleEndian::write_u64(&mut len_buf, buf.len() as u64);
@ -286,11 +287,11 @@ impl CryptoOpenSSL {
let mut tmp = BigNum::new()?; let mut tmp = BigNum::new()?;
tmp.mod_mul(x, w0, order, bn_ctx)?; tmp.mod_mul(x, w0, order, bn_ctx)?;
N.invert(group, bn_ctx)?; N.invert(group, bn_ctx)?;
let Z = CryptoOpenSSL::do_add_mul(Y, x, N, &tmp, group, bn_ctx)?; let Z = Self::do_add_mul(Y, x, N, &tmp, group, bn_ctx)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
tmp.mod_mul(w1, w0, order, bn_ctx)?; tmp.mod_mul(w1, w0, order, bn_ctx)?;
let V = CryptoOpenSSL::do_add_mul(Y, w1, N, &tmp, group, bn_ctx)?; let V = Self::do_add_mul(Y, w1, N, &tmp, group, bn_ctx)?;
Ok((Z, V)) Ok((Z, V))
} }
@ -321,7 +322,7 @@ impl CryptoOpenSSL {
let mut tmp = BigNum::new()?; let mut tmp = BigNum::new()?;
tmp.mod_mul(y, w0, order, bn_ctx)?; tmp.mod_mul(y, w0, order, bn_ctx)?;
M.invert(group, bn_ctx)?; M.invert(group, bn_ctx)?;
let Z = CryptoOpenSSL::do_add_mul(X, y, M, &tmp, group, bn_ctx)?; let Z = Self::do_add_mul(X, y, M, &tmp, group, bn_ctx)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
let mut V = EcPoint::new(group)?; let mut V = EcPoint::new(group)?;
@ -333,8 +334,7 @@ impl CryptoOpenSSL {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::CryptoOpenSSL; use super::CryptoSpake2;
use crate::secure_channel::crypto::CryptoSpake2;
use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use crate::secure_channel::spake2p_test_vectors::test_vectors::*;
use openssl::bn::BigNum; use openssl::bn::BigNum;
use openssl::ec::{EcPoint, PointConversionForm}; use openssl::ec::{EcPoint, PointConversionForm};
@ -343,12 +343,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_X() { fn test_get_X() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoOpenSSL::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = BigNum::from_slice(&t.x).unwrap(); let x = BigNum::from_slice(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator(); let P = c.group.generator();
let r = CryptoOpenSSL::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); let r = CryptoSpake2::do_add_mul(P, &x, &c.M, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
assert_eq!( assert_eq!(
t.X, t.X,
r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx)
@ -362,11 +362,11 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_Y() { fn test_get_Y() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoOpenSSL::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = BigNum::from_slice(&t.y).unwrap(); let y = BigNum::from_slice(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = c.group.generator(); let P = c.group.generator();
let r = CryptoOpenSSL::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap(); let r = CryptoSpake2::do_add_mul(P, &y, &c.N, &c.w0, &c.group, &mut c.bn_ctx).unwrap();
assert_eq!( assert_eq!(
t.Y, t.Y,
r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx) r.to_bytes(&c.group, PointConversionForm::UNCOMPRESSED, &mut c.bn_ctx)
@ -380,12 +380,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_prover() { fn test_get_ZV_as_prover() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoOpenSSL::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = BigNum::from_slice(&t.x).unwrap(); let x = BigNum::from_slice(&t.x).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
c.set_w1(&t.w1).unwrap(); c.set_w1(&t.w1).unwrap();
let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap(); let Y = EcPoint::from_bytes(&c.group, &t.Y, &mut c.bn_ctx).unwrap();
let (Z, V) = CryptoOpenSSL::get_ZV_as_prover( let (Z, V) = CryptoSpake2::get_ZV_as_prover(
&c.w0, &c.w0,
&c.w1, &c.w1,
&mut c.N, &mut c.N,
@ -416,12 +416,12 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_verifier() { fn test_get_ZV_as_verifier() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoOpenSSL::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = BigNum::from_slice(&t.y).unwrap(); let y = BigNum::from_slice(&t.y).unwrap();
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap(); let X = EcPoint::from_bytes(&c.group, &t.X, &mut c.bn_ctx).unwrap();
let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap(); let L = EcPoint::from_bytes(&c.group, &t.L, &mut c.bn_ctx).unwrap();
let (Z, V) = CryptoOpenSSL::get_ZV_as_verifier( let (Z, V) = CryptoSpake2::get_ZV_as_verifier(
&c.w0, &c.w0,
&L, &L,
&mut c.M, &mut c.M,

View file

@ -21,11 +21,12 @@ use elliptic_curve::ops::*;
use elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint}; use elliptic_curve::sec1::{FromEncodedPoint, ToEncodedPoint};
use elliptic_curve::Field; use elliptic_curve::Field;
use elliptic_curve::PrimeField; use elliptic_curve::PrimeField;
use rand_core::CryptoRng;
use rand_core::RngCore;
use sha2::Digest; use sha2::Digest;
use crate::error::Error; use crate::error::Error;
use crate::utils::rand::Rand;
use super::crypto::CryptoSpake2;
const MATTER_M_BIN: [u8; 65] = [ const MATTER_M_BIN: [u8; 65] = [
0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2, 0x04, 0x88, 0x6e, 0x2f, 0x97, 0xac, 0xe4, 0x6e, 0x55, 0xba, 0x9d, 0xd7, 0x24, 0x25, 0x79, 0xf2,
@ -44,7 +45,7 @@ const MATTER_N_BIN: [u8; 65] = [
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub struct CryptoRustCrypto { pub struct CryptoSpake2 {
xy: p256::Scalar, xy: p256::Scalar,
w0: p256::Scalar, w0: p256::Scalar,
w1: p256::Scalar, w1: p256::Scalar,
@ -54,15 +55,15 @@ pub struct CryptoRustCrypto {
pB: p256::EncodedPoint, pB: p256::EncodedPoint,
} }
impl CryptoSpake2 for CryptoRustCrypto { impl CryptoSpake2 {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn new() -> Result<Self, Error> { pub fn new() -> Result<Self, Error> {
let M = p256::EncodedPoint::from_bytes(MATTER_M_BIN).unwrap(); let M = p256::EncodedPoint::from_bytes(MATTER_M_BIN).unwrap();
let N = p256::EncodedPoint::from_bytes(MATTER_N_BIN).unwrap(); let N = p256::EncodedPoint::from_bytes(MATTER_N_BIN).unwrap();
let L = p256::EncodedPoint::default(); let L = p256::EncodedPoint::default();
let pB = p256::EncodedPoint::default(); let pB = p256::EncodedPoint::default();
Ok(CryptoRustCrypto { Ok(Self {
xy: p256::Scalar::ZERO, xy: p256::Scalar::ZERO,
w0: p256::Scalar::ZERO, w0: p256::Scalar::ZERO,
w1: p256::Scalar::ZERO, w1: p256::Scalar::ZERO,
@ -74,7 +75,7 @@ impl CryptoSpake2 for CryptoRustCrypto {
} }
// Computes w0 from w0s respectively // Computes w0 from w0s respectively
fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> { pub fn set_w0_from_w0s(&mut self, w0s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w0 = w0s mod p // w0 = w0s mod p
// where p is the order of the curve // where p is the order of the curve
@ -103,7 +104,7 @@ impl CryptoSpake2 for CryptoRustCrypto {
Ok(()) Ok(())
} }
fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { pub fn set_w1_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter Spec, // From the Matter Spec,
// w1 = w1s mod p // w1 = w1s mod p
// where p is the order of the curve // where p is the order of the curve
@ -132,14 +133,14 @@ impl CryptoSpake2 for CryptoRustCrypto {
Ok(()) Ok(())
} }
fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> { pub fn set_w0(&mut self, w0: &[u8]) -> Result<(), Error> {
self.w0 = self.w0 =
p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w0)) p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w0))
.unwrap(); .unwrap();
Ok(()) Ok(())
} }
fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> { pub fn set_w1(&mut self, w1: &[u8]) -> Result<(), Error> {
self.w1 = self.w1 =
p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w1)) p256::Scalar::from_repr(*elliptic_curve::generic_array::GenericArray::from_slice(w1))
.unwrap(); .unwrap();
@ -148,12 +149,13 @@ impl CryptoSpake2 for CryptoRustCrypto {
#[allow(non_snake_case)] #[allow(non_snake_case)]
#[allow(dead_code)] #[allow(dead_code)]
fn set_L(&mut self, l: &[u8]) -> Result<(), Error> { pub fn set_L(&mut self, l: &[u8]) -> Result<(), Error> {
self.L = p256::EncodedPoint::from_bytes(l).unwrap(); self.L = p256::EncodedPoint::from_bytes(l).unwrap();
Ok(()) Ok(())
} }
fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> { #[allow(non_snake_case)]
pub fn set_L_from_w1s(&mut self, w1s: &[u8]) -> Result<(), Error> {
// From the Matter spec, // From the Matter spec,
// L = w1 * P // L = w1 * P
// where P is the generator of the underlying elliptic curve // where P is the generator of the underlying elliptic curve
@ -163,14 +165,14 @@ impl CryptoSpake2 for CryptoRustCrypto {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_pB(&mut self, pB: &mut [u8]) -> Result<(), Error> { pub fn get_pB(&mut self, pB: &mut [u8], rand: Rand) -> Result<(), Error> {
// From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/) // From the SPAKE2+ spec (https://datatracker.ietf.org/doc/draft-bar-cfrg-spake2plus/)
// for y // for y
// - select random y between 0 to p // - select random y between 0 to p
// - Y = y*P + w0*N // - Y = y*P + w0*N
// - pB = Y // - pB = Y
let mut rng = rand::thread_rng(); let mut rand = RandRngCore(rand);
self.xy = p256::Scalar::random(&mut rng); self.xy = p256::Scalar::random(&mut rand);
let P = p256::AffinePoint::GENERATOR; let P = p256::AffinePoint::GENERATOR;
let N = p256::AffinePoint::from_encoded_point(&self.N).unwrap(); let N = p256::AffinePoint::from_encoded_point(&self.N).unwrap();
@ -182,7 +184,7 @@ impl CryptoSpake2 for CryptoRustCrypto {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn get_TT_as_verifier( pub fn get_TT_as_verifier(
&mut self, &mut self,
context: &[u8], context: &[u8],
pA: &[u8], pA: &[u8],
@ -222,9 +224,7 @@ impl CryptoSpake2 for CryptoRustCrypto {
Ok(()) Ok(())
} }
}
impl CryptoRustCrypto {
fn add_to_tt(tt: &mut sha2::Sha256, buf: &[u8]) -> Result<(), Error> { fn add_to_tt(tt: &mut sha2::Sha256, buf: &[u8]) -> Result<(), Error> {
tt.update((buf.len() as u64).to_le_bytes()); tt.update((buf.len() as u64).to_le_bytes());
if !buf.is_empty() { if !buf.is_empty() {
@ -266,11 +266,11 @@ impl CryptoRustCrypto {
let mut tmp = x * w0; let mut tmp = x * w0;
let N_neg = N.neg(); let N_neg = N.neg();
let Z = CryptoRustCrypto::do_add_mul(Y, x, N_neg, tmp)?; let Z = Self::do_add_mul(Y, x, N_neg, tmp)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
tmp = w1 * w0; tmp = w1 * w0;
let V = CryptoRustCrypto::do_add_mul(Y, w1, N_neg, tmp)?; let V = Self::do_add_mul(Y, w1, N_neg, tmp)?;
Ok((Z, V)) Ok((Z, V))
} }
@ -297,27 +297,55 @@ impl CryptoRustCrypto {
let tmp = y * w0; let tmp = y * w0;
let M_neg = M.neg(); let M_neg = M.neg();
let Z = CryptoRustCrypto::do_add_mul(X, y, M_neg, tmp)?; let Z = Self::do_add_mul(X, y, M_neg, tmp)?;
// Cofactor for P256 is 1, so that is a No-Op // Cofactor for P256 is 1, so that is a No-Op
let V = (L * y).to_encoded_point(false); let V = (L * y).to_encoded_point(false);
Ok((Z, V)) Ok((Z, V))
} }
} }
pub struct RandRngCore(pub Rand);
impl RngCore for RandRngCore {
fn next_u32(&mut self) -> u32 {
let mut buf = [0; 4];
self.fill_bytes(&mut buf);
u32::from_be_bytes(buf)
}
fn next_u64(&mut self) -> u64 {
let mut buf = [0; 8];
self.fill_bytes(&mut buf);
u64::from_be_bytes(buf)
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
(self.0)(dest);
}
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> {
self.fill_bytes(dest);
Ok(())
}
}
impl CryptoRng for RandRngCore {}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use elliptic_curve::sec1::FromEncodedPoint; use elliptic_curve::sec1::FromEncodedPoint;
use crate::secure_channel::crypto::CryptoSpake2;
use crate::secure_channel::spake2p_test_vectors::test_vectors::*; use crate::secure_channel::spake2p_test_vectors::test_vectors::*;
#[test] #[test]
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_X() { fn test_get_X() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoRustCrypto::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = p256::Scalar::from_repr( let x = p256::Scalar::from_repr(
*elliptic_curve::generic_array::GenericArray::from_slice(&t.x), *elliptic_curve::generic_array::GenericArray::from_slice(&t.x),
) )
@ -325,7 +353,7 @@ mod tests {
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = p256::AffinePoint::GENERATOR; let P = p256::AffinePoint::GENERATOR;
let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap(); let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap();
let r: p256::EncodedPoint = CryptoRustCrypto::do_add_mul(P, x, M, c.w0).unwrap(); let r: p256::EncodedPoint = CryptoSpake2::do_add_mul(P, x, M, c.w0).unwrap();
assert_eq!(&t.X, r.as_bytes()); assert_eq!(&t.X, r.as_bytes());
} }
} }
@ -334,7 +362,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_Y() { fn test_get_Y() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoRustCrypto::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = p256::Scalar::from_repr( let y = p256::Scalar::from_repr(
*elliptic_curve::generic_array::GenericArray::from_slice(&t.y), *elliptic_curve::generic_array::GenericArray::from_slice(&t.y),
) )
@ -342,7 +370,7 @@ mod tests {
c.set_w0(&t.w0).unwrap(); c.set_w0(&t.w0).unwrap();
let P = p256::AffinePoint::GENERATOR; let P = p256::AffinePoint::GENERATOR;
let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap(); let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap();
let r = CryptoRustCrypto::do_add_mul(P, y, N, c.w0).unwrap(); let r = CryptoSpake2::do_add_mul(P, y, N, c.w0).unwrap();
assert_eq!(&t.Y, r.as_bytes()); assert_eq!(&t.Y, r.as_bytes());
} }
} }
@ -351,7 +379,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_prover() { fn test_get_ZV_as_prover() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoRustCrypto::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let x = p256::Scalar::from_repr( let x = p256::Scalar::from_repr(
*elliptic_curve::generic_array::GenericArray::from_slice(&t.x), *elliptic_curve::generic_array::GenericArray::from_slice(&t.x),
) )
@ -361,7 +389,7 @@ mod tests {
let Y = p256::EncodedPoint::from_bytes(t.Y).unwrap(); let Y = p256::EncodedPoint::from_bytes(t.Y).unwrap();
let Y = p256::AffinePoint::from_encoded_point(&Y).unwrap(); let Y = p256::AffinePoint::from_encoded_point(&Y).unwrap();
let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap(); let N = p256::AffinePoint::from_encoded_point(&c.N).unwrap();
let (Z, V) = CryptoRustCrypto::get_ZV_as_prover(c.w0, c.w1, N, Y, x).unwrap(); let (Z, V) = CryptoSpake2::get_ZV_as_prover(c.w0, c.w1, N, Y, x).unwrap();
assert_eq!(&t.Z, Z.as_bytes()); assert_eq!(&t.Z, Z.as_bytes());
assert_eq!(&t.V, V.as_bytes()); assert_eq!(&t.V, V.as_bytes());
@ -372,7 +400,7 @@ mod tests {
#[allow(non_snake_case)] #[allow(non_snake_case)]
fn test_get_ZV_as_verifier() { fn test_get_ZV_as_verifier() {
for t in RFC_T { for t in RFC_T {
let mut c = CryptoRustCrypto::new().unwrap(); let mut c = CryptoSpake2::new().unwrap();
let y = p256::Scalar::from_repr( let y = p256::Scalar::from_repr(
*elliptic_curve::generic_array::GenericArray::from_slice(&t.y), *elliptic_curve::generic_array::GenericArray::from_slice(&t.y),
) )
@ -383,7 +411,7 @@ mod tests {
let L = p256::EncodedPoint::from_bytes(t.L).unwrap(); let L = p256::EncodedPoint::from_bytes(t.L).unwrap();
let L = p256::AffinePoint::from_encoded_point(&L).unwrap(); let L = p256::AffinePoint::from_encoded_point(&L).unwrap();
let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap(); let M = p256::AffinePoint::from_encoded_point(&c.M).unwrap();
let (Z, V) = CryptoRustCrypto::get_ZV_as_verifier(c.w0, L, M, X, y).unwrap(); let (Z, V) = CryptoSpake2::get_ZV_as_verifier(c.w0, L, M, X, y).unwrap();
assert_eq!(&t.Z, Z.as_bytes()); assert_eq!(&t.Z, Z.as_bytes());
assert_eq!(&t.V, V.as_bytes()); assert_eq!(&t.V, V.as_bytes());

View file

@ -17,13 +17,15 @@
pub mod case; pub mod case;
pub mod common; pub mod common;
#[cfg(feature = "crypto_esp_mbedtls")] #[cfg(not(any(feature = "openssl", feature = "mbedtls", feature = "rustcrypto")))]
pub mod crypto_esp_mbedtls; mod crypto_dummy;
#[cfg(feature = "crypto_mbedtls")] #[cfg(all(feature = "mbedtls", target_os = "espidf"))]
pub mod crypto_mbedtls; mod crypto_esp_mbedtls;
#[cfg(feature = "crypto_openssl")] #[cfg(all(feature = "mbedtls", not(target_os = "espidf")))]
mod crypto_mbedtls;
#[cfg(feature = "openssl")]
pub mod crypto_openssl; pub mod crypto_openssl;
#[cfg(feature = "crypto_rustcrypto")] #[cfg(feature = "rustcrypto")]
pub mod crypto_rustcrypto; pub mod crypto_rustcrypto;
pub mod core; pub mod core;

View file

@ -15,109 +15,88 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::{cell::RefCell, fmt::Write, time::Duration};
sync::{Arc, Mutex},
time::{Duration, SystemTime},
};
use super::{ use super::{
common::{create_sc_status_report, SCStatusCodes}, common::{SCStatusCodes, PROTO_ID_SECURE_CHANNEL},
spake2p::{Spake2P, VerifierData}, spake2p::{Spake2P, VerifierData},
}; };
use crate::{ use crate::{
crypto, alloc, crypto,
error::Error, error::{Error, ErrorCode},
mdns::{self, Mdns}, mdns::{Mdns, ServiceMode},
secure_channel::common::OpCode, secure_channel::common::{complete_with_status, OpCode},
sys::SysMdnsService, tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV},
tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV},
transport::{ transport::{
exchange::ExchangeCtx, exchange::{Exchange, ExchangeId},
network::Address, packet::Packet,
proto_demux::{ProtoCtx, ResponseRequired},
queue::{Msg, WorkQ},
session::{CloneData, SessionMode}, session::{CloneData, SessionMode},
}, },
utils::{epoch::Epoch, rand::Rand},
}; };
use log::{error, info}; use log::{error, info};
use rand::prelude::*;
enum PaseMgrState { struct PaseSession {
Enabled(PAKE, SysMdnsService), mdns_service_name: heapless::String<16>,
Disabled, verifier: VerifierData,
} }
pub struct PaseMgrInternal { pub struct PaseMgr {
state: PaseMgrState, session: Option<PaseSession>,
timeout: Option<Timeout>,
epoch: Epoch,
rand: Rand,
} }
#[derive(Clone)]
// Could this lock be avoided?
pub struct PaseMgr(Arc<Mutex<PaseMgrInternal>>);
impl PaseMgr { impl PaseMgr {
pub fn new() -> Self { #[inline(always)]
Self(Arc::new(Mutex::new(PaseMgrInternal { pub fn new(epoch: Epoch, rand: Rand) -> Self {
state: PaseMgrState::Disabled, Self {
}))) session: None,
timeout: None,
epoch,
rand,
}
}
pub fn is_pase_session_enabled(&self) -> bool {
self.session.is_some()
} }
pub fn enable_pase_session( pub fn enable_pase_session(
&mut self, &mut self,
verifier: VerifierData, verifier: VerifierData,
discriminator: u16, discriminator: u16,
mdns: &dyn Mdns,
) -> Result<(), Error> { ) -> Result<(), Error> {
let mut s = self.0.lock().unwrap(); let mut buf = [0; 8];
let name: u64 = rand::thread_rng().gen_range(0..0xFFFFFFFFFFFFFFFF); (self.rand)(&mut buf);
let name = format!("{:016X}", name); let num = u64::from_be_bytes(buf);
let mdns = Mdns::get()?
.publish_service(&name, mdns::ServiceMode::Commissionable(discriminator))?; let mut mdns_service_name = heapless::String::<16>::new();
s.state = PaseMgrState::Enabled(PAKE::new(verifier), mdns); write!(&mut mdns_service_name, "{:016X}", num).unwrap();
mdns.add(
&mdns_service_name,
ServiceMode::Commissionable(discriminator),
)?;
self.session = Some(PaseSession {
mdns_service_name,
verifier,
});
Ok(()) Ok(())
} }
pub fn disable_pase_session(&mut self) { pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> {
let mut s = self.0.lock().unwrap(); if let Some(session) = self.session.as_ref() {
s.state = PaseMgrState::Disabled; mdns.remove(&session.mdns_service_name)?;
} }
/// If the PASE Session is enabled, execute the closure, self.session = None;
/// if not enabled, generate SC Status Report
fn if_enabled<F>(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<(), Error>
where
F: FnOnce(&mut PAKE, &mut ProtoCtx) -> Result<(), Error>,
{
let mut s = self.0.lock().unwrap();
if let PaseMgrState::Enabled(pake, _) = &mut s.state {
f(pake, ctx)
} else {
error!("PASE Not enabled");
create_sc_status_report(&mut ctx.tx, SCStatusCodes::InvalidParameter, None)
}
}
pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> { Ok(())
ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?;
Ok(ResponseRequired::Yes)
}
pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> {
ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8);
self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?;
Ok(ResponseRequired::Yes)
}
pub fn pasepake3_handler(&mut self, ctx: &mut ProtoCtx) -> Result<ResponseRequired, Error> {
self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?;
self.disable_pase_session();
Ok(ResponseRequired::Yes)
}
}
impl Default for PaseMgr {
fn default() -> Self {
Self::new()
} }
} }
@ -130,101 +109,75 @@ const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60);
const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys"; const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys";
struct SessionData { struct Timeout {
start_time: SystemTime, start_time: Duration,
exch_id: u16, exch_id: ExchangeId,
peer_addr: Address,
spake2p: Box<Spake2P>,
} }
impl SessionData { impl Timeout {
fn is_sess_expired(&self) -> Result<bool, Error> { fn new(exchange: &Exchange, epoch: Epoch) -> Self {
if SystemTime::now().duration_since(self.start_time)? > PASE_DISCARD_TIMEOUT_SECS { Self {
Ok(true) start_time: epoch(),
} else { exch_id: exchange.id().clone(),
Ok(false)
}
} }
} }
enum PakeState { fn is_sess_expired(&self, epoch: Epoch) -> bool {
Idle, epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS
InProgress(SessionData),
}
impl PakeState {
fn take(&mut self) -> Result<SessionData, Error> {
let new = std::mem::replace(self, PakeState::Idle);
if let PakeState::InProgress(s) = new {
Ok(s)
} else {
Err(Error::InvalidSignature)
} }
} }
fn is_idle(&self) -> bool { pub struct Pake<'a> {
std::mem::discriminant(self) == std::mem::discriminant(&PakeState::Idle) pase: &'a RefCell<PaseMgr>,
} }
fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result<SessionData, Error> { impl<'a> Pake<'a> {
let sd = self.take()?; pub const fn new(pase: &'a RefCell<PaseMgr>) -> Self {
if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() {
Err(Error::InvalidState)
} else {
Ok(sd)
}
}
fn make_in_progress(&mut self, spake2p: Box<Spake2P>, exch_ctx: &ExchangeCtx) {
*self = PakeState::InProgress(SessionData {
start_time: SystemTime::now(),
spake2p,
exch_id: exch_ctx.exch.get_id(),
peer_addr: exch_ctx.sess.get_peer_addr(),
});
}
fn set_sess_data(&mut self, sd: SessionData) {
*self = PakeState::InProgress(sd);
}
}
impl Default for PakeState {
fn default() -> Self {
Self::Idle
}
}
pub struct PAKE {
pub verifier: VerifierData,
state: PakeState,
}
impl PAKE {
pub fn new(verifier: VerifierData) -> Self {
// TODO: Can any PBKDF2 calculation be pre-computed here // TODO: Can any PBKDF2 calculation be pre-computed here
PAKE { Self { pase }
verifier,
state: Default::default(),
} }
pub async fn handle(
&mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
mdns: &dyn Mdns,
) -> Result<(), Error> {
let mut spake2p = alloc!(Spake2P::new());
self.handle_pbkdfparamrequest(exchange, rx, tx, &mut spake2p)
.await?;
self.handle_pasepake1(exchange, rx, tx, &mut spake2p)
.await?;
self.handle_pasepake3(exchange, rx, tx, mdns, &mut spake2p)
.await
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { async fn handle_pasepake3(
let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; &mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
mdns: &dyn Mdns,
spake2p: &mut Spake2P,
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::PASEPake3 as _)?;
self.update_timeout(exchange, tx, true).await?;
let cA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; let cA = extract_pasepake_1_or_3_params(rx.as_slice())?;
let (status_code, Ke) = sd.spake2p.handle_cA(cA); let (status_code, ke) = spake2p.handle_cA(cA);
if status_code == SCStatusCodes::SessionEstablishmentSuccess { let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess {
// Get the keys // Get the keys
let Ke = Ke.ok_or(Error::Invalid)?; let ke = ke.ok_or(ErrorCode::Invalid)?;
let mut session_keys: [u8; 48] = [0; 48]; let mut session_keys: [u8; 48] = [0; 48];
crypto::hkdf_sha256(&[], Ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys) crypto::hkdf_sha256(&[], ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
// Create a session // Create a session
let data = sd.spake2p.get_app_data(); let data = spake2p.get_app_data();
let peer_sessid: u16 = (data & 0xffff) as u16; let peer_sessid: u16 = (data & 0xffff) as u16;
let local_sessid: u16 = ((data >> 16) & 0xffff) as u16; let local_sessid: u16 = ((data >> 16) & 0xffff) as u16;
let mut clone_data = CloneData::new( let mut clone_data = CloneData::new(
@ -232,7 +185,7 @@ impl PAKE {
0, 0,
peer_sessid, peer_sessid,
local_sessid, local_sessid,
ctx.exch_ctx.sess.get_peer_addr(), exchange.with_session(|sess| Ok(sess.get_peer_addr()))?,
SessionMode::Pase, SessionMode::Pase,
); );
clone_data.dec_key.copy_from_slice(&session_keys[0..16]); clone_data.dec_key.copy_from_slice(&session_keys[0..16]);
@ -242,67 +195,94 @@ impl PAKE {
.copy_from_slice(&session_keys[32..48]); .copy_from_slice(&session_keys[32..48]);
// Queue a transport mgr request to add a new session // Queue a transport mgr request to add a new session
WorkQ::get()?.sync_send(Msg::NewSession(clone_data))?; Some(clone_data)
} else {
None
};
if let Some(clone_data) = clone_data {
// TODO: Handle NoSpace
exchange.with_session_mgr_mut(|sess_mgr| sess_mgr.clone_session(&clone_data))?;
self.pase.borrow_mut().disable_pase_session(mdns)?;
} }
create_sc_status_report(&mut ctx.tx, status_code, None)?; complete_with_status(exchange, tx, status_code, None).await?;
ctx.exch_ctx.exch.close();
Ok(()) Ok(())
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { #[allow(clippy::await_holding_refcell_ref)]
let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?; async fn handle_pasepake1(
&mut self,
exchange: &mut Exchange<'_>,
rx: &mut Packet<'_>,
tx: &mut Packet<'_>,
spake2p: &mut Spake2P,
) -> Result<(), Error> {
rx.check_proto_opcode(OpCode::PASEPake1 as _)?;
self.update_timeout(exchange, tx, false).await?;
let pA = extract_pasepake_1_or_3_params(ctx.rx.as_borrow_slice())?; let pase = self.pase.borrow();
let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
let pA = extract_pasepake_1_or_3_params(rx.as_slice())?;
let mut pB: [u8; 65] = [0; 65]; let mut pB: [u8; 65] = [0; 65];
let mut cB: [u8; 32] = [0; 32]; let mut cB: [u8; 32] = [0; 32];
sd.spake2p.start_verifier(&self.verifier)?; spake2p.start_verifier(&session.verifier)?;
sd.spake2p.handle_pA(pA, &mut pB, &mut cB)?; spake2p.handle_pA(pA, &mut pB, &mut cB, pase.rand)?;
let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); // Generate response
tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::PASEPake2 as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
let resp = Pake1Resp { let resp = Pake1Resp {
pb: OctetStr(&pB), pb: OctetStr(&pB),
cb: OctetStr(&cB), cb: OctetStr(&cB),
}; };
resp.to_tlv(&mut tw, TagType::Anonymous)?; resp.to_tlv(&mut tw, TagType::Anonymous)?;
self.state.set_sess_data(sd); drop(pase);
exchange.exchange(tx, rx).await
Ok(())
} }
pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> { #[allow(clippy::await_holding_refcell_ref)]
if !self.state.is_idle() { async fn handle_pbkdfparamrequest(
let sd = self.state.take()?; &mut self,
if sd.is_sess_expired()? { exchange: &mut Exchange<'_>,
info!("Previous session expired, clearing it"); rx: &mut Packet<'_>,
self.state = PakeState::Idle; tx: &mut Packet<'_>,
} else { spake2p: &mut Spake2P,
info!("Previous session in-progress, denying new request"); ) -> Result<(), Error> {
// little-endian timeout (here we've hardcoded 500ms) rx.check_proto_opcode(OpCode::PBKDFParamRequest as _)?;
create_sc_status_report(&mut ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?; self.update_timeout(exchange, tx, true).await?;
return Ok(());
}
}
let root = tlv::get_root_node(ctx.rx.as_borrow_slice())?; let pase = self.pase.borrow();
let session = pase.session.as_ref().ok_or(ErrorCode::NoSession)?;
let root = tlv::get_root_node(rx.as_slice())?;
let a = PBKDFParamReq::from_tlv(&root)?; let a = PBKDFParamReq::from_tlv(&root)?;
if a.passcode_id != 0 { if a.passcode_id != 0 {
error!("Can't yet handle passcode_id != 0"); error!("Can't yet handle passcode_id != 0");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
let mut our_random: [u8; 32] = [0; 32]; let mut our_random: [u8; 32] = [0; 32];
rand::thread_rng().fill_bytes(&mut our_random); (self.pase.borrow().rand)(&mut our_random);
let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id(); let local_sessid = exchange.with_session_mgr_mut(|mgr| Ok(mgr.get_next_sess_id()))?;
let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32; let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
let mut spake2p = Box::new(Spake2P::new());
spake2p.set_app_data(spake2p_data); spake2p.set_app_data(spake2p_data);
// Generate response // Generate response
let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?); tx.reset();
tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
let mut tw = TLVWriter::new(tx.get_writebuf()?);
let mut resp = PBKDFParamResp { let mut resp = PBKDFParamResp {
init_random: a.initiator_random, init_random: a.initiator_random,
our_random: OctetStr(&our_random), our_random: OctetStr(&our_random),
@ -311,20 +291,79 @@ impl PAKE {
}; };
if !a.has_params { if !a.has_params {
let params_resp = PBKDFParamRespParams { let params_resp = PBKDFParamRespParams {
count: self.verifier.count, count: session.verifier.count,
salt: OctetStr(&self.verifier.salt), salt: OctetStr(&session.verifier.salt),
}; };
resp.params = Some(params_resp); resp.params = Some(params_resp);
} }
resp.to_tlv(&mut tw, TagType::Anonymous)?; resp.to_tlv(&mut tw, TagType::Anonymous)?;
spake2p.set_context(ctx.rx.as_borrow_slice(), ctx.tx.as_borrow_slice())?; spake2p.set_context(rx.as_slice(), tx.as_mut_slice())?;
self.state.make_in_progress(spake2p, &ctx.exch_ctx);
drop(pase);
exchange.exchange(tx, rx).await
}
#[allow(clippy::await_holding_refcell_ref)]
async fn update_timeout(
&mut self,
exchange: &mut Exchange<'_>,
tx: &mut Packet<'_>,
new: bool,
) -> Result<(), Error> {
self.check_session(exchange, tx).await?;
let mut pase = self.pase.borrow_mut();
if pase
.timeout
.as_ref()
.map(|sd| sd.is_sess_expired(pase.epoch))
.unwrap_or(false)
{
pase.timeout = None;
}
let status = if let Some(sd) = pase.timeout.as_mut() {
if &sd.exch_id != exchange.id() {
info!("Other PAKE session in progress");
Some(SCStatusCodes::Busy)
} else {
None
}
} else if new {
None
} else {
error!("PAKE session not found or expired");
Some(SCStatusCodes::SessionNotFound)
};
if let Some(status) = status {
drop(pase);
complete_with_status(exchange, tx, status, None).await
} else {
pase.timeout = Some(Timeout::new(exchange, pase.epoch));
Ok(()) Ok(())
} }
} }
async fn check_session(
&mut self,
exchange: &mut Exchange<'_>,
tx: &mut Packet<'_>,
) -> Result<(), Error> {
if self.pase.borrow().session.is_none() {
error!("PASE not enabled");
complete_with_status(exchange, tx, SCStatusCodes::InvalidParameter, None).await
} else {
Ok(())
}
}
}
#[derive(ToTLV)] #[derive(ToTLV)]
#[tlvargs(start = 1)] #[tlvargs(start = 1)]
struct Pake1Resp<'a> { struct Pake1Resp<'a> {

View file

@ -17,33 +17,20 @@
use crate::{ use crate::{
crypto::{self, HmacSha256}, crypto::{self, HmacSha256},
sys, utils::rand::Rand,
}; };
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use log::error; use log::error;
use rand::prelude::*;
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use crate::{ use crate::{
crypto::{pbkdf2_hmac, Sha256}, crypto::{pbkdf2_hmac, Sha256},
error::Error, error::{Error, ErrorCode},
}; };
#[cfg(feature = "crypto_openssl")]
use super::crypto_openssl::CryptoOpenSSL;
#[cfg(feature = "crypto_mbedtls")]
use super::crypto_mbedtls::CryptoMbedTLS;
#[cfg(feature = "crypto_esp_mbedtls")]
use super::crypto_esp_mbedtls::CryptoEspMbedTls;
#[cfg(feature = "crypto_rustcrypto")]
use super::crypto_rustcrypto::CryptoRustCrypto;
use super::{common::SCStatusCodes, crypto::CryptoSpake2}; use super::{common::SCStatusCodes, crypto::CryptoSpake2};
// This file handle Spake2+ specific instructions. In itself, this file is // This file handles Spake2+ specific instructions. In itself, this file is
// independent from the BigNum and EC operations that are typically required // independent from the BigNum and EC operations that are typically required
// Spake2+. We use the CryptoSpake2 trait object that allows us to abstract // Spake2+. We use the CryptoSpake2 trait object that allows us to abstract
// out the specific implementations. // out the specific implementations.
@ -51,6 +38,8 @@ use super::{common::SCStatusCodes, crypto::CryptoSpake2};
// In the case of the verifier, we don't actually release the Ke until we // In the case of the verifier, we don't actually release the Ke until we
// validate that the cA is confirmed. // validate that the cA is confirmed.
pub const SPAKE2_ITERATION_COUNT: u32 = 2000;
#[derive(PartialEq, Copy, Clone, Debug)] #[derive(PartialEq, Copy, Clone, Debug)]
pub enum Spake2VerifierState { pub enum Spake2VerifierState {
// Initialised - w0, L are set // Initialised - w0, L are set
@ -74,7 +63,7 @@ pub struct Spake2P {
context: Option<Sha256>, context: Option<Sha256>,
Ke: [u8; 16], Ke: [u8; 16],
cA: [u8; 32], cA: [u8; 32],
crypto_spake2: Option<Box<dyn CryptoSpake2>>, crypto_spake2: Option<CryptoSpake2>,
app_data: u32, app_data: u32,
} }
@ -87,24 +76,8 @@ const CRYPTO_PUBLIC_KEY_SIZE_BYTES: usize = (2 * CRYPTO_GROUP_SIZE_BYTES) + 1;
const MAX_SALT_SIZE_BYTES: usize = 32; const MAX_SALT_SIZE_BYTES: usize = 32;
const VERIFIER_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES; const VERIFIER_SIZE_BYTES: usize = CRYPTO_GROUP_SIZE_BYTES + CRYPTO_PUBLIC_KEY_SIZE_BYTES;
#[cfg(feature = "crypto_openssl")] fn crypto_spake2_new() -> Result<CryptoSpake2, Error> {
fn crypto_spake2_new() -> Result<Box<dyn CryptoSpake2>, Error> { CryptoSpake2::new()
Ok(Box::new(CryptoOpenSSL::new()?))
}
#[cfg(feature = "crypto_mbedtls")]
fn crypto_spake2_new() -> Result<Box<dyn CryptoSpake2>, Error> {
Ok(Box::new(CryptoMbedTLS::new()?))
}
#[cfg(feature = "crypto_esp_mbedtls")]
fn crypto_spake2_new() -> Result<Box<dyn CryptoSpake2>, Error> {
Ok(Box::new(CryptoEspMbedTls::new()?))
}
#[cfg(feature = "crypto_rustcrypto")]
fn crypto_spake2_new() -> Result<Box<dyn CryptoSpake2>, Error> {
Ok(Box::new(CryptoRustCrypto::new()?))
} }
impl Default for Spake2P { impl Default for Spake2P {
@ -129,13 +102,13 @@ pub enum VerifierOption {
} }
impl VerifierData { impl VerifierData {
pub fn new_with_pw(pw: u32) -> Self { pub fn new_with_pw(pw: u32, rand: Rand) -> Self {
let mut s = Self { let mut s = Self {
salt: [0; MAX_SALT_SIZE_BYTES], salt: [0; MAX_SALT_SIZE_BYTES],
count: sys::SPAKE2_ITERATION_COUNT, count: SPAKE2_ITERATION_COUNT,
data: VerifierOption::Password(pw), data: VerifierOption::Password(pw),
}; };
rand::thread_rng().fill_bytes(&mut s.salt); rand(&mut s.salt);
s s
} }
@ -158,7 +131,7 @@ impl VerifierData {
} }
impl Spake2P { impl Spake2P {
pub fn new() -> Self { pub const fn new() -> Self {
Spake2P { Spake2P {
mode: Spake2Mode::Unknown, mode: Spake2Mode::Unknown,
context: None, context: None,
@ -198,7 +171,7 @@ impl Spake2P {
match verifier.data { match verifier.data {
VerifierOption::Password(pw) => { VerifierOption::Password(pw) => {
// Derive w0 and L from the password // Derive w0 and L from the password
let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; 2 * CRYPTO_W_SIZE_BYTES]; let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; (2 * CRYPTO_W_SIZE_BYTES)];
Spake2P::get_w0w1s(pw, verifier.count, &verifier.salt, &mut w0w1s); Spake2P::get_w0w1s(pw, verifier.count, &verifier.salt, &mut w0w1s);
let w0s_len = w0w1s.len() / 2; let w0s_len = w0w1s.len() / 2;
@ -223,13 +196,19 @@ impl Spake2P {
} }
#[allow(non_snake_case)] #[allow(non_snake_case)]
pub fn handle_pA(&mut self, pA: &[u8], pB: &mut [u8], cB: &mut [u8]) -> Result<(), Error> { pub fn handle_pA(
&mut self,
pA: &[u8],
pB: &mut [u8],
cB: &mut [u8],
rand: Rand,
) -> Result<(), Error> {
if self.mode != Spake2Mode::Verifier(Spake2VerifierState::Init) { if self.mode != Spake2Mode::Verifier(Spake2VerifierState::Init) {
return Err(Error::InvalidState); Err(ErrorCode::InvalidState)?;
} }
if let Some(crypto_spake2) = &mut self.crypto_spake2 { if let Some(crypto_spake2) = &mut self.crypto_spake2 {
crypto_spake2.get_pB(pB)?; crypto_spake2.get_pB(pB, rand)?;
if let Some(context) = self.context.take() { if let Some(context) = self.context.take() {
let mut hash = [0u8; crypto::SHA256_HASH_LEN_BYTES]; let mut hash = [0u8; crypto::SHA256_HASH_LEN_BYTES];
context.finish(&mut hash)?; context.finish(&mut hash)?;
@ -278,13 +257,13 @@ impl Spake2P {
if ke_internal.len() == Ke.len() { if ke_internal.len() == Ke.len() {
Ke.copy_from_slice(ke_internal); Ke.copy_from_slice(ke_internal);
} else { } else {
return Err(Error::NoSpace); Err(ErrorCode::NoSpace)?;
} }
// Step 2: KcA || KcB = KDF(nil, Ka, "ConfirmationKeys") // Step 2: KcA || KcB = KDF(nil, Ka, "ConfirmationKeys")
let mut KcAKcB: [u8; 32] = [0; 32]; let mut KcAKcB: [u8; 32] = [0; 32];
crypto::hkdf_sha256(&[], Ka, &SPAKE2P_KEY_CONFIRM_INFO, &mut KcAKcB) crypto::hkdf_sha256(&[], Ka, &SPAKE2P_KEY_CONFIRM_INFO, &mut KcAKcB)
.map_err(|_x| Error::NoSpace)?; .map_err(|_x| ErrorCode::NoSpace)?;
let KcA = &KcAKcB[0..(KcAKcB.len() / 2)]; let KcA = &KcAKcB[0..(KcAKcB.len() / 2)];
let KcB = &KcAKcB[(KcAKcB.len() / 2)..]; let KcB = &KcAKcB[(KcAKcB.len() / 2)..];
@ -317,7 +296,7 @@ mod tests {
0x4, 0xa1, 0xd2, 0xc6, 0x11, 0xf0, 0xbd, 0x36, 0x78, 0x67, 0x79, 0x7b, 0xfe, 0x82, 0x4, 0xa1, 0xd2, 0xc6, 0x11, 0xf0, 0xbd, 0x36, 0x78, 0x67, 0x79, 0x7b, 0xfe, 0x82,
0x36, 0x0, 0x36, 0x0,
]; ];
let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; 2 * CRYPTO_W_SIZE_BYTES]; let mut w0w1s: [u8; 2 * CRYPTO_W_SIZE_BYTES] = [0; (2 * CRYPTO_W_SIZE_BYTES)];
Spake2P::get_w0w1s(123456, 2000, &salt, &mut w0w1s); Spake2P::get_w0w1s(123456, 2000, &salt, &mut w0w1s);
assert_eq!( assert_eq!(
w0w1s, w0w1s,

View file

@ -39,6 +39,7 @@ pub enum GeneralCode {
PermissionDenied = 15, PermissionDenied = 15,
DataLoss = 16, DataLoss = 16,
} }
pub fn create_status_report( pub fn create_status_report(
proto_tx: &mut Packet, proto_tx: &mut Packet,
general_code: GeneralCode, general_code: GeneralCode,
@ -46,7 +47,8 @@ pub fn create_status_report(
proto_code: u16, proto_code: u16,
proto_data: Option<&[u8]>, proto_data: Option<&[u8]>,
) -> Result<(), Error> { ) -> Result<(), Error> {
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL as u16); proto_tx.reset();
proto_tx.set_proto_id(PROTO_ID_SECURE_CHANNEL);
proto_tx.set_proto_opcode(OpCode::StatusReport as u8); proto_tx.set_proto_opcode(OpCode::StatusReport as u8);
let wb = proto_tx.get_writebuf()?; let wb = proto_tx.get_writebuf()?;
wb.le_u16(general_code as u16)?; wb.le_u16(general_code as u16)?;

View file

@ -1,31 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#[cfg(target_os = "macos")]
mod sys_macos;
#[cfg(target_os = "macos")]
pub use self::sys_macos::*;
#[cfg(target_os = "linux")]
mod sys_linux;
#[cfg(target_os = "linux")]
pub use self::sys_linux::*;
#[cfg(any(target_os = "macos", target_os = "linux"))]
mod posix;
#[cfg(any(target_os = "macos", target_os = "linux"))]
pub use self::posix::*;

View file

@ -1,96 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use std::{
convert::TryInto,
fs::{remove_file, DirBuilder, File},
io::{Read, Write},
sync::{Arc, Mutex, Once},
};
use crate::error::Error;
pub const SPAKE2_ITERATION_COUNT: u32 = 2000;
// The Packet Pool that is allocated from. POSIX systems can use
// higher values unlike embedded systems
pub const MAX_PACKET_POOL_SIZE: usize = 25;
pub struct Psm {}
static mut G_PSM: Option<Arc<Mutex<Psm>>> = None;
static INIT: Once = Once::new();
const PSM_DIR: &str = "/tmp/matter_psm";
macro_rules! psm_path {
($key:ident) => {
format!("{}/{}", PSM_DIR, $key)
};
}
impl Psm {
fn new() -> Result<Self, Error> {
let result = DirBuilder::new().create(PSM_DIR);
if let Err(e) = result {
if e.kind() != std::io::ErrorKind::AlreadyExists {
return Err(e.into());
}
}
Ok(Self {})
}
pub fn get() -> Result<Arc<Mutex<Self>>, Error> {
unsafe {
INIT.call_once(|| {
G_PSM = Some(Arc::new(Mutex::new(Psm::new().unwrap())));
});
Ok(G_PSM.as_ref().ok_or(Error::Invalid)?.clone())
}
}
pub fn set_kv_slice(&self, key: &str, val: &[u8]) -> Result<(), Error> {
let mut f = File::create(psm_path!(key))?;
f.write_all(val)?;
Ok(())
}
pub fn get_kv_slice(&self, key: &str, val: &mut Vec<u8>) -> Result<usize, Error> {
let mut f = File::open(psm_path!(key))?;
let len = f.read_to_end(val)?;
Ok(len)
}
pub fn set_kv_u64(&self, key: &str, val: u64) -> Result<(), Error> {
let mut f = File::create(psm_path!(key))?;
f.write_all(&val.to_be_bytes())?;
Ok(())
}
pub fn get_kv_u64(&self, key: &str, val: &mut u64) -> Result<(), Error> {
let mut f = File::open(psm_path!(key))?;
let mut vec = Vec::new();
let _ = f.read_to_end(&mut vec)?;
*val = u64::from_be_bytes(vec.as_slice().try_into()?);
Ok(())
}
pub fn rm(&self, key: &str) {
let _ = remove_file(psm_path!(key));
}
}

View file

@ -1,58 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::error::Error;
use lazy_static::lazy_static;
use libmdns::{Responder, Service};
use log::info;
use std::sync::{Arc, Mutex};
use std::vec::Vec;
#[allow(dead_code)]
pub struct SysMdnsService {
service: Service,
}
lazy_static! {
static ref RESPONDER: Arc<Mutex<Responder>> = Arc::new(Mutex::new(Responder::new().unwrap()));
}
/// Publish a mDNS service
/// name - can be a service name (comma separate subtypes may follow)
/// regtype - registration type (e.g. _matter_.tcp etc)
/// port - the port
pub fn sys_publish_service(
name: &str,
regtype: &str,
port: u16,
txt_kvs: &[[&str; 2]],
) -> Result<SysMdnsService, Error> {
info!("mDNS Registration Type {}", regtype);
info!("mDNS properties {:?}", txt_kvs);
let mut properties = Vec::new();
for kvs in txt_kvs {
info!("mDNS TXT key {} val {}", kvs[0], kvs[1]);
properties.push(format!("{}={}", kvs[0], kvs[1]));
}
let properties: Vec<&str> = properties.iter().map(|entry| entry.as_str()).collect();
let responder = RESPONDER.lock().map_err(|_| Error::MdnsError)?;
let service = responder.register(regtype.to_owned(), name.to_owned(), port, &properties);
Ok(SysMdnsService { service })
}

View file

@ -1,46 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use crate::error::Error;
use astro_dnssd::{DNSServiceBuilder, RegisteredDnsService};
use log::info;
#[allow(dead_code)]
pub struct SysMdnsService {
s: RegisteredDnsService,
}
/// Publish a mDNS service
/// name - can be a service name (comma separate subtypes may follow)
/// regtype - registration type (e.g. _matter_.tcp etc)
/// port - the port
pub fn sys_publish_service(
name: &str,
regtype: &str,
port: u16,
txt_kvs: &[[&str; 2]],
) -> Result<SysMdnsService, Error> {
let mut builder = DNSServiceBuilder::new(regtype, port).with_name(name);
info!("mDNS Registration Type {}", regtype);
for kvs in txt_kvs {
info!("mDNS TXT key {} val {}", kvs[0], kvs[1]);
builder = builder.with_key_value(kvs[0].to_string(), kvs[1].to_string());
}
let s = builder.register().map_err(|_| Error::MdnsError)?;
Ok(SysMdnsService { s })
}

View file

@ -15,11 +15,11 @@
* limitations under the License. * limitations under the License.
*/ */
use crate::error::Error; use crate::error::{Error, ErrorCode};
use byteorder::{ByteOrder, LittleEndian}; use byteorder::{ByteOrder, LittleEndian};
use core::fmt;
use log::{error, info}; use log::{error, info};
use std::fmt;
use super::{TagType, MAX_TAG_INDEX, TAG_MASK, TAG_SHIFT_BITS, TAG_SIZE_MAP, TYPE_MASK}; use super::{TagType, MAX_TAG_INDEX, TAG_MASK, TAG_SHIFT_BITS, TAG_SIZE_MAP, TYPE_MASK};
@ -33,14 +33,7 @@ impl<'a> TLVList<'a> {
} }
} }
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Clone, PartialEq)]
pub struct Pointer<'a> {
buf: &'a [u8],
current: usize,
left: usize,
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub enum ElementType<'a> { pub enum ElementType<'a> {
S8(i8), S8(i8),
S16(i16), S16(i16),
@ -63,9 +56,9 @@ pub enum ElementType<'a> {
Str32l, Str32l,
Str64l, Str64l,
Null, Null,
Struct(Pointer<'a>), Struct(&'a [u8]),
Array(Pointer<'a>), Array(&'a [u8]),
List(Pointer<'a>), List(&'a [u8]),
EndCnt, EndCnt,
Last, Last,
} }
@ -204,44 +197,11 @@ static VALUE_EXTRACTOR: [ExtractValue; MAX_VALUE_INDEX] = [
// Null 20 // Null 20
{ |_t| (0, ElementType::Null) }, { |_t| (0, ElementType::Null) },
// Struct 21 // Struct 21
{ { |t| (0, ElementType::Struct(&t.buf[t.current..])) },
|t| {
(
0,
ElementType::Struct(Pointer {
buf: t.buf,
current: t.current,
left: t.left,
}),
)
}
},
// Array 22 // Array 22
{ { |t| (0, ElementType::Array(&t.buf[t.current..])) },
|t| {
(
0,
ElementType::Array(Pointer {
buf: t.buf,
current: t.current,
left: t.left,
}),
)
}
},
// List 23 // List 23
{ { |t| (0, ElementType::List(&t.buf[t.current..])) },
|t| {
(
0,
ElementType::List(Pointer {
buf: t.buf,
current: t.current,
left: t.left,
}),
)
}
},
// EndCnt 24 // EndCnt 24
{ |_t| (0, ElementType::EndCnt) }, { |_t| (0, ElementType::EndCnt) },
]; ];
@ -282,9 +242,9 @@ fn read_length_value<'a>(
// The current offset is the string size // The current offset is the string size
let length: usize = LittleEndian::read_uint(&t.buf[t.current..], size_of_length_field) as usize; let length: usize = LittleEndian::read_uint(&t.buf[t.current..], size_of_length_field) as usize;
// We'll consume the current offset (len) + the entire string // We'll consume the current offset (len) + the entire string
if length + size_of_length_field > t.left { if length + size_of_length_field > t.buf.len() - t.current {
// Return Error // Return Error
Err(Error::NoSpace) Err(ErrorCode::NoSpace.into())
} else { } else {
Ok(( Ok((
// return the additional size only // return the additional size only
@ -294,7 +254,7 @@ fn read_length_value<'a>(
} }
} }
#[derive(Debug, Copy, Clone)] #[derive(Debug, Clone)]
pub struct TLVElement<'a> { pub struct TLVElement<'a> {
tag_type: TagType, tag_type: TagType,
element_type: ElementType<'a>, element_type: ElementType<'a>,
@ -303,11 +263,11 @@ pub struct TLVElement<'a> {
impl<'a> PartialEq for TLVElement<'a> { impl<'a> PartialEq for TLVElement<'a> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
match self.element_type { match self.element_type {
ElementType::Struct(a) | ElementType::Array(a) | ElementType::List(a) => { ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => {
let mut our_iter = TLVListIterator::from_pointer(a); let mut our_iter = TLVListIterator::from_buf(buf);
let mut their = match other.element_type { let mut their = match other.element_type {
ElementType::Struct(b) | ElementType::Array(b) | ElementType::List(b) => { ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => {
TLVListIterator::from_pointer(b) TLVListIterator::from_buf(buf)
} }
_ => { _ => {
// If we are a container, the other must be a container, else this is a mismatch // If we are a container, the other must be a container, else this is a mismatch
@ -318,7 +278,7 @@ impl<'a> PartialEq for TLVElement<'a> {
loop { loop {
let ours = our_iter.next(); let ours = our_iter.next();
let theirs = their.next(); let theirs = their.next();
if std::mem::discriminant(&ours) != std::mem::discriminant(&theirs) { if core::mem::discriminant(&ours) != core::mem::discriminant(&theirs) {
// One of us reached end of list, but the other didn't, that's a mismatch // One of us reached end of list, but the other didn't, that's a mismatch
return false; return false;
} }
@ -336,13 +296,13 @@ impl<'a> PartialEq for TLVElement<'a> {
} }
nest_level -= 1; nest_level -= 1;
} else { } else {
if is_container(ours.element_type) { if is_container(&ours.element_type) {
nest_level += 1; nest_level += 1;
// Only compare the discriminants in case of array/list/structures, // Only compare the discriminants in case of array/list/structures,
// instead of actual element values. Those will be subsets within this same // instead of actual element values. Those will be subsets within this same
// list that will get validated anyway // list that will get validated anyway
if std::mem::discriminant(&ours.element_type) if core::mem::discriminant(&ours.element_type)
!= std::mem::discriminant(&theirs.element_type) != core::mem::discriminant(&theirs.element_type)
{ {
return false; return false;
} }
@ -364,15 +324,11 @@ impl<'a> PartialEq for TLVElement<'a> {
impl<'a> TLVElement<'a> { impl<'a> TLVElement<'a> {
pub fn enter(&self) -> Option<TLVContainerIterator<'a>> { pub fn enter(&self) -> Option<TLVContainerIterator<'a>> {
let ptr = match self.element_type { let buf = match self.element_type {
ElementType::Struct(a) | ElementType::Array(a) | ElementType::List(a) => a, ElementType::Struct(buf) | ElementType::Array(buf) | ElementType::List(buf) => buf,
_ => return None, _ => return None,
}; };
let list_iter = TLVListIterator { let list_iter = TLVListIterator { buf, current: 0 };
buf: ptr.buf,
current: ptr.current,
left: ptr.left,
};
Some(TLVContainerIterator { Some(TLVContainerIterator {
list_iter, list_iter,
prev_container: false, prev_container: false,
@ -390,14 +346,22 @@ impl<'a> TLVElement<'a> {
pub fn i8(&self) -> Result<i8, Error> { pub fn i8(&self) -> Result<i8, Error> {
match self.element_type { match self.element_type {
ElementType::S8(a) => Ok(a), ElementType::S8(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn u8(&self) -> Result<u8, Error> { pub fn u8(&self) -> Result<u8, Error> {
match self.element_type { match self.element_type {
ElementType::U8(a) => Ok(a), ElementType::U8(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
}
}
pub fn i16(&self) -> Result<i16, Error> {
match self.element_type {
ElementType::S8(a) => Ok(a.into()),
ElementType::S16(a) => Ok(a),
_ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -405,7 +369,16 @@ impl<'a> TLVElement<'a> {
match self.element_type { match self.element_type {
ElementType::U8(a) => Ok(a.into()), ElementType::U8(a) => Ok(a.into()),
ElementType::U16(a) => Ok(a), ElementType::U16(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
}
}
pub fn i32(&self) -> Result<i32, Error> {
match self.element_type {
ElementType::S8(a) => Ok(a.into()),
ElementType::S16(a) => Ok(a.into()),
ElementType::S32(a) => Ok(a),
_ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -414,7 +387,17 @@ impl<'a> TLVElement<'a> {
ElementType::U8(a) => Ok(a.into()), ElementType::U8(a) => Ok(a.into()),
ElementType::U16(a) => Ok(a.into()), ElementType::U16(a) => Ok(a.into()),
ElementType::U32(a) => Ok(a), ElementType::U32(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
}
}
pub fn i64(&self) -> Result<i64, Error> {
match self.element_type {
ElementType::S8(a) => Ok(a.into()),
ElementType::S16(a) => Ok(a.into()),
ElementType::S32(a) => Ok(a.into()),
ElementType::S64(a) => Ok(a),
_ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -424,7 +407,7 @@ impl<'a> TLVElement<'a> {
ElementType::U16(a) => Ok(a.into()), ElementType::U16(a) => Ok(a.into()),
ElementType::U32(a) => Ok(a.into()), ElementType::U32(a) => Ok(a.into()),
ElementType::U64(a) => Ok(a), ElementType::U64(a) => Ok(a),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -434,7 +417,19 @@ impl<'a> TLVElement<'a> {
| ElementType::Utf8l(s) | ElementType::Utf8l(s)
| ElementType::Str16l(s) | ElementType::Str16l(s)
| ElementType::Utf16l(s) => Ok(s), | ElementType::Utf16l(s) => Ok(s),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
}
}
pub fn str(&self) -> Result<&'a str, Error> {
match self.element_type {
ElementType::Str8l(s)
| ElementType::Utf8l(s)
| ElementType::Str16l(s)
| ElementType::Utf16l(s) => {
Ok(core::str::from_utf8(s).map_err(|_| Error::from(ErrorCode::InvalidData))?)
}
_ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
@ -442,48 +437,48 @@ impl<'a> TLVElement<'a> {
match self.element_type { match self.element_type {
ElementType::False => Ok(false), ElementType::False => Ok(false),
ElementType::True => Ok(true), ElementType::True => Ok(true),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn null(&self) -> Result<(), Error> { pub fn null(&self) -> Result<(), Error> {
match self.element_type { match self.element_type {
ElementType::Null => Ok(()), ElementType::Null => Ok(()),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn confirm_struct(&self) -> Result<TLVElement<'a>, Error> { pub fn confirm_struct(&self) -> Result<&TLVElement<'a>, Error> {
match self.element_type { match self.element_type {
ElementType::Struct(_) => Ok(*self), ElementType::Struct(_) => Ok(self),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn confirm_array(&self) -> Result<TLVElement<'a>, Error> { pub fn confirm_array(&self) -> Result<&TLVElement<'a>, Error> {
match self.element_type { match self.element_type {
ElementType::Array(_) => Ok(*self), ElementType::Array(_) => Ok(self),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn confirm_list(&self) -> Result<TLVElement<'a>, Error> { pub fn confirm_list(&self) -> Result<&TLVElement<'a>, Error> {
match self.element_type { match self.element_type {
ElementType::List(_) => Ok(*self), ElementType::List(_) => Ok(self),
_ => Err(Error::TLVTypeMismatch), _ => Err(ErrorCode::TLVTypeMismatch.into()),
} }
} }
pub fn find_tag(&self, tag: u32) -> Result<TLVElement<'a>, Error> { pub fn find_tag(&self, tag: u32) -> Result<TLVElement<'a>, Error> {
let match_tag: TagType = TagType::Context(tag as u8); let match_tag: TagType = TagType::Context(tag as u8);
let iter = self.enter().ok_or(Error::TLVTypeMismatch)?; let iter = self.enter().ok_or(ErrorCode::TLVTypeMismatch)?;
for a in iter { for a in iter {
if match_tag == a.tag_type { if match_tag == a.tag_type {
return Ok(a); return Ok(a);
} }
} }
Err(Error::NoTagFound) Err(ErrorCode::NoTagFound.into())
} }
pub fn get_tag(&self) -> TagType { pub fn get_tag(&self) -> TagType {
@ -499,8 +494,8 @@ impl<'a> TLVElement<'a> {
false false
} }
pub fn get_element_type(&self) -> ElementType { pub fn get_element_type(&self) -> &ElementType {
self.element_type &self.element_type
} }
} }
@ -522,7 +517,7 @@ impl<'a> fmt::Display for TLVElement<'a> {
| ElementType::Utf8l(a) | ElementType::Utf8l(a)
| ElementType::Str16l(a) | ElementType::Str16l(a)
| ElementType::Utf16l(a) => { | ElementType::Utf16l(a) => {
if let Ok(s) = std::str::from_utf8(a) { if let Ok(s) = core::str::from_utf8(a) {
write!(f, "len[{}]\"{}\"", s.len(), s) write!(f, "len[{}]\"{}\"", s.len(), s)
} else { } else {
write!(f, "len[{}]{:x?}", a.len(), a) write!(f, "len[{}]{:x?}", a.len(), a)
@ -534,25 +529,19 @@ impl<'a> fmt::Display for TLVElement<'a> {
} }
// This is a TLV List iterator, it only iterates over the individual TLVs in a TLV list // This is a TLV List iterator, it only iterates over the individual TLVs in a TLV list
#[derive(Copy, Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub struct TLVListIterator<'a> { pub struct TLVListIterator<'a> {
buf: &'a [u8], buf: &'a [u8],
current: usize, current: usize,
left: usize,
} }
impl<'a> TLVListIterator<'a> { impl<'a> TLVListIterator<'a> {
fn from_pointer(p: Pointer<'a>) -> Self { fn from_buf(buf: &'a [u8]) -> Self {
Self { Self { buf, current: 0 }
buf: p.buf,
current: p.current,
left: p.left,
}
} }
fn advance(&mut self, len: usize) { fn advance(&mut self, len: usize) {
self.current += len; self.current += len;
self.left -= len;
} }
// Caller should ensure they are reading the _right_ tag at the _right_ place // Caller should ensure they are reading the _right_ tag at the _right_ place
@ -561,7 +550,7 @@ impl<'a> TLVListIterator<'a> {
return None; return None;
} }
let tag_size = TAG_SIZE_MAP[tag_type as usize]; let tag_size = TAG_SIZE_MAP[tag_type as usize];
if tag_size > self.left { if tag_size > self.buf.len() - self.current {
return None; return None;
} }
let tag = (TAG_EXTRACTOR[tag_type as usize])(self); let tag = (TAG_EXTRACTOR[tag_type as usize])(self);
@ -574,7 +563,7 @@ impl<'a> TLVListIterator<'a> {
return None; return None;
} }
let mut size = VALUE_SIZE_MAP[element_type as usize]; let mut size = VALUE_SIZE_MAP[element_type as usize];
if size > self.left { if size > self.buf.len() - self.current {
error!( error!(
"Invalid value found: {} self {:?} size {}", "Invalid value found: {} self {:?} size {}",
element_type, self, size element_type, self, size
@ -597,7 +586,7 @@ impl<'a> Iterator for TLVListIterator<'a> {
type Item = TLVElement<'a>; type Item = TLVElement<'a>;
/* Code for going to the next Element */ /* Code for going to the next Element */
fn next(&mut self) -> Option<TLVElement<'a>> { fn next(&mut self) -> Option<TLVElement<'a>> {
if self.left < 1 { if self.buf.len() - self.current < 1 {
return None; return None;
} }
/* Read Control */ /* Read Control */
@ -623,13 +612,12 @@ impl<'a> TLVList<'a> {
pub fn iter(&self) -> TLVListIterator<'a> { pub fn iter(&self) -> TLVListIterator<'a> {
TLVListIterator { TLVListIterator {
current: 0, current: 0,
left: self.buf.len(),
buf: self.buf, buf: self.buf,
} }
} }
} }
fn is_container(element_type: ElementType) -> bool { fn is_container(element_type: &ElementType) -> bool {
matches!( matches!(
element_type, element_type,
ElementType::Struct(_) | ElementType::Array(_) | ElementType::List(_) ElementType::Struct(_) | ElementType::Array(_) | ElementType::List(_)
@ -668,7 +656,7 @@ impl<'a> TLVContainerIterator<'a> {
nest_level -= 1; nest_level -= 1;
} }
_ => { _ => {
if is_container(element.element_type) { if is_container(&element.element_type) {
nest_level += 1; nest_level += 1;
} }
} }
@ -699,33 +687,38 @@ impl<'a> Iterator for TLVContainerIterator<'a> {
return None; return None;
} }
if is_container(element.element_type) { self.prev_container = is_container(&element.element_type);
self.prev_container = true;
} else {
self.prev_container = false;
}
Some(element) Some(element)
} }
} }
pub fn get_root_node(b: &[u8]) -> Result<TLVElement, Error> { pub fn get_root_node(b: &[u8]) -> Result<TLVElement, Error> {
TLVList::new(b).iter().next().ok_or(Error::InvalidData) Ok(TLVList::new(b)
.iter()
.next()
.ok_or(ErrorCode::InvalidData)?)
} }
pub fn get_root_node_struct(b: &[u8]) -> Result<TLVElement, Error> { pub fn get_root_node_struct(b: &[u8]) -> Result<TLVElement, Error> {
TLVList::new(b) let root = TLVList::new(b)
.iter() .iter()
.next() .next()
.ok_or(Error::InvalidData)? .ok_or(ErrorCode::InvalidData)?;
.confirm_struct()
root.confirm_struct()?;
Ok(root)
} }
pub fn get_root_node_list(b: &[u8]) -> Result<TLVElement, Error> { pub fn get_root_node_list(b: &[u8]) -> Result<TLVElement, Error> {
TLVList::new(b) let root = TLVList::new(b)
.iter() .iter()
.next() .next()
.ok_or(Error::InvalidData)? .ok_or(ErrorCode::InvalidData)?;
.confirm_list()
root.confirm_list()?;
Ok(root)
} }
pub fn print_tlv_list(b: &[u8]) { pub fn print_tlv_list(b: &[u8]) {
@ -752,7 +745,7 @@ pub fn print_tlv_list(b: &[u8]) {
match a.element_type { match a.element_type {
ElementType::Struct(_) => { ElementType::Struct(_) => {
if index < MAX_DEPTH { if index < MAX_DEPTH {
println!("{}{}", space[index], a); info!("{}{}", space[index], a);
stack[index] = '}'; stack[index] = '}';
index += 1; index += 1;
} else { } else {
@ -761,7 +754,7 @@ pub fn print_tlv_list(b: &[u8]) {
} }
ElementType::Array(_) | ElementType::List(_) => { ElementType::Array(_) | ElementType::List(_) => {
if index < MAX_DEPTH { if index < MAX_DEPTH {
println!("{}{}", space[index], a); info!("{}{}", space[index], a);
stack[index] = ']'; stack[index] = ']';
index += 1; index += 1;
} else { } else {
@ -771,24 +764,25 @@ pub fn print_tlv_list(b: &[u8]) {
ElementType::EndCnt => { ElementType::EndCnt => {
if index > 0 { if index > 0 {
index -= 1; index -= 1;
println!("{}{}", space[index], stack[index]); info!("{}{}", space[index], stack[index]);
} else { } else {
error!("Incorrect TLV List"); error!("Incorrect TLV List");
} }
} }
_ => println!("{}{}", space[index], a), _ => info!("{}{}", space[index], a),
} }
} }
println!("---------"); info!("---------");
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use log::info;
use super::{ use super::{
get_root_node_list, get_root_node_struct, ElementType, Pointer, TLVElement, TLVList, get_root_node_list, get_root_node_struct, ElementType, TLVElement, TLVList, TagType,
TagType,
}; };
use crate::error::Error; use crate::error::ErrorCode;
#[test] #[test]
fn test_short_length_tag() { fn test_short_length_tag() {
@ -846,11 +840,7 @@ mod tests {
tlv_iter.next(), tlv_iter.next(),
Some(TLVElement { Some(TLVElement {
tag_type: TagType::Context(0), tag_type: TagType::Context(0),
element_type: ElementType::Array(Pointer { element_type: ElementType::Array(&[]),
buf: &[21, 54, 0],
current: 3,
left: 0
}),
}) })
); );
} }
@ -1105,12 +1095,13 @@ mod tests {
.unwrap() .unwrap()
.enter() .enter()
.unwrap(); .unwrap();
println!("Command list iterator: {:?}", cmd_list_iter); info!("Command list iterator: {:?}", cmd_list_iter);
// This is an array of CommandDataIB, but we'll only use the first element // This is an array of CommandDataIB, but we'll only use the first element
let cmd_data_ib = cmd_list_iter.next().unwrap(); let cmd_data_ib = cmd_list_iter.next().unwrap();
let cmd_path = cmd_data_ib.find_tag(0).unwrap().confirm_list().unwrap(); let cmd_path = cmd_data_ib.find_tag(0).unwrap();
let cmd_path = cmd_path.confirm_list().unwrap();
assert_eq!( assert_eq!(
cmd_path.find_tag(0).unwrap(), cmd_path.find_tag(0).unwrap(),
TLVElement { TLVElement {
@ -1132,7 +1123,10 @@ mod tests {
element_type: ElementType::U32(1), element_type: ElementType::U32(1),
} }
); );
assert_eq!(cmd_path.find_tag(3), Err(Error::NoTagFound)); assert_eq!(
cmd_path.find_tag(3).map_err(|e| e.code()),
Err(ErrorCode::NoTagFound)
);
// This is the variable of the invoke command // This is the variable of the invoke command
assert_eq!( assert_eq!(
@ -1172,11 +1166,7 @@ mod tests {
0x35, 0x1, 0x18, 0x18, 0x18, 0x18, 0x35, 0x1, 0x18, 0x18, 0x18, 0x18,
]; ];
let dummy_pointer = Pointer { let dummy_pointer = &b[1..];
buf: &b,
current: 1,
left: 21,
};
// These are the decoded elements that we expect from this input // These are the decoded elements that we expect from this input
let verify_matrix: [(TagType, ElementType); 13] = [ let verify_matrix: [(TagType, ElementType); 13] = [
(TagType::Anonymous, ElementType::Struct(dummy_pointer)), (TagType::Anonymous, ElementType::Struct(dummy_pointer)),
@ -1203,8 +1193,8 @@ mod tests {
Some(a) => { Some(a) => {
assert_eq!(a.tag_type, verify_matrix[index].0); assert_eq!(a.tag_type, verify_matrix[index].0);
assert_eq!( assert_eq!(
std::mem::discriminant(&a.element_type), core::mem::discriminant(&a.element_type),
std::mem::discriminant(&verify_matrix[index].1) core::mem::discriminant(&verify_matrix[index].1)
); );
} }
} }

View file

@ -16,10 +16,10 @@
*/ */
use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType}; use super::{ElementType, TLVContainerIterator, TLVElement, TLVWriter, TagType};
use crate::error::Error; use crate::error::{Error, ErrorCode};
use core::fmt::Debug;
use core::slice::Iter; use core::slice::Iter;
use log::error; use log::error;
use std::fmt::Debug;
pub trait FromTLV<'a> { pub trait FromTLV<'a> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error>
@ -31,33 +31,54 @@ pub trait FromTLV<'a> {
where where
Self: Sized, Self: Sized,
{ {
Err(Error::TLVNotFound) Err(ErrorCode::TLVNotFound.into())
} }
} }
impl<'a, T: Default + FromTLV<'a> + Copy, const N: usize> FromTLV<'a> for [T; N] { impl<'a, T: FromTLV<'a> + Default, const N: usize> FromTLV<'a> for [T; N] {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error>
where where
Self: Sized, Self: Sized,
{ {
t.confirm_array()?; t.confirm_array()?;
let mut a: [T; N] = [Default::default(); N];
let mut index = 0; let mut a = heapless::Vec::<T, N>::new();
if let Some(tlv_iter) = t.enter() { if let Some(tlv_iter) = t.enter() {
for element in tlv_iter { for element in tlv_iter {
if index < N { a.push(T::from_tlv(&element)?)
a[index] = T::from_tlv(&element)?; .map_err(|_| ErrorCode::NoSpace)?;
index += 1;
} else {
error!("Received TLV Array with elements larger than current size");
break;
} }
} }
// TODO: This was the old behavior before rebasing the
// implementation on top of heapless::Vec (to avoid requiring Copy)
// Not sure why we actually need that yet, but without it unit tests fail
while a.len() < N {
a.push(Default::default()).map_err(|_| ErrorCode::NoSpace)?;
} }
Ok(a)
a.into_array().map_err(|_| ErrorCode::Invalid.into())
} }
} }
pub fn from_tlv<'a, T: FromTLV<'a>, const N: usize>(
vec: &mut heapless::Vec<T, N>,
t: &TLVElement<'a>,
) -> Result<(), Error> {
vec.clear();
t.confirm_array()?;
if let Some(tlv_iter) = t.enter() {
for element in tlv_iter {
vec.push(T::from_tlv(&element)?)
.map_err(|_| ErrorCode::NoSpace)?;
}
}
Ok(())
}
macro_rules! fromtlv_for { macro_rules! fromtlv_for {
($($t:ident)*) => { ($($t:ident)*) => {
$( $(
@ -70,12 +91,21 @@ macro_rules! fromtlv_for {
}; };
} }
fromtlv_for!(u8 u16 u32 u64 bool); fromtlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool);
pub trait ToTLV { pub trait ToTLV {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>; fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error>;
} }
impl<T> ToTLV for &T
where
T: ToTLV,
{
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
(**self).to_tlv(tw, tag)
}
}
macro_rules! totlv_for { macro_rules! totlv_for {
($($t:ident)*) => { ($($t:ident)*) => {
$( $(
@ -98,30 +128,39 @@ impl<T: ToTLV, const N: usize> ToTLV for [T; N] {
} }
} }
impl<'a, T: ToTLV> ToTLV for &'a [T] {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.start_array(tag)?;
for i in *self {
i.to_tlv(tw, TagType::Anonymous)?;
}
tw.end_container()
}
}
// Generate ToTLV for standard data types // Generate ToTLV for standard data types
totlv_for!(i8 u8 u16 u32 u64 bool); totlv_for!(i8 u8 i16 u16 i32 u32 i64 u64 bool);
// We define a few common data types that will be required here // We define a few common data types that will be required here
// //
// - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec // - UtfStr, OctetStr: These are versions that map to utfstr and ostr in the TLV spec
// - These only have references into the original list // - These only have references into the original list
// - String, Vec<u8>: Is the owned version of utfstr and ostr, data is cloned into this // - heapless::String<N>, Vheapless::ec<u8, N>: Is the owned version of utfstr and ostr, data is cloned into this
// - String is only partially implemented // - heapless::String is only partially implemented
// //
// - TLVArray: Is an array of entries, with reference within the original list // - TLVArray: Is an array of entries, with reference within the original list
// - TLVArrayOwned: Is the owned version of this, data is cloned into this
/// Implements UTFString from the spec /// Implements UTFString from the spec
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Copy, Clone, PartialEq, Default)]
pub struct UtfStr<'a>(pub &'a [u8]); pub struct UtfStr<'a>(pub &'a [u8]);
impl<'a> UtfStr<'a> { impl<'a> UtfStr<'a> {
pub fn new(str: &'a [u8]) -> Self { pub const fn new(str: &'a [u8]) -> Self {
Self(str) Self(str)
} }
pub fn to_string(self) -> Result<String, Error> { pub fn as_str(&self) -> Result<&str, Error> {
String::from_utf8(self.0.to_vec()).map_err(|_| Error::Invalid) core::str::from_utf8(self.0).map_err(|_| ErrorCode::Invalid.into())
} }
} }
@ -138,7 +177,7 @@ impl<'a> FromTLV<'a> for UtfStr<'a> {
} }
/// Implements OctetString from the spec /// Implements OctetString from the spec
#[derive(Debug, Copy, Clone, PartialEq)] #[derive(Debug, Copy, Clone, PartialEq, Default)]
pub struct OctetStr<'a>(pub &'a [u8]); pub struct OctetStr<'a>(pub &'a [u8]);
impl<'a> OctetStr<'a> { impl<'a> OctetStr<'a> {
@ -160,35 +199,32 @@ impl<'a> ToTLV for OctetStr<'a> {
} }
/// Implements the Owned version of Octet String /// Implements the Owned version of Octet String
impl FromTLV<'_> for Vec<u8> { impl<const N: usize> FromTLV<'_> for heapless::Vec<u8, N> {
fn from_tlv(t: &TLVElement) -> Result<Vec<u8>, Error> { fn from_tlv(t: &TLVElement) -> Result<heapless::Vec<u8, N>, Error> {
t.slice().map(|x| x.to_owned()) heapless::Vec::from_slice(t.slice()?).map_err(|_| ErrorCode::NoSpace.into())
} }
} }
impl ToTLV for Vec<u8> { impl<const N: usize> ToTLV for heapless::Vec<u8, N> {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.str16(tag, self.as_slice()) tw.str16(tag, self.as_slice())
} }
} }
/// Implements the Owned version of UTF String /// Implements the Owned version of UTF String
impl FromTLV<'_> for String { impl<const N: usize> FromTLV<'_> for heapless::String<N> {
fn from_tlv(t: &TLVElement) -> Result<String, Error> { fn from_tlv(t: &TLVElement) -> Result<heapless::String<N>, Error> {
match t.slice() { let mut string = heapless::String::new();
Ok(x) => {
if let Ok(s) = String::from_utf8(x.to_vec()) { string
Ok(s) .push_str(core::str::from_utf8(t.slice()?)?)
} else { .map_err(|_| ErrorCode::NoSpace)?;
Err(Error::Invalid)
} Ok(string)
}
Err(e) => Err(e),
}
} }
} }
impl ToTLV for String { impl<const N: usize> ToTLV for heapless::String<N> {
fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag: TagType) -> Result<(), Error> {
tw.utf16(tag, self.as_bytes()) tw.utf16(tag, self.as_bytes())
} }
@ -262,38 +298,7 @@ impl<T: ToTLV> ToTLV for Nullable<T> {
} }
} }
/// Owned version of a TLVArray #[derive(Clone)]
pub struct TLVArrayOwned<T>(Vec<T>);
impl<'a, T: FromTLV<'a>> FromTLV<'a> for TLVArrayOwned<T> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> {
t.confirm_array()?;
let mut vec = Vec::<T>::new();
if let Some(tlv_iter) = t.enter() {
for element in tlv_iter {
vec.push(T::from_tlv(&element)?);
}
}
Ok(Self(vec))
}
}
impl<T: ToTLV> ToTLV for TLVArrayOwned<T> {
fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
tw.start_array(tag_type)?;
for t in &self.0 {
t.to_tlv(tw, TagType::Anonymous)?;
}
tw.end_container()
}
}
impl<T> TLVArrayOwned<T> {
pub fn iter(&self) -> Iter<T> {
self.0.iter()
}
}
#[derive(Copy, Clone)]
pub enum TLVArray<'a, T> { pub enum TLVArray<'a, T> {
// This is used for the to-tlv path // This is used for the to-tlv path
Slice(&'a [T]), Slice(&'a [T]),
@ -312,14 +317,14 @@ impl<'a, T: ToTLV> TLVArray<'a, T> {
} }
pub fn iter(&self) -> TLVArrayIter<'a, T> { pub fn iter(&self) -> TLVArrayIter<'a, T> {
match *self { match self {
Self::Slice(s) => TLVArrayIter::Slice(s.iter()), Self::Slice(s) => TLVArrayIter::Slice(s.iter()),
Self::Ptr(p) => TLVArrayIter::Ptr(p.enter()), Self::Ptr(p) => TLVArrayIter::Ptr(p.enter()),
} }
} }
} }
impl<'a, T: ToTLV + FromTLV<'a> + Copy> TLVArray<'a, T> { impl<'a, T: ToTLV + FromTLV<'a> + Clone> TLVArray<'a, T> {
pub fn get_index(&self, index: usize) -> T { pub fn get_index(&self, index: usize) -> T {
for (curr, element) in self.iter().enumerate() { for (curr, element) in self.iter().enumerate() {
if curr == index { if curr == index {
@ -330,12 +335,12 @@ impl<'a, T: ToTLV + FromTLV<'a> + Copy> TLVArray<'a, T> {
} }
} }
impl<'a, T: FromTLV<'a> + Copy> Iterator for TLVArrayIter<'a, T> { impl<'a, T: FromTLV<'a> + Clone> Iterator for TLVArrayIter<'a, T> {
type Item = T; type Item = T;
/* Code for going to the next Element */ /* Code for going to the next Element */
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
match self { match self {
Self::Slice(s_iter) => s_iter.next().copied(), Self::Slice(s_iter) => s_iter.next().cloned(),
Self::Ptr(p_iter) => { Self::Ptr(p_iter) => {
if let Some(tlv_iter) = p_iter.as_mut() { if let Some(tlv_iter) = p_iter.as_mut() {
let e = tlv_iter.next(); let e = tlv_iter.next();
@ -354,7 +359,7 @@ impl<'a, T: FromTLV<'a> + Copy> Iterator for TLVArrayIter<'a, T> {
impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T> impl<'a, T> PartialEq<&[T]> for TLVArray<'a, T>
where where
T: ToTLV + FromTLV<'a> + Copy + PartialEq, T: ToTLV + FromTLV<'a> + Clone + PartialEq,
{ {
fn eq(&self, other: &&[T]) -> bool { fn eq(&self, other: &&[T]) -> bool {
let mut iter1 = self.iter(); let mut iter1 = self.iter();
@ -373,34 +378,46 @@ where
} }
} }
impl<'a, T: ToTLV> ToTLV for TLVArray<'a, T> { impl<'a, T: FromTLV<'a> + Clone + ToTLV> ToTLV for TLVArray<'a, T> {
fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> { fn to_tlv(&self, tw: &mut TLVWriter, tag_type: TagType) -> Result<(), Error> {
match *self {
Self::Slice(s) => {
tw.start_array(tag_type)?; tw.start_array(tag_type)?;
for a in s { for a in self.iter() {
a.to_tlv(tw, TagType::Anonymous)?; a.to_tlv(tw, TagType::Anonymous)?;
} }
tw.end_container() tw.end_container()
} // match *self {
Self::Ptr(t) => t.to_tlv(tw, tag_type), // Self::Slice(s) => {
} // tw.start_array(tag_type)?;
// for a in s {
// a.to_tlv(tw, TagType::Anonymous)?;
// }
// tw.end_container()
// }
// Self::Ptr(t) => t.to_tlv(tw, tag_type), <-- TODO: this fails the unit tests of Cert from/to TLV
// }
} }
} }
impl<'a, T> FromTLV<'a> for TLVArray<'a, T> { impl<'a, T> FromTLV<'a> for TLVArray<'a, T> {
fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> { fn from_tlv(t: &TLVElement<'a>) -> Result<Self, Error> {
t.confirm_array()?; t.confirm_array()?;
Ok(Self::Ptr(*t)) Ok(Self::Ptr(t.clone()))
} }
} }
impl<'a, T: Debug + ToTLV + FromTLV<'a> + Copy> Debug for TLVArray<'a, T> { impl<'a, T: Debug + ToTLV + FromTLV<'a> + Clone> Debug for TLVArray<'a, T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "TLVArray [")?;
let mut first = true;
for i in self.iter() { for i in self.iter() {
writeln!(f, "{:?}", i)?; if !first {
write!(f, ", ")?;
} }
writeln!(f)
write!(f, "{:?}", i)?;
first = false;
}
write!(f, "]")
} }
} }
@ -423,7 +440,7 @@ impl<'a> ToTLV for TLVElement<'a> {
ElementType::EndCnt => tw.end_container(), ElementType::EndCnt => tw.end_container(),
_ => { _ => {
error!("ToTLV Not supported"); error!("ToTLV Not supported");
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
} }
@ -431,7 +448,7 @@ impl<'a> ToTLV for TLVElement<'a> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::{FromTLV, OctetStr, TLVElement, TLVWriter, TagType, ToTLV}; use super::{FromTLV, OctetStr, TLVWriter, TagType, ToTLV};
use crate::{error::Error, tlv::TLVList, utils::writebuf::WriteBuf}; use crate::{error::Error, tlv::TLVList, utils::writebuf::WriteBuf};
use matter_macro_derive::{FromTLV, ToTLV}; use matter_macro_derive::{FromTLV, ToTLV};
@ -442,9 +459,8 @@ mod tests {
} }
#[test] #[test]
fn test_derive_totlv() { fn test_derive_totlv() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
let abc = TestDerive { let abc = TestDerive {
@ -525,9 +541,8 @@ mod tests {
#[test] #[test]
fn test_derive_totlv_fab_scoped() { fn test_derive_totlv_fab_scoped() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
let abc = TestDeriveFabScoped { a: 20, fab_idx: 3 }; let abc = TestDeriveFabScoped { a: 20, fab_idx: 3 };
@ -557,9 +572,8 @@ mod tests {
enum_val = TestDeriveEnum::ValueB(10); enum_val = TestDeriveEnum::ValueB(10);
// Test ToTLV // Test ToTLV
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
enum_val.to_tlv(&mut tw, TagType::Anonymous).unwrap(); enum_val.to_tlv(&mut tw, TagType::Anonymous).unwrap();

View file

@ -50,11 +50,11 @@ enum WriteElementType {
} }
pub struct TLVWriter<'a, 'b> { pub struct TLVWriter<'a, 'b> {
buf: &'b mut WriteBuf<'a>, buf: &'a mut WriteBuf<'b>,
} }
impl<'a, 'b> TLVWriter<'a, 'b> { impl<'a, 'b> TLVWriter<'a, 'b> {
pub fn new(buf: &'b mut WriteBuf<'a>) -> Self { pub fn new(buf: &'a mut WriteBuf<'b>) -> Self {
TLVWriter { buf } TLVWriter { buf }
} }
@ -164,7 +164,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> {
pub fn str8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> { pub fn str8(&mut self, tag_type: TagType, data: &[u8]) -> Result<(), Error> {
if data.len() > 256 { if data.len() > 256 {
error!("use str16() instead"); error!("use str16() instead");
return Err(Error::Invalid); return Err(ErrorCode::Invalid.into());
} }
self.put_control_tag(tag_type, WriteElementType::Str8l)?; self.put_control_tag(tag_type, WriteElementType::Str8l)?;
self.buf.le_u8(data.len() as u8)?; self.buf.le_u8(data.len() as u8)?;
@ -265,7 +265,7 @@ impl<'a, 'b> TLVWriter<'a, 'b> {
self.buf.rewind_tail_to(anchor); self.buf.rewind_tail_to(anchor);
} }
pub fn get_buf<'c>(&'c mut self) -> &'c mut WriteBuf<'a> { pub fn get_buf(&mut self) -> &mut WriteBuf<'b> {
self.buf self.buf
} }
} }
@ -277,9 +277,8 @@ mod tests {
#[test] #[test]
fn test_write_success() { fn test_write_success() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
tw.start_struct(TagType::Anonymous).unwrap(); tw.start_struct(TagType::Anonymous).unwrap();
@ -299,9 +298,8 @@ mod tests {
#[test] #[test]
fn test_write_overflow() { fn test_write_overflow() {
let mut buf: [u8; 6] = [0; 6]; let mut buf = [0; 6];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
tw.u8(TagType::Anonymous, 12).unwrap(); tw.u8(TagType::Anonymous, 12).unwrap();
@ -317,9 +315,8 @@ mod tests {
#[test] #[test]
fn test_put_str8() { fn test_put_str8() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
tw.u8(TagType::Context(1), 13).unwrap(); tw.u8(TagType::Context(1), 13).unwrap();
@ -334,9 +331,8 @@ mod tests {
#[test] #[test]
fn test_put_str16_as() { fn test_put_str16_as() {
let mut buf: [u8; 20] = [0; 20]; let mut buf = [0; 20];
let buf_len = buf.len(); let mut writebuf = WriteBuf::new(&mut buf);
let mut writebuf = WriteBuf::new(&mut buf, buf_len);
let mut tw = TLVWriter::new(&mut writebuf); let mut tw = TLVWriter::new(&mut writebuf);
tw.u8(TagType::Context(1), 13).unwrap(); tw.u8(TagType::Context(1), 13).unwrap();

View file

@ -0,0 +1,770 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use core::borrow::Borrow;
use core::mem::MaybeUninit;
use core::pin::pin;
use embassy_futures::select::{select, select_slice, Either};
use embassy_sync::{blocking_mutex::raw::NoopRawMutex, channel::Channel};
use embassy_time::{Duration, Timer};
use log::{error, info, warn};
use crate::utils::select::Notification;
use crate::CommissioningData;
use crate::{
alloc,
data_model::{core::DataModel, objects::DataModelHandler},
error::{Error, ErrorCode},
interaction_model::core::PROTO_ID_INTERACTION_MODEL,
secure_channel::{
common::{OpCode, PROTO_ID_SECURE_CHANNEL},
core::SecureChannel,
},
transport::packet::Packet,
utils::select::EitherUnwrap,
Matter,
};
use super::{
exchange::{
Exchange, ExchangeCtr, ExchangeCtx, ExchangeId, ExchangeState, Role, MAX_EXCHANGES,
},
mrp::ReliableMessage,
packet::{MAX_RX_BUF_SIZE, MAX_RX_STATUS_BUF_SIZE, MAX_TX_BUF_SIZE},
pipe::{Chunk, Pipe},
};
#[derive(Debug)]
enum OpCodeDescriptor {
SecureChannel(OpCode),
InteractionModel(crate::interaction_model::core::OpCode),
Unknown(u8),
}
impl From<u8> for OpCodeDescriptor {
fn from(value: u8) -> Self {
if let Some(opcode) = num::FromPrimitive::from_u8(value) {
Self::SecureChannel(opcode)
} else if let Some(opcode) = num::FromPrimitive::from_u8(value) {
Self::InteractionModel(opcode)
} else {
Self::Unknown(value)
}
}
}
type TxBuf = MaybeUninit<[u8; MAX_TX_BUF_SIZE]>;
type RxBuf = MaybeUninit<[u8; MAX_RX_BUF_SIZE]>;
type SxBuf = MaybeUninit<[u8; MAX_RX_STATUS_BUF_SIZE]>;
#[cfg(any(feature = "std", feature = "embassy-net"))]
pub struct RunBuffers {
udp_bufs: crate::transport::udp::UdpBuffers,
run_bufs: PacketBuffers,
tx_buf: TxBuf,
rx_buf: RxBuf,
}
#[cfg(any(feature = "std", feature = "embassy-net"))]
impl RunBuffers {
#[inline(always)]
pub const fn new() -> Self {
Self {
udp_bufs: crate::transport::udp::UdpBuffers::new(),
run_bufs: PacketBuffers::new(),
tx_buf: core::mem::MaybeUninit::uninit(),
rx_buf: core::mem::MaybeUninit::uninit(),
}
}
}
pub struct PacketBuffers {
tx: [TxBuf; MAX_EXCHANGES],
rx: [RxBuf; MAX_EXCHANGES],
sx: [SxBuf; MAX_EXCHANGES],
}
impl PacketBuffers {
const TX_ELEM: TxBuf = MaybeUninit::uninit();
const RX_ELEM: RxBuf = MaybeUninit::uninit();
const SX_ELEM: SxBuf = MaybeUninit::uninit();
const TX_INIT: [TxBuf; MAX_EXCHANGES] = [Self::TX_ELEM; MAX_EXCHANGES];
const RX_INIT: [RxBuf; MAX_EXCHANGES] = [Self::RX_ELEM; MAX_EXCHANGES];
const SX_INIT: [SxBuf; MAX_EXCHANGES] = [Self::SX_ELEM; MAX_EXCHANGES];
#[inline(always)]
pub const fn new() -> Self {
Self {
tx: Self::TX_INIT,
rx: Self::RX_INIT,
sx: Self::SX_INIT,
}
}
}
impl<'a> Matter<'a> {
#[cfg(any(feature = "std", feature = "embassy-net"))]
pub async fn run<D, H>(
&self,
stack: &crate::transport::network::NetworkStack<D>,
buffers: &mut RunBuffers,
dev_comm: CommissioningData,
handler: &H,
) -> Result<(), Error>
where
D: crate::transport::network::NetworkStackDriver,
H: DataModelHandler,
{
let udp = crate::transport::udp::UdpListener::new(
stack,
crate::transport::network::SocketAddr::new(
crate::transport::network::IpAddr::V6(
crate::transport::network::Ipv6Addr::UNSPECIFIED,
),
self.port,
),
&mut buffers.udp_bufs,
)
.await?;
let tx_pipe = Pipe::new(unsafe { buffers.tx_buf.assume_init_mut() });
let rx_pipe = Pipe::new(unsafe { buffers.rx_buf.assume_init_mut() });
let tx_pipe = &tx_pipe;
let rx_pipe = &rx_pipe;
let udp = &udp;
let run_bufs = &mut buffers.run_bufs;
let mut tx = pin!(async move {
loop {
{
let mut data = tx_pipe.data.lock().await;
if let Some(chunk) = data.chunk {
udp.send(chunk.addr.unwrap_udp(), &data.buf[chunk.start..chunk.end])
.await?;
data.chunk = None;
tx_pipe.data_consumed_notification.signal(());
}
}
tx_pipe.data_supplied_notification.wait().await;
}
});
let mut rx = pin!(async move {
loop {
{
let mut data = rx_pipe.data.lock().await;
if data.chunk.is_none() {
let (len, addr) = udp.recv(data.buf).await?;
data.chunk = Some(Chunk {
start: 0,
end: len,
addr: crate::transport::network::Address::Udp(addr),
});
rx_pipe.data_supplied_notification.signal(());
}
}
rx_pipe.data_consumed_notification.wait().await;
}
});
let mut run = pin!(async move {
self.run_piped(run_bufs, tx_pipe, rx_pipe, dev_comm, handler)
.await
});
embassy_futures::select::select3(&mut tx, &mut rx, &mut run)
.await
.unwrap()
}
pub async fn run_piped<H>(
&self,
buffers: &mut PacketBuffers,
tx_pipe: &Pipe<'_>,
rx_pipe: &Pipe<'_>,
dev_comm: CommissioningData,
handler: &H,
) -> Result<(), Error>
where
H: DataModelHandler,
{
info!("Running Matter transport");
let buf = unsafe { buffers.rx[0].assume_init_mut() };
if self.start_comissioning(dev_comm, buf)? {
info!("Comissioning started");
}
let construction_notification = Notification::new();
let mut rx = pin!(self.handle_rx(buffers, rx_pipe, &construction_notification, handler));
let mut tx = pin!(self.handle_tx(tx_pipe));
select(&mut rx, &mut tx).await.unwrap()
}
#[inline(always)]
async fn handle_rx<H>(
&self,
buffers: &mut PacketBuffers,
rx_pipe: &Pipe<'_>,
construction_notification: &Notification,
handler: &H,
) -> Result<(), Error>
where
H: DataModelHandler,
{
info!("Creating queue for {} exchanges", 1);
let channel = Channel::<NoopRawMutex, _, 1>::new();
info!("Creating {} handlers", MAX_EXCHANGES);
let mut handlers = heapless::Vec::<_, MAX_EXCHANGES>::new();
info!("Handlers size: {}", core::mem::size_of_val(&handlers));
// Unsafely allow mutable aliasing in the packet pools by different indices
let pools: *mut PacketBuffers = buffers;
for index in 0..MAX_EXCHANGES {
let channel = &channel;
let handler_id = index;
let pools = unsafe { pools.as_mut() }.unwrap();
let tx_buf = unsafe { pools.tx[handler_id].assume_init_mut() };
let rx_buf = unsafe { pools.rx[handler_id].assume_init_mut() };
let sx_buf = unsafe { pools.sx[handler_id].assume_init_mut() };
handlers
.push(self.exchange_handler(tx_buf, rx_buf, sx_buf, handler_id, channel, handler))
.map_err(|_| ())
.unwrap();
}
let mut rx = pin!(self.handle_rx_multiplex(rx_pipe, construction_notification, &channel));
let result = select(&mut rx, select_slice(&mut handlers)).await;
if let Either::First(result) = result {
if let Err(e) = &result {
error!("Exitting RX loop due to an error: {:?}", e);
}
result?;
}
Ok(())
}
#[inline(always)]
pub async fn handle_tx(&self, tx_pipe: &Pipe<'_>) -> Result<(), Error> {
loop {
loop {
{
let mut data = tx_pipe.data.lock().await;
if data.chunk.is_none() {
let mut tx = alloc!(Packet::new_tx(data.buf));
if self.pull_tx(&mut tx).await? {
data.chunk = Some(Chunk {
start: tx.get_writebuf()?.get_start(),
end: tx.get_writebuf()?.get_tail(),
addr: tx.peer,
});
tx_pipe.data_supplied_notification.signal(());
} else {
break;
}
}
}
tx_pipe.data_consumed_notification.wait().await;
}
self.wait_tx().await?;
}
}
#[inline(always)]
pub async fn handle_rx_multiplex<'t, 'e, const N: usize>(
&'t self,
rx_pipe: &Pipe<'_>,
construction_notification: &'e Notification,
channel: &Channel<NoopRawMutex, ExchangeCtr<'e>, N>,
) -> Result<(), Error>
where
't: 'e,
{
loop {
info!("Transport: waiting for incoming packets");
{
let mut data = rx_pipe.data.lock().await;
if let Some(chunk) = data.chunk {
let mut rx = alloc!(Packet::new_rx(&mut data.buf[chunk.start..chunk.end]));
rx.peer = chunk.addr;
if let Some(exchange_ctr) =
self.process_rx(construction_notification, &mut rx)?
{
let exchange_id = exchange_ctr.id().clone();
info!("Transport: got new exchange: {:?}", exchange_id);
channel.send(exchange_ctr).await;
info!("Transport: exchange sent");
self.wait_construction(construction_notification, &rx, &exchange_id)
.await?;
info!("Transport: exchange started");
}
data.chunk = None;
rx_pipe.data_consumed_notification.signal(());
}
}
rx_pipe.data_supplied_notification.wait().await
}
#[allow(unreachable_code)]
Ok::<_, Error>(())
}
#[inline(always)]
pub async fn exchange_handler<const N: usize, H>(
&self,
tx_buf: &mut [u8; MAX_TX_BUF_SIZE],
rx_buf: &mut [u8; MAX_RX_BUF_SIZE],
sx_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE],
handler_id: impl core::fmt::Display,
channel: &Channel<NoopRawMutex, ExchangeCtr<'_>, N>,
handler: &H,
) -> Result<(), Error>
where
H: DataModelHandler,
{
loop {
let exchange_ctr: ExchangeCtr<'_> = channel.recv().await;
info!(
"Handler {}: Got exchange {:?}",
handler_id,
exchange_ctr.id()
);
let result = self
.handle_exchange(tx_buf, rx_buf, sx_buf, exchange_ctr, handler)
.await;
if let Err(err) = result {
warn!(
"Handler {}: Exchange closed because of error: {:?}",
handler_id, err
);
} else {
info!("Handler {}: Exchange completed", handler_id);
}
}
}
#[inline(always)]
#[cfg_attr(feature = "nightly", allow(clippy::await_holding_refcell_ref))] // Fine because of the async mutex
pub async fn handle_exchange<H>(
&self,
tx_buf: &mut [u8; MAX_TX_BUF_SIZE],
rx_buf: &mut [u8; MAX_RX_BUF_SIZE],
sx_buf: &mut [u8; MAX_RX_STATUS_BUF_SIZE],
exchange_ctr: ExchangeCtr<'_>,
handler: &H,
) -> Result<(), Error>
where
H: DataModelHandler,
{
let mut tx = alloc!(Packet::new_tx(tx_buf.as_mut()));
let mut rx = alloc!(Packet::new_rx(rx_buf.as_mut()));
let mut exchange = alloc!(exchange_ctr.get(&mut rx).await?);
match rx.get_proto_id() {
PROTO_ID_SECURE_CHANNEL => {
let sc = SecureChannel::new(self);
sc.handle(&mut exchange, &mut rx, &mut tx).await?;
self.notify_changed();
}
PROTO_ID_INTERACTION_MODEL => {
let dm = DataModel::new(handler);
let mut rx_status = alloc!(Packet::new_rx(sx_buf));
dm.handle(&mut exchange, &mut rx, &mut tx, &mut rx_status)
.await?;
self.notify_changed();
}
other => {
error!("Unknown Proto-ID: {}", other);
}
}
Ok(())
}
pub fn reset_transport(&self) {
self.exchanges.borrow_mut().clear();
self.session_mgr.borrow_mut().reset();
}
pub fn process_rx<'r>(
&'r self,
construction_notification: &'r Notification,
src_rx: &mut Packet<'_>,
) -> Result<Option<ExchangeCtr<'r>>, Error> {
self.purge()?;
let mut exchanges = self.exchanges.borrow_mut();
let (ctx, new) = match self.post_recv(&mut exchanges, src_rx) {
Ok((ctx, new)) => (ctx, new),
Err(e) => match e.code() {
ErrorCode::Duplicate => {
self.send_notification.signal(());
return Ok(None);
}
_ => Err(e)?,
},
};
src_rx.log("Got packet");
if src_rx.proto.is_ack() {
if new {
Err(ErrorCode::Invalid)?;
} else {
let state = &mut ctx.state;
match state {
ExchangeState::ExchangeRecv {
tx_acknowledged, ..
} => {
*tx_acknowledged = true;
}
ExchangeState::CompleteAcknowledge { notification, .. } => {
unsafe { notification.as_ref() }.unwrap().signal(());
ctx.state = ExchangeState::Closed;
}
_ => {
// TODO: Error handling
todo!()
}
}
self.notify_changed();
}
}
if new {
let constructor = ExchangeCtr {
exchange: Exchange {
id: ctx.id.clone(),
matter: self,
notification: Notification::new(),
},
construction_notification,
};
self.notify_changed();
Ok(Some(constructor))
} else if src_rx.proto.proto_id == PROTO_ID_SECURE_CHANNEL
&& src_rx.proto.proto_opcode == OpCode::MRPStandAloneAck as u8
{
// Standalone ack, do nothing
Ok(None)
} else {
let state = &mut ctx.state;
match state {
ExchangeState::ExchangeRecv {
rx, notification, ..
} => {
let rx = unsafe { rx.as_mut() }.unwrap();
rx.load(src_rx)?;
unsafe { notification.as_ref() }.unwrap().signal(());
*state = ExchangeState::Active;
}
_ => {
// TODO: Error handling
todo!()
}
}
self.notify_changed();
Ok(None)
}
}
pub async fn wait_construction(
&self,
construction_notification: &Notification,
src_rx: &Packet<'_>,
exchange_id: &ExchangeId,
) -> Result<(), Error> {
construction_notification.wait().await;
let mut exchanges = self.exchanges.borrow_mut();
let ctx = ExchangeCtx::get(&mut exchanges, exchange_id).unwrap();
let state = &mut ctx.state;
match state {
ExchangeState::Construction { rx, notification } => {
let rx = unsafe { rx.as_mut() }.unwrap();
rx.load(src_rx)?;
unsafe { notification.as_ref() }.unwrap().signal(());
*state = ExchangeState::Active;
}
_ => unreachable!(),
}
Ok(())
}
pub async fn wait_tx(&self) -> Result<(), Error> {
select(
self.send_notification.wait(),
Timer::after(Duration::from_millis(100)),
)
.await;
Ok(())
}
pub async fn pull_tx(&self, dest_tx: &mut Packet<'_>) -> Result<bool, Error> {
self.purge()?;
let mut exchanges = self.exchanges.borrow_mut();
let ctx = exchanges.iter_mut().find(|ctx| {
matches!(
&ctx.state,
ExchangeState::Acknowledge { .. }
| ExchangeState::ExchangeSend { .. }
// | ExchangeState::ExchangeRecv {
// tx_acknowledged: false,
// ..
// }
| ExchangeState::Complete { .. } // | ExchangeState::CompleteAcknowledge { .. }
) || ctx.mrp.is_ack_ready(*self.borrow())
});
if let Some(ctx) = ctx {
self.notify_changed();
let state = &mut ctx.state;
let send = match state {
ExchangeState::Acknowledge { notification } => {
ReliableMessage::prepare_ack(ctx.id.id, dest_tx);
unsafe { notification.as_ref() }.unwrap().signal(());
*state = ExchangeState::Active;
true
}
ExchangeState::ExchangeSend {
tx,
rx,
notification,
} => {
let tx = unsafe { tx.as_ref() }.unwrap();
dest_tx.load(tx)?;
*state = ExchangeState::ExchangeRecv {
_tx: tx,
tx_acknowledged: false,
rx: *rx,
notification: *notification,
};
true
}
// ExchangeState::ExchangeRecv { .. } => {
// // TODO: Re-send the tx package if due
// false
// }
ExchangeState::Complete { tx, notification } => {
let tx = unsafe { tx.as_ref() }.unwrap();
dest_tx.load(tx)?;
*state = ExchangeState::CompleteAcknowledge {
_tx: tx as *const _,
notification: *notification,
};
true
}
// ExchangeState::CompleteAcknowledge { .. } => {
// // TODO: Re-send the tx package if due
// false
// }
_ => {
ReliableMessage::prepare_ack(ctx.id.id, dest_tx);
true
}
};
if send {
dest_tx.log("Sending packet");
self.pre_send(ctx, dest_tx)?;
self.notify_changed();
return Ok(true);
}
}
Ok(false)
}
fn purge(&self) -> Result<(), Error> {
loop {
let mut exchanges = self.exchanges.borrow_mut();
if let Some(index) = exchanges.iter_mut().enumerate().find_map(|(index, ctx)| {
matches!(ctx.state, ExchangeState::Closed).then_some(index)
}) {
exchanges.swap_remove(index);
} else {
break;
}
}
Ok(())
}
fn post_recv<'r>(
&self,
exchanges: &'r mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
rx: &mut Packet<'_>,
) -> Result<(&'r mut ExchangeCtx, bool), Error> {
rx.plain_hdr_decode()?;
// Get the session
let mut session_mgr = self.session_mgr.borrow_mut();
let sess_index = session_mgr.post_recv(rx)?;
let session = session_mgr.mut_by_index(sess_index).unwrap();
// Decrypt the message
session.recv(self.epoch, rx)?;
// Get the exchange
// TODO: Handle out of space
let (exch, new) = Self::register(
exchanges,
ExchangeId::load(rx),
Role::complementary(rx.proto.is_initiator()),
// We create a new exchange, only if the peer is the initiator
rx.proto.is_initiator(),
)?;
// Message Reliability Protocol
exch.mrp.recv(rx, self.epoch)?;
Ok((exch, new))
}
fn pre_send(&self, ctx: &mut ExchangeCtx, tx: &mut Packet) -> Result<(), Error> {
let mut session_mgr = self.session_mgr.borrow_mut();
let sess_index = session_mgr
.get(
ctx.id.session_id.id,
ctx.id.session_id.peer_addr,
ctx.id.session_id.peer_nodeid,
ctx.id.session_id.is_encrypted,
)
.ok_or(ErrorCode::NoSession)?;
let session = session_mgr.mut_by_index(sess_index).unwrap();
tx.proto.exch_id = ctx.id.id;
if ctx.role == Role::Initiator {
tx.proto.set_initiator();
}
session.pre_send(tx)?;
ctx.mrp.pre_send(tx)?;
session_mgr.send(sess_index, tx)
}
fn register(
exchanges: &mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
id: ExchangeId,
role: Role,
create_new: bool,
) -> Result<(&mut ExchangeCtx, bool), Error> {
let exchange_index = exchanges
.iter_mut()
.enumerate()
.find_map(|(index, exchange)| (exchange.id == id).then_some(index));
if let Some(exchange_index) = exchange_index {
let exchange = &mut exchanges[exchange_index];
if exchange.role == role {
Ok((exchange, false))
} else {
Err(ErrorCode::NoExchange.into())
}
} else if create_new {
info!("Creating new exchange: {:?}", id);
let exchange = ExchangeCtx {
id,
role,
mrp: ReliableMessage::new(),
state: ExchangeState::Active,
};
exchanges.push(exchange).map_err(|_| ErrorCode::NoSpace)?;
Ok((exchanges.iter_mut().next_back().unwrap(), true))
} else {
Err(ErrorCode::NoExchange.into())
}
}
}

View file

@ -86,6 +86,8 @@ impl RxCtrState {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use log::info;
use super::RxCtrState; use super::RxCtrState;
const ENCRYPTED: bool = true; const ENCRYPTED: bool = true;
@ -194,10 +196,10 @@ mod tests {
#[test] #[test]
fn unencrypted_device_reboot() { fn unencrypted_device_reboot() {
println!("Sub 65532 is {:?}", 1_u16.overflowing_sub(65532)); info!("Sub 65532 is {:?}", 1_u16.overflowing_sub(65532));
println!("Sub 65535 is {:?}", 1_u16.overflowing_sub(65535)); info!("Sub 65535 is {:?}", 1_u16.overflowing_sub(65535));
println!("Sub 11-13 is {:?}", 11_u32.wrapping_sub(13_u32) as i32); info!("Sub 11-13 is {:?}", 11_u32.wrapping_sub(13_u32) as i32);
println!("Sub regular is {:?}", 2000_u16.overflowing_sub(1998)); info!("Sub regular is {:?}", 2000_u16.overflowing_sub(1998));
let mut s = RxCtrState::new(20010); let mut s = RxCtrState::new(20010);
assert_ndup(s.recv(20011, NOT_ENCRYPTED)); assert_ndup(s.recv(20011, NOT_ENCRYPTED));

View file

@ -1,604 +1,308 @@
/* use crate::{
* acl::Accessor,
* Copyright (c) 2020-2022 Project CHIP Authors error::{Error, ErrorCode},
* utils::select::Notification,
* Licensed under the Apache License, Version 2.0 (the "License"); Matter,
* you may not use this file except in compliance with the License. };
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use boxslab::{BoxSlab, Slab}; use super::{
use colored::*; mrp::ReliableMessage,
use log::{error, info, trace}; network::Address,
use std::any::Any; packet::Packet,
use std::fmt; session::{Session, SessionMgr},
use std::time::SystemTime; };
use crate::error::Error; pub const MAX_EXCHANGES: usize = 8;
use crate::secure_channel;
use heapless::LinearMap; #[derive(Debug, PartialEq, Eq, Copy, Clone, Default)]
pub(crate) enum Role {
use super::packet::PacketPool; #[default]
use super::session::CloneData;
use super::{mrp::ReliableMessage, packet::Packet, session::SessionHandle, session::SessionMgr};
pub struct ExchangeCtx<'a> {
pub exch: &'a mut Exchange,
pub sess: SessionHandle<'a>,
}
#[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub enum Role {
Initiator = 0, Initiator = 0,
Responder = 1, Responder = 1,
} }
impl Default for Role { impl Role {
fn default() -> Self { pub fn complementary(is_initiator: bool) -> Self {
Role::Initiator if is_initiator {
Self::Responder
} else {
Self::Initiator
}
} }
} }
/// State of the exchange
#[derive(Debug, PartialEq)]
enum State {
/// The exchange is open and active
Open,
/// The exchange is closed, but keys are active since retransmissions/acks may be pending
Close,
/// The exchange is terminated, keys are destroyed, no communication can happen
Terminate,
}
impl Default for State {
fn default() -> Self {
State::Open
}
}
// Instead of just doing an Option<>, we create some special handling
// where the commonly used higher layer data store does't have to do a Box
#[derive(Debug)] #[derive(Debug)]
pub enum DataOption { pub(crate) struct ExchangeCtx {
Boxed(Box<dyn Any>), pub(crate) id: ExchangeId,
Time(SystemTime), pub(crate) role: Role,
None, pub(crate) mrp: ReliableMessage,
pub(crate) state: ExchangeState,
} }
impl Default for DataOption { impl ExchangeCtx {
fn default() -> Self { pub(crate) fn get<'r>(
DataOption::None exchanges: &'r mut heapless::Vec<ExchangeCtx, MAX_EXCHANGES>,
id: &ExchangeId,
) -> Option<&'r mut ExchangeCtx> {
exchanges.iter_mut().find(|exchange| exchange.id == *id)
} }
} }
#[derive(Debug, Default)] #[derive(Debug, Clone)]
pub struct Exchange { pub(crate) enum ExchangeState {
id: u16, Construction {
sess_idx: usize, rx: *mut Packet<'static>,
role: Role, notification: *const Notification,
state: State,
mrp: ReliableMessage,
// Currently I see this primarily used in PASE and CASE. If that is the limited use
// of this, we might move this into a separate data structure, so as not to burden
// all 'exchanges'.
data: DataOption,
}
impl Exchange {
pub fn new(id: u16, sess_idx: usize, role: Role) -> Exchange {
Exchange {
id,
sess_idx,
role,
state: State::Open,
mrp: ReliableMessage::new(),
..Default::default()
}
}
pub fn terminate(&mut self) {
self.data = DataOption::None;
self.state = State::Terminate;
}
pub fn close(&mut self) {
self.data = DataOption::None;
self.state = State::Close;
}
pub fn is_state_open(&self) -> bool {
self.state == State::Open
}
pub fn is_purgeable(&self) -> bool {
// No Users, No pending ACKs/Retrans
self.state == State::Terminate || (self.state == State::Close && self.mrp.is_empty())
}
pub fn get_id(&self) -> u16 {
self.id
}
pub fn get_role(&self) -> Role {
self.role
}
pub fn is_data_none(&self) -> bool {
matches!(self.data, DataOption::None)
}
pub fn set_data_boxed(&mut self, data: Box<dyn Any>) {
self.data = DataOption::Boxed(data);
}
pub fn clear_data_boxed(&mut self) {
self.data = DataOption::None;
}
pub fn get_data_boxed<T: Any>(&mut self) -> Option<&mut T> {
if let DataOption::Boxed(a) = &mut self.data {
a.downcast_mut::<T>()
} else {
None
}
}
pub fn take_data_boxed<T: Any>(&mut self) -> Option<Box<T>> {
let old = std::mem::replace(&mut self.data, DataOption::None);
if let DataOption::Boxed(d) = old {
d.downcast::<T>().ok()
} else {
self.data = old;
None
}
}
pub fn set_data_time(&mut self, expiry_ts: Option<SystemTime>) {
if let Some(t) = expiry_ts {
self.data = DataOption::Time(t);
}
}
pub fn get_data_time(&self) -> Option<SystemTime> {
match self.data {
DataOption::Time(t) => Some(t),
_ => None,
}
}
pub fn send(
&mut self,
mut proto_tx: BoxSlab<PacketPool>,
session: &mut SessionHandle,
) -> Result<(), Error> {
if self.state == State::Terminate {
info!("Skipping tx for terminated exchange {}", self.id);
return Ok(());
}
trace!("payload: {:x?}", proto_tx.as_borrow_slice());
info!(
"{} with proto id: {} opcode: {}",
"Sending".blue(),
proto_tx.get_proto_id(),
proto_tx.get_proto_opcode(),
);
proto_tx.proto.exch_id = self.id;
if self.role == Role::Initiator {
proto_tx.proto.set_initiator();
}
session.pre_send(&mut proto_tx)?;
self.mrp.pre_send(&mut proto_tx)?;
session.send(proto_tx)
}
}
impl fmt::Display for Exchange {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"exch_id: {:?}, sess_index: {}, role: {:?}, data: {:?}, mrp: {:?}, state: {:?}",
self.id, self.sess_idx, self.role, self.data, self.mrp, self.state
)
}
}
pub fn get_role(is_initiator: bool) -> Role {
if is_initiator {
Role::Initiator
} else {
Role::Responder
}
}
pub fn get_complementary_role(is_initiator: bool) -> Role {
if is_initiator {
Role::Responder
} else {
Role::Initiator
}
}
const MAX_EXCHANGES: usize = 8;
#[derive(Default)]
pub struct ExchangeMgr {
// keys: exch-id
exchanges: LinearMap<u16, Exchange, MAX_EXCHANGES>,
sess_mgr: SessionMgr,
}
pub const MAX_MRP_ENTRIES: usize = 4;
impl ExchangeMgr {
pub fn new(sess_mgr: SessionMgr) -> Self {
Self {
sess_mgr,
exchanges: Default::default(),
}
}
pub fn get_sess_mgr(&mut self) -> &mut SessionMgr {
&mut self.sess_mgr
}
pub fn _get_with_id(
exchanges: &mut LinearMap<u16, Exchange, MAX_EXCHANGES>,
exch_id: u16,
) -> Option<&mut Exchange> {
exchanges.get_mut(&exch_id)
}
pub fn get_with_id(&mut self, exch_id: u16) -> Option<&mut Exchange> {
ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id)
}
fn _get(
exchanges: &mut LinearMap<u16, Exchange, MAX_EXCHANGES>,
sess_idx: usize,
id: u16,
role: Role,
create_new: bool,
) -> Result<&mut Exchange, Error> {
// I don't prefer that we scan the list twice here (once for contains_key and other)
if !exchanges.contains_key(&(id)) {
if create_new {
// If an exchange doesn't exist, create a new one
info!("Creating new exchange");
let e = Exchange::new(id, sess_idx, role);
if exchanges.insert(id, e).is_err() {
return Err(Error::NoSpace);
}
} else {
return Err(Error::NoSpace);
}
}
// At this point, we would either have inserted the record if 'create_new' was set
// or it existed already
if let Some(result) = exchanges.get_mut(&id) {
if result.get_role() == role && sess_idx == result.sess_idx {
Ok(result)
} else {
Err(Error::NoExchange)
}
} else {
error!("This should never happen");
Err(Error::NoSpace)
}
}
/// The Exchange Mgr receive is like a big processing function
pub fn recv(&mut self) -> Result<Option<(BoxSlab<PacketPool>, ExchangeCtx)>, Error> {
// Get the session
let (mut proto_rx, index) = self.sess_mgr.recv()?;
let index = if let Some(s) = index {
s
} else {
// The sessions were full, evict one session, and re-perform post-recv
let evict_index = self.sess_mgr.get_lru();
self.evict_session(evict_index)?;
info!("Reattempting session creation");
self.sess_mgr.post_recv(&proto_rx)?.ok_or(Error::Invalid)?
};
let mut session = self.sess_mgr.get_session_handle(index);
// Decrypt the message
session.recv(&mut proto_rx)?;
// Get the exchange
let exch = ExchangeMgr::_get(
&mut self.exchanges,
index,
proto_rx.proto.exch_id,
get_complementary_role(proto_rx.proto.is_initiator()),
// We create a new exchange, only if the peer is the initiator
proto_rx.proto.is_initiator(),
)?;
// Message Reliability Protocol
exch.mrp.recv(&proto_rx)?;
if exch.is_state_open() {
Ok(Some((
proto_rx,
ExchangeCtx {
exch,
sess: session,
}, },
))) Active,
} else { Acknowledge {
// Instead of an error, we send None here, because it is likely that notification: *const Notification,
// we just processed an acknowledgement that cleared the exchange },
Ok(None) ExchangeSend {
tx: *const Packet<'static>,
rx: *mut Packet<'static>,
notification: *const Notification,
},
ExchangeRecv {
_tx: *const Packet<'static>,
tx_acknowledged: bool,
rx: *mut Packet<'static>,
notification: *const Notification,
},
Complete {
tx: *const Packet<'static>,
notification: *const Notification,
},
CompleteAcknowledge {
_tx: *const Packet<'static>,
notification: *const Notification,
},
Closed,
}
pub struct ExchangeCtr<'a> {
pub(crate) exchange: Exchange<'a>,
pub(crate) construction_notification: &'a Notification,
}
impl<'a> ExchangeCtr<'a> {
pub const fn id(&self) -> &ExchangeId {
self.exchange.id()
}
pub async fn get(mut self, rx: &mut Packet<'_>) -> Result<Exchange<'a>, Error> {
let construction_notification = self.construction_notification;
self.exchange.with_ctx_mut(move |exchange, ctx| {
if !matches!(ctx.state, ExchangeState::Active) {
Err(ErrorCode::NoExchange)?;
}
let rx: &'static mut Packet<'static> = unsafe { core::mem::transmute(rx) };
let notification: &'static Notification =
unsafe { core::mem::transmute(&exchange.notification) };
ctx.state = ExchangeState::Construction { rx, notification };
construction_notification.signal(());
Ok(())
})?;
self.exchange.notification.wait().await;
Ok(self.exchange)
} }
} }
pub fn send(&mut self, exch_id: u16, proto_tx: BoxSlab<PacketPool>) -> Result<(), Error> { #[derive(Debug, Clone, Eq, PartialEq)]
let exchange = pub struct ExchangeId {
ExchangeMgr::_get_with_id(&mut self.exchanges, exch_id).ok_or(Error::NoExchange)?; pub id: u16,
let mut session = self.sess_mgr.get_session_handle(exchange.sess_idx); pub session_id: SessionId,
exchange.send(proto_tx, &mut session)
} }
pub fn purge(&mut self) { impl ExchangeId {
let mut to_purge: LinearMap<u16, (), MAX_EXCHANGES> = LinearMap::new(); pub fn load(rx: &Packet) -> Self {
Self {
for (exch_id, exchange) in self.exchanges.iter() { id: rx.proto.exch_id,
if exchange.is_purgeable() { session_id: SessionId::load(rx),
let _ = to_purge.insert(*exch_id, ());
} }
} }
for (exch_id, _) in to_purge.iter() {
self.exchanges.remove(exch_id);
} }
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct SessionId {
pub id: u16,
pub peer_addr: Address,
pub peer_nodeid: Option<u64>,
pub is_encrypted: bool,
} }
pub fn pending_acks(&mut self, expired_entries: &mut LinearMap<u16, (), MAX_MRP_ENTRIES>) { impl SessionId {
for (exch_id, exchange) in self.exchanges.iter() { pub fn load(rx: &Packet) -> Self {
if exchange.mrp.is_ack_ready() { Self {
expired_entries.insert(*exch_id, ()).unwrap(); id: rx.plain.sess_id,
peer_addr: rx.peer,
peer_nodeid: rx.plain.get_src_u64(),
is_encrypted: rx.plain.is_encrypted(),
} }
} }
} }
pub struct Exchange<'a> {
pub(crate) id: ExchangeId,
pub(crate) matter: &'a Matter<'a>,
pub(crate) notification: Notification,
}
pub fn evict_session(&mut self, index: usize) -> Result<(), Error> { impl<'a> Exchange<'a> {
info!("Sessions full, vacating session with index: {}", index); pub const fn id(&self) -> &ExchangeId {
// If we enter here, we have an LRU session that needs to be reclaimed &self.id
// As per the spec, we need to send a CLOSE here
let mut session = self.sess_mgr.get_session_handle(index);
let mut tx = Slab::<PacketPool>::try_new(Packet::new_tx()?).ok_or(Error::NoSpace)?;
secure_channel::common::create_sc_status_report(
&mut tx,
secure_channel::common::SCStatusCodes::CloseSession,
None,
)?;
if let Some((_, exchange)) = self.exchanges.iter_mut().find(|(_, e)| e.sess_idx == index) {
// Send Close_session on this exchange, and then close the session
// Should this be done for all exchanges?
error!("Sending Close Session");
exchange.send(tx, &mut session)?;
// TODO: This wouldn't actually send it out, because 'transport' isn't owned yet.
} }
let remove_exchanges: Vec<u16> = self pub fn accessor(&self) -> Result<Accessor<'a>, Error> {
.exchanges self.with_session(|sess| Ok(Accessor::for_session(sess, &self.matter.acl_mgr)))
.iter()
.filter_map(|(eid, e)| {
if e.sess_idx == index {
Some(*eid)
} else {
None
} }
pub fn with_session_mut<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&mut Session) -> Result<T, Error>,
{
self.with_ctx(|_self, ctx| {
let mut session_mgr = _self.matter.session_mgr.borrow_mut();
let sess_index = session_mgr
.get(
ctx.id.session_id.id,
ctx.id.session_id.peer_addr,
ctx.id.session_id.peer_nodeid,
ctx.id.session_id.is_encrypted,
)
.ok_or(ErrorCode::NoSession)?;
f(session_mgr.mut_by_index(sess_index).unwrap())
}) })
.collect();
info!(
"Terminating the following exchanges: {:?}",
remove_exchanges
);
for exch_id in remove_exchanges {
// Remove from exchange list
self.exchanges.remove(&exch_id);
} }
self.sess_mgr.remove(index);
pub fn with_session<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&Session) -> Result<T, Error>,
{
self.with_session_mut(|sess| f(sess))
}
pub fn with_session_mgr_mut<F, T>(&self, f: F) -> Result<T, Error>
where
F: FnOnce(&mut SessionMgr) -> Result<T, Error>,
{
let mut session_mgr = self.matter.session_mgr.borrow_mut();
f(&mut session_mgr)
}
pub async fn acknowledge(&mut self) -> Result<(), Error> {
let wait = self.with_ctx_mut(|_self, ctx| {
if !matches!(ctx.state, ExchangeState::Active) {
Err(ErrorCode::NoExchange)?;
}
if ctx.mrp.is_empty() {
Ok(false)
} else {
ctx.state = ExchangeState::Acknowledge {
notification: &_self.notification as *const _,
};
_self.matter.send_notification.signal(());
Ok(true)
}
})?;
if wait {
self.notification.wait().await;
}
Ok(()) Ok(())
} }
pub fn add_session(&mut self, clone_data: &CloneData) -> Result<SessionHandle, Error> { pub async fn exchange(&mut self, tx: &Packet<'_>, rx: &mut Packet<'_>) -> Result<(), Error> {
let sess_idx = match self.sess_mgr.clone_session(clone_data) { let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
Ok(idx) => idx, let rx: &mut Packet<'static> = unsafe { core::mem::transmute(rx) };
Err(Error::NoSpace) => {
let evict_index = self.sess_mgr.get_lru(); self.with_ctx_mut(|_self, ctx| {
self.evict_session(evict_index)?; if !matches!(ctx.state, ExchangeState::Active) {
self.sess_mgr.clone_session(clone_data)? Err(ErrorCode::NoExchange)?;
}
Err(e) => {
return Err(e);
} }
ctx.state = ExchangeState::ExchangeSend {
tx: tx as *const _,
rx: rx as *mut _,
notification: &_self.notification as *const _,
}; };
Ok(self.sess_mgr.get_session_handle(sess_idx)) _self.matter.send_notification.signal(());
}
Ok(())
})?;
self.notification.wait().await;
Ok(())
} }
impl fmt::Display for ExchangeMgr { pub async fn complete(mut self, tx: &Packet<'_>) -> Result<(), Error> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.send_complete(tx).await
writeln!(f, "{{ Session Mgr: {},", self.sess_mgr)?;
writeln!(f, " Exchanges: [")?;
for s in &self.exchanges {
writeln!(f, "{{ {}, }},", s.1)?;
}
writeln!(f, " ]")?;
write!(f, "}}")
}
} }
#[cfg(test)] pub async fn send_complete(&mut self, tx: &Packet<'_>) -> Result<(), Error> {
#[allow(clippy::bool_assert_comparison)] let tx: &Packet<'static> = unsafe { core::mem::transmute(tx) };
mod tests {
use crate::{ self.with_ctx_mut(|_self, ctx| {
error::Error, if !matches!(ctx.state, ExchangeState::Active) {
transport::{ Err(ErrorCode::NoExchange)?;
network::{Address, NetworkInterface}, }
session::{CloneData, SessionMgr, SessionMode, MAX_SESSIONS},
}, ctx.state = ExchangeState::Complete {
tx: tx as *const _,
notification: &_self.notification as *const _,
}; };
_self.matter.send_notification.signal(());
use super::{ExchangeMgr, Role}; Ok(())
})?;
#[test] self.notification.wait().await;
fn test_purge() {
let sess_mgr = SessionMgr::new();
let mut mgr = ExchangeMgr::new(sess_mgr);
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, true).unwrap();
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, true).unwrap();
mgr.purge(); Ok(())
assert_eq!(
ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).is_ok(),
true
);
assert_eq!(
ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, false).is_ok(),
true
);
// Close e1
let e1 = ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).unwrap();
e1.close();
mgr.purge();
assert_eq!(
ExchangeMgr::_get(&mut mgr.exchanges, 1, 2, Role::Responder, false).is_ok(),
false
);
assert_eq!(
ExchangeMgr::_get(&mut mgr.exchanges, 1, 3, Role::Responder, false).is_ok(),
true
);
} }
fn get_clone_data(peer_sess_id: u16, local_sess_id: u16) -> CloneData { fn with_ctx<F, T>(&self, f: F) -> Result<T, Error>
CloneData::new( where
12341234, F: FnOnce(&Self, &ExchangeCtx) -> Result<T, Error>,
43211234, {
peer_sess_id, let mut exchanges = self.matter.exchanges.borrow_mut();
local_sess_id,
Address::default(), let exchange = ExchangeCtx::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO
SessionMode::Pase,
) f(self, exchange)
} }
fn fill_sessions(mgr: &mut ExchangeMgr, count: usize) { fn with_ctx_mut<F, T>(&mut self, f: F) -> Result<T, Error>
let mut local_sess_id = 1; where
let mut peer_sess_id = 100; F: FnOnce(&mut Self, &mut ExchangeCtx) -> Result<T, Error>,
for _ in 1..count { {
let clone_data = get_clone_data(peer_sess_id, local_sess_id); let mut exchanges = self.matter.exchanges.borrow_mut();
match mgr.add_session(&clone_data) {
Ok(s) => assert_eq!(peer_sess_id, s.get_peer_sess_id()), let exchange = ExchangeCtx::get(&mut exchanges, &self.id).ok_or(ErrorCode::NoExchange)?; // TODO
Err(Error::NoSpace) => break,
_ => { f(self, exchange)
panic!("Couldn't, create session");
}
}
local_sess_id += 1;
peer_sess_id += 1;
} }
} }
pub struct DummyNetwork; impl<'a> Drop for Exchange<'a> {
impl DummyNetwork { fn drop(&mut self) {
pub fn new() -> Self { let _ = self.with_ctx_mut(|_self, ctx| {
Self {} ctx.state = ExchangeState::Closed;
} _self.matter.send_notification.signal(());
}
impl NetworkInterface for DummyNetwork { Ok(())
fn recv(&self, _in_buf: &mut [u8]) -> Result<(usize, Address), Error> { });
Ok((0, Address::default()))
}
fn send(&self, _out_buf: &[u8], _addr: Address) -> Result<usize, Error> {
Ok(0)
}
}
#[test]
/// We purposefuly overflow the sessions
/// and when the overflow happens, we confirm that
/// - The sessions are evicted in LRU
/// - The exchanges associated with those sessions are evicted too
fn test_sess_evict() {
let mut sess_mgr = SessionMgr::new();
let transport = Box::new(DummyNetwork::new());
sess_mgr.add_network_interface(transport).unwrap();
let mut mgr = ExchangeMgr::new(sess_mgr);
fill_sessions(&mut mgr, MAX_SESSIONS + 1);
// Sessions are now full from local session id 1 to 16
// Create exchanges for sessions 2 (i.e. session index 1) and 3 (session index 2)
// Exchange IDs are 20 and 30 respectively
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 1, 20, Role::Responder, true).unwrap();
let _ = ExchangeMgr::_get(&mut mgr.exchanges, 2, 30, Role::Responder, true).unwrap();
// Confirm that session ids 1 to MAX_SESSIONS exists
for i in 1..(MAX_SESSIONS + 1) {
assert_eq!(mgr.sess_mgr.get_with_id(i as u16).is_none(), false);
}
// Confirm that the exchanges are around
assert_eq!(mgr.get_with_id(20).is_none(), false);
assert_eq!(mgr.get_with_id(30).is_none(), false);
let mut old_local_sess_id = 1;
let mut new_local_sess_id = 100;
let mut new_peer_sess_id = 200;
for i in 1..(MAX_SESSIONS + 1) {
// Now purposefully overflow the sessions by adding another session
let session = mgr
.add_session(&get_clone_data(new_peer_sess_id, new_local_sess_id))
.unwrap();
assert_eq!(session.get_peer_sess_id(), new_peer_sess_id);
// This should have evicted session with local sess_id
assert_eq!(mgr.sess_mgr.get_with_id(old_local_sess_id).is_none(), true);
new_local_sess_id += 1;
new_peer_sess_id += 1;
old_local_sess_id += 1;
match i {
1 => {
// Both exchanges should exist
assert_eq!(mgr.get_with_id(20).is_none(), false);
assert_eq!(mgr.get_with_id(30).is_none(), false);
}
2 => {
// Exchange 20 would have been evicted
assert_eq!(mgr.get_with_id(20).is_none(), true);
assert_eq!(mgr.get_with_id(30).is_none(), false);
}
3 => {
// Exchange 20 and 30 would have been evicted
assert_eq!(mgr.get_with_id(20).is_none(), true);
assert_eq!(mgr.get_with_id(30).is_none(), true);
}
_ => {}
}
}
// println!("Session mgr {}", mgr.sess_mgr);
} }
} }

View file

@ -1,175 +0,0 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use async_channel::Receiver;
use boxslab::{BoxSlab, Slab};
use heapless::LinearMap;
use log::{debug, error, info};
use crate::error::*;
use crate::transport::mrp::ReliableMessage;
use crate::transport::packet::PacketPool;
use crate::transport::{exchange, packet::Packet, proto_demux, queue, session, udp};
use super::proto_demux::ProtoCtx;
use super::queue::Msg;
pub struct Mgr {
exch_mgr: exchange::ExchangeMgr,
proto_demux: proto_demux::ProtoDemux,
rx_q: Receiver<Msg>,
}
impl Mgr {
pub fn new() -> Result<Mgr, Error> {
let mut sess_mgr = session::SessionMgr::new();
let udp_transport = Box::new(udp::UdpListener::new()?);
sess_mgr.add_network_interface(udp_transport)?;
Ok(Mgr {
proto_demux: proto_demux::ProtoDemux::new(),
exch_mgr: exchange::ExchangeMgr::new(sess_mgr),
rx_q: queue::WorkQ::init()?,
})
}
// Allows registration of different protocols with the Transport/Protocol Demux
pub fn register_protocol(
&mut self,
proto_id_handle: Box<dyn proto_demux::HandleProto>,
) -> Result<(), Error> {
self.proto_demux.register(proto_id_handle)
}
fn send_to_exchange(
&mut self,
exch_id: u16,
proto_tx: BoxSlab<PacketPool>,
) -> Result<(), Error> {
self.exch_mgr.send(exch_id, proto_tx)
}
fn handle_rxtx(&mut self) -> Result<(), Error> {
let result = self.exch_mgr.recv().map_err(|e| {
error!("Error in recv: {:?}", e);
e
})?;
if result.is_none() {
// Nothing to process, return quietly
return Ok(());
}
// result contains something worth processing, we can safely unwrap
// as we already checked for none above
let (rx, exch_ctx) = result.unwrap();
debug!("Exchange is {:?}", exch_ctx.exch);
let tx = Self::new_tx()?;
let mut proto_ctx = ProtoCtx::new(exch_ctx, rx, tx);
// Proto Dispatch
match self.proto_demux.handle(&mut proto_ctx) {
Ok(r) => {
if let proto_demux::ResponseRequired::No = r {
// We need to send the Ack if reliability is enabled, in this case
return Ok(());
}
}
Err(e) => {
error!("Error in proto_demux {:?}", e);
return Err(e);
}
}
let ProtoCtx {
exch_ctx,
rx: _,
tx,
} = proto_ctx;
// tx_ctx now contains the response payload, send the packet
let exch_id = exch_ctx.exch.get_id();
self.send_to_exchange(exch_id, tx).map_err(|e| {
error!("Error in sending msg {:?}", e);
e
})?;
Ok(())
}
fn handle_queue_msgs(&mut self) -> Result<(), Error> {
if let Ok(msg) = self.rx_q.try_recv() {
match msg {
Msg::NewSession(clone_data) => {
// If a new session was created, add it
let _ = self
.exch_mgr
.add_session(&clone_data)
.map_err(|e| error!("Error adding new session {:?}", e));
}
_ => {
error!("Queue Message Type not yet handled {:?}", msg);
}
}
}
Ok(())
}
pub fn start(&mut self) -> Result<(), Error> {
loop {
// Handle network operations
if self.handle_rxtx().is_err() {
error!("Error in handle_rxtx");
continue;
}
if self.handle_queue_msgs().is_err() {
error!("Error in handle_queue_msg");
continue;
}
// Handle any pending acknowledgement send
let mut acks_to_send: LinearMap<u16, (), { exchange::MAX_MRP_ENTRIES }> =
LinearMap::new();
self.exch_mgr.pending_acks(&mut acks_to_send);
for exch_id in acks_to_send.keys() {
info!("Sending MRP Standalone ACK for exch {}", exch_id);
let mut proto_tx = match Self::new_tx() {
Ok(p) => p,
Err(e) => {
error!("Error creating proto_tx {:?}", e);
break;
}
};
ReliableMessage::prepare_ack(*exch_id, &mut proto_tx);
if let Err(e) = self.send_to_exchange(*exch_id, proto_tx) {
error!("Error in sending Ack {:?}", e);
}
}
// Handle exchange purging
// This need not be done in each turn of the loop, maybe once in 5 times or so?
self.exch_mgr.purge();
info!("Exchange Mgr: {}", self.exch_mgr);
}
}
fn new_tx() -> Result<BoxSlab<PacketPool>, Error> {
Slab::<PacketPool>::try_new(Packet::new_tx()?).ok_or(Error::PacketPoolExhaust)
}
}

View file

@ -15,15 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
pub mod core;
mod dedup; mod dedup;
pub mod exchange; pub mod exchange;
pub mod mgr;
pub mod mrp; pub mod mrp;
pub mod network; pub mod network;
pub mod packet; pub mod packet;
pub mod pipe;
pub mod plain_hdr; pub mod plain_hdr;
pub mod proto_demux;
pub mod proto_hdr; pub mod proto_hdr;
pub mod queue;
pub mod session; pub mod session;
pub mod udp; pub mod udp;

View file

@ -15,8 +15,8 @@
* limitations under the License. * limitations under the License.
*/ */
use std::time::Duration; use crate::utils::epoch::Epoch;
use std::time::SystemTime; use core::time::Duration;
use crate::{error::*, secure_channel, transport::packet::Packet}; use crate::{error::*, secure_channel, transport::packet::Packet};
use log::error; use log::error;
@ -41,25 +41,25 @@ impl RetransEntry {
} }
} }
#[derive(Debug, Copy, Clone)] #[derive(Debug, Clone)]
pub struct AckEntry { pub struct AckEntry {
// The msg counter that we should acknowledge // The msg counter that we should acknowledge
msg_ctr: u32, msg_ctr: u32,
// The max time after which this entry must be ACK // The max time after which this entry must be ACK
ack_timeout: SystemTime, ack_timeout: Duration,
} }
impl AckEntry { impl AckEntry {
pub fn new(msg_ctr: u32) -> Result<Self, Error> { pub fn new(msg_ctr: u32, epoch: Epoch) -> Result<Self, Error> {
if let Some(ack_timeout) = if let Some(ack_timeout) =
SystemTime::now().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT)) epoch().checked_add(Duration::from_millis(MRP_STANDALONE_ACK_TIMEOUT))
{ {
Ok(Self { Ok(Self {
msg_ctr, msg_ctr,
ack_timeout, ack_timeout,
}) })
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
@ -67,8 +67,8 @@ impl AckEntry {
self.msg_ctr self.msg_ctr
} }
pub fn has_timed_out(&self) -> bool { pub fn has_timed_out(&self, epoch: Epoch) -> bool {
self.ack_timeout > SystemTime::now() self.ack_timeout > epoch()
} }
} }
@ -90,10 +90,10 @@ impl ReliableMessage {
} }
// Check any pending acknowledgements / retransmissions and take action // Check any pending acknowledgements / retransmissions and take action
pub fn is_ack_ready(&self) -> bool { pub fn is_ack_ready(&self, epoch: Epoch) -> bool {
// Acknowledgements // Acknowledgements
if let Some(ack_entry) = self.ack { if let Some(ack_entry) = &self.ack {
ack_entry.has_timed_out() ack_entry.has_timed_out(epoch)
} else { } else {
false false
} }
@ -107,7 +107,7 @@ impl ReliableMessage {
// Check if any acknowledgements are pending for this exchange, // Check if any acknowledgements are pending for this exchange,
// if so, piggy back in the encoded header here // if so, piggy back in the encoded header here
if let Some(ack_entry) = self.ack { if let Some(ack_entry) = &self.ack {
// Ack Entry exists, set ACK bit and remove from table // Ack Entry exists, set ACK bit and remove from table
proto_tx.proto.set_ack(ack_entry.get_msg_ctr()); proto_tx.proto.set_ack(ack_entry.get_msg_ctr());
self.ack = None; self.ack = None;
@ -120,7 +120,7 @@ impl ReliableMessage {
if self.retrans.is_some() { if self.retrans.is_some() {
// This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen
error!("Previous retrans entry for this exchange already exists"); error!("Previous retrans entry for this exchange already exists");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
self.retrans = Some(RetransEntry::new(proto_tx.plain.ctr)); self.retrans = Some(RetransEntry::new(proto_tx.plain.ctr));
@ -132,10 +132,10 @@ impl ReliableMessage {
* - there can be only one pending retransmission per exchange (so this is per-exchange) * - there can be only one pending retransmission per exchange (so this is per-exchange)
* - duplicate detection should happen per session (obviously), so that part is per-session * - duplicate detection should happen per session (obviously), so that part is per-session
*/ */
pub fn recv(&mut self, proto_rx: &Packet) -> Result<(), Error> { pub fn recv(&mut self, proto_rx: &Packet, epoch: Epoch) -> Result<(), Error> {
if proto_rx.proto.is_ack() { if proto_rx.proto.is_ack() {
// Handle received Acks // Handle received Acks
let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(Error::Invalid)?; let ack_msg_ctr = proto_rx.proto.get_ack_msg_ctr().ok_or(ErrorCode::Invalid)?;
if let Some(entry) = &self.retrans { if let Some(entry) = &self.retrans {
if entry.get_msg_ctr() != ack_msg_ctr { if entry.get_msg_ctr() != ack_msg_ctr {
// TODO: XXX Fix this // TODO: XXX Fix this
@ -150,10 +150,10 @@ impl ReliableMessage {
// This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen // This indicates there was some existing entry for same sess-id/exch-id, which shouldnt happen
// TODO: As per the spec if this happens, we need to send out the previous ACK and note this new ACK // TODO: As per the spec if this happens, we need to send out the previous ACK and note this new ACK
error!("Previous ACK entry for this exchange already exists"); error!("Previous ACK entry for this exchange already exists");
return Err(Error::Invalid); Err(ErrorCode::Invalid)?;
} }
self.ack = Some(AckEntry::new(proto_rx.plain.ctr)?); self.ack = Some(AckEntry::new(proto_rx.plain.ctr, epoch)?);
} }
Ok(()) Ok(())
} }

View file

@ -15,18 +15,25 @@
* limitations under the License. * limitations under the License.
*/ */
use std::{ use core::fmt::{Debug, Display};
fmt::{Debug, Display}, #[cfg(not(feature = "std"))]
net::{IpAddr, Ipv4Addr, SocketAddr}, pub use no_std_net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
}; #[cfg(feature = "std")]
pub use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr};
use crate::error::Error; #[derive(Eq, PartialEq, Copy, Clone)]
#[derive(PartialEq, Copy, Clone)]
pub enum Address { pub enum Address {
Udp(SocketAddr), Udp(SocketAddr),
} }
impl Address {
pub fn unwrap_udp(self) -> SocketAddr {
match self {
Self::Udp(addr) => addr,
}
}
}
impl Default for Address { impl Default for Address {
fn default() -> Self { fn default() -> Self {
Address::Udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080)) Address::Udp(SocketAddr::new(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0)), 8080))
@ -34,7 +41,7 @@ impl Default for Address {
} }
impl Display for Address { impl Display for Address {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
Address::Udp(addr) => writeln!(f, "{}", addr), Address::Udp(addr) => writeln!(f, "{}", addr),
} }
@ -42,14 +49,36 @@ impl Display for Address {
} }
impl Debug for Address { impl Debug for Address {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self { match self {
Address::Udp(addr) => writeln!(f, "{}", addr), Address::Udp(addr) => writeln!(f, "{}", addr),
} }
} }
} }
pub trait NetworkInterface { #[cfg(all(feature = "std", not(feature = "embassy-net")))]
fn recv(&self, in_buf: &mut [u8]) -> Result<(usize, Address), Error>; pub use std_stack::*;
fn send(&self, out_buf: &[u8], addr: Address) -> Result<usize, Error>;
#[cfg(feature = "embassy-net")]
pub use embassy_net_stack::*;
#[cfg(feature = "std")]
pub mod std_stack {
pub trait NetworkStackDriver {}
impl NetworkStackDriver for () {}
pub struct NetworkStack<D>(D);
impl NetworkStack<()> {
pub const fn new() -> Self {
Self(())
}
}
}
#[cfg(feature = "embassy-net")]
pub mod embassy_net_stack {
pub use embassy_net::Stack as NetworkStack;
pub use embassy_net_driver::Driver as NetworkStackDriver;
} }

View file

@ -15,14 +15,14 @@
* limitations under the License. * limitations under the License.
*/ */
use log::{error, trace}; use log::{error, info, trace};
use std::sync::Mutex; use owo_colors::OwoColorize;
use boxslab::box_slab;
use crate::{ use crate::{
error::Error, error::{Error, ErrorCode},
sys::MAX_PACKET_POOL_SIZE, interaction_model::core::PROTO_ID_INTERACTION_MODEL,
secure_channel::common::PROTO_ID_SECURE_CHANNEL,
tlv,
utils::{parsebuf::ParseBuf, writebuf::WriteBuf}, utils::{parsebuf::ParseBuf, writebuf::WriteBuf},
}; };
@ -33,56 +33,10 @@ use super::{
}; };
pub const MAX_RX_BUF_SIZE: usize = 1583; pub const MAX_RX_BUF_SIZE: usize = 1583;
type Buffer = [u8; MAX_RX_BUF_SIZE]; pub const MAX_RX_STATUS_BUF_SIZE: usize = 100;
pub const MAX_TX_BUF_SIZE: usize = 1280 - 40/*IPV6 header size*/ - 8/*UDP header size*/;
// TODO: I am not very happy with this construction, need to find another way to do this #[derive(Debug, PartialEq, Eq, Copy, Clone)]
pub struct BufferPool {
buffers: [Option<Buffer>; MAX_PACKET_POOL_SIZE],
}
impl BufferPool {
const INIT: Option<Buffer> = None;
fn get() -> &'static Mutex<BufferPool> {
static mut BUFFER_HOLDER: Option<Mutex<BufferPool>> = None;
static ONCE: Once = Once::new();
unsafe {
ONCE.call_once(|| {
BUFFER_HOLDER = Some(Mutex::new(BufferPool {
buffers: [BufferPool::INIT; MAX_PACKET_POOL_SIZE],
}));
});
BUFFER_HOLDER.as_ref().unwrap()
}
}
pub fn alloc() -> Option<(usize, &'static mut Buffer)> {
trace!("Buffer Alloc called\n");
let mut pool = BufferPool::get().lock().unwrap();
for i in 0..MAX_PACKET_POOL_SIZE {
if pool.buffers[i].is_none() {
pool.buffers[i] = Some([0; MAX_RX_BUF_SIZE]);
// Sigh! to by-pass the borrow-checker telling us we are stealing a mutable reference
// from under the lock
// In this case the lock only protects against the setting of Some/None,
// the objects then are independently accessed in a unique way
let buffer = unsafe { &mut *(pool.buffers[i].as_mut().unwrap() as *mut Buffer) };
return Some((i, buffer));
}
}
None
}
pub fn free(index: usize) {
trace!("Buffer Free called\n");
let mut pool = BufferPool::get().lock().unwrap();
if pool.buffers[index].is_some() {
pool.buffers[index] = None;
}
}
}
#[derive(PartialEq)]
enum RxState { enum RxState {
Uninit, Uninit,
PlainDecode, PlainDecode,
@ -94,51 +48,95 @@ enum Direction<'a> {
Rx(ParseBuf<'a>, RxState), Rx(ParseBuf<'a>, RxState),
} }
impl<'a> Direction<'a> {
pub fn load(&mut self, direction: &Direction) -> Result<(), Error> {
if matches!(self, Self::Tx(_)) != matches!(direction, Direction::Tx(_)) {
Err(ErrorCode::Invalid)?;
}
match self {
Self::Tx(wb) => match direction {
Direction::Tx(src_wb) => wb.load(src_wb)?,
Direction::Rx(_, _) => Err(ErrorCode::Invalid)?,
},
Self::Rx(pb, state) => match direction {
Direction::Tx(_) => Err(ErrorCode::Invalid)?,
Direction::Rx(src_pb, src_state) => {
pb.load(src_pb)?;
*state = *src_state;
}
},
}
Ok(())
}
}
pub struct Packet<'a> { pub struct Packet<'a> {
pub plain: PlainHdr, pub plain: PlainHdr,
pub proto: ProtoHdr, pub proto: ProtoHdr,
pub peer: Address, pub peer: Address,
data: Direction<'a>, data: Direction<'a>,
buffer_index: usize,
} }
impl<'a> Packet<'a> { impl<'a> Packet<'a> {
const HDR_RESERVE: usize = plain_hdr::max_plain_hdr_len() + proto_hdr::max_proto_hdr_len(); const HDR_RESERVE: usize = plain_hdr::max_plain_hdr_len() + proto_hdr::max_proto_hdr_len();
pub fn new_rx() -> Result<Self, Error> { pub fn new_rx(buf: &'a mut [u8]) -> Self {
let (buffer_index, buffer) = BufferPool::alloc().ok_or(Error::NoSpace)?; Self {
let buf_len = buffer.len();
Ok(Self {
plain: Default::default(), plain: Default::default(),
proto: Default::default(), proto: Default::default(),
buffer_index,
peer: Address::default(), peer: Address::default(),
data: Direction::Rx(ParseBuf::new(buffer, buf_len), RxState::Uninit), data: Direction::Rx(ParseBuf::new(buf), RxState::Uninit),
}) }
} }
pub fn new_tx() -> Result<Self, Error> { pub fn new_tx(buf: &'a mut [u8]) -> Self {
let (buffer_index, buffer) = BufferPool::alloc().ok_or(Error::NoSpace)?; let mut wb = WriteBuf::new(buf);
let buf_len = buffer.len(); wb.reserve(Packet::HDR_RESERVE).unwrap();
let mut wb = WriteBuf::new(buffer, buf_len); // Reliability on by default
wb.reserve(Packet::HDR_RESERVE)?; let mut proto: ProtoHdr = Default::default();
proto.set_reliable();
let mut p = Self { Self {
plain: Default::default(), plain: Default::default(),
proto: Default::default(), proto,
buffer_index,
peer: Address::default(), peer: Address::default(),
data: Direction::Tx(wb), data: Direction::Tx(wb),
}; }
// Reliability on by default
p.proto.set_reliable();
Ok(p)
} }
pub fn as_borrow_slice(&mut self) -> &mut [u8] { pub fn reset(&mut self) {
if let Direction::Tx(wb) = &mut self.data {
wb.reset();
wb.reserve(Packet::HDR_RESERVE).unwrap();
self.plain = Default::default();
self.proto = Default::default();
self.peer = Address::default();
self.proto.set_reliable();
}
}
pub fn load(&mut self, packet: &Packet) -> Result<(), Error> {
self.plain = packet.plain.clone();
self.proto = packet.proto.clone();
self.peer = packet.peer;
self.data.load(&packet.data)
}
pub fn as_slice(&self) -> &[u8] {
match &self.data {
Direction::Rx(pb, _) => pb.as_slice(),
Direction::Tx(wb) => wb.as_slice(),
}
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
match &mut self.data { match &mut self.data {
Direction::Rx(pb, _) => pb.as_borrow_slice(), Direction::Rx(pb, _) => pb.as_mut_slice(),
Direction::Tx(wb) => wb.as_mut_slice(), Direction::Tx(wb) => wb.as_mut_slice(),
} }
} }
@ -147,7 +145,7 @@ impl<'a> Packet<'a> {
if let Direction::Rx(pbuf, _) = &mut self.data { if let Direction::Rx(pbuf, _) = &mut self.data {
Ok(pbuf) Ok(pbuf)
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
@ -155,7 +153,7 @@ impl<'a> Packet<'a> {
if let Direction::Tx(wbuf) = &mut self.data { if let Direction::Tx(wbuf) = &mut self.data {
Ok(wbuf) Ok(wbuf)
} else { } else {
Err(Error::Invalid) Err(ErrorCode::Invalid.into())
} }
} }
@ -167,10 +165,22 @@ impl<'a> Packet<'a> {
self.proto.proto_id = proto_id; self.proto.proto_id = proto_id;
} }
pub fn get_proto_opcode(&self) -> u8 { pub fn get_proto_opcode<T: num::FromPrimitive>(&self) -> Result<T, Error> {
num::FromPrimitive::from_u8(self.proto.proto_opcode).ok_or(ErrorCode::Invalid.into())
}
pub fn get_proto_raw_opcode(&self) -> u8 {
self.proto.proto_opcode self.proto.proto_opcode
} }
pub fn check_proto_opcode(&self, opcode: u8) -> Result<(), Error> {
if self.proto.proto_opcode == opcode {
Ok(())
} else {
Err(ErrorCode::Invalid.into())
}
}
pub fn set_proto_opcode(&mut self, proto_opcode: u8) { pub fn set_proto_opcode(&mut self, proto_opcode: u8) {
self.proto.proto_opcode = proto_opcode; self.proto.proto_opcode = proto_opcode;
} }
@ -196,20 +206,66 @@ impl<'a> Packet<'a> {
.decrypt_and_decode(&self.plain, pb, peer_nodeid, dec_key) .decrypt_and_decode(&self.plain, pb, peer_nodeid, dec_key)
} else { } else {
error!("Invalid state for proto_decode"); error!("Invalid state for proto_decode");
Err(Error::InvalidState) Err(ErrorCode::InvalidState.into())
} }
} }
_ => Err(Error::InvalidState), _ => Err(ErrorCode::InvalidState.into()),
} }
} }
pub fn proto_encode(
&mut self,
peer: Address,
peer_nodeid: Option<u64>,
local_nodeid: u64,
plain_text: bool,
enc_key: Option<&[u8]>,
) -> Result<(), Error> {
self.peer = peer;
// Generate encrypted header
let mut tmp_buf = [0_u8; proto_hdr::max_proto_hdr_len()];
let mut write_buf = WriteBuf::new(&mut tmp_buf);
self.proto.encode(&mut write_buf)?;
self.get_writebuf()?.prepend(write_buf.as_slice())?;
// Generate plain-text header
if plain_text {
if let Some(d) = peer_nodeid {
self.plain.set_dest_u64(d);
}
}
let mut tmp_buf = [0_u8; plain_hdr::max_plain_hdr_len()];
let mut write_buf = WriteBuf::new(&mut tmp_buf);
self.plain.encode(&mut write_buf)?;
let plain_hdr_bytes = write_buf.as_slice();
trace!("unencrypted packet: {:x?}", self.as_mut_slice());
let ctr = self.plain.ctr;
if let Some(e) = enc_key {
proto_hdr::encrypt_in_place(
ctr,
local_nodeid,
plain_hdr_bytes,
self.get_writebuf()?,
e,
)?;
}
self.get_writebuf()?.prepend(plain_hdr_bytes)?;
trace!("Full encrypted packet: {:x?}", self.as_mut_slice());
Ok(())
}
pub fn is_plain_hdr_decoded(&self) -> Result<bool, Error> { pub fn is_plain_hdr_decoded(&self) -> Result<bool, Error> {
match &self.data { match &self.data {
Direction::Rx(_, state) => match state { Direction::Rx(_, state) => match state {
RxState::Uninit => Ok(false), RxState::Uninit => Ok(false),
_ => Ok(true), _ => Ok(true),
}, },
_ => Err(Error::InvalidState), _ => Err(ErrorCode::InvalidState.into()),
} }
} }
@ -221,19 +277,50 @@ impl<'a> Packet<'a> {
self.plain.decode(pb) self.plain.decode(pb)
} else { } else {
error!("Invalid state for plain_decode"); error!("Invalid state for plain_decode");
Err(Error::InvalidState) Err(ErrorCode::InvalidState.into())
} }
} }
_ => Err(Error::InvalidState), _ => Err(ErrorCode::InvalidState.into()),
}
}
}
impl<'a> Drop for Packet<'a> {
fn drop(&mut self) {
BufferPool::free(self.buffer_index);
trace!("Dropping Packet......");
} }
} }
box_slab!(PacketPool, Packet<'static>, MAX_PACKET_POOL_SIZE); pub fn log(&self, operation: &str) {
match self.get_proto_id() {
PROTO_ID_SECURE_CHANNEL => {
if let Ok(opcode) = self.get_proto_opcode::<crate::secure_channel::common::OpCode>()
{
info!("{} SC:{:?}: ", operation.cyan(), opcode);
} else {
info!(
"{} SC:{}??: ",
operation.cyan(),
self.get_proto_raw_opcode()
);
}
tlv::print_tlv_list(self.as_slice());
}
PROTO_ID_INTERACTION_MODEL => {
if let Ok(opcode) =
self.get_proto_opcode::<crate::interaction_model::core::OpCode>()
{
info!("{} IM:{:?}: ", operation.cyan(), opcode);
} else {
info!(
"{} IM:{}??: ",
operation.cyan(),
self.get_proto_raw_opcode()
);
}
tlv::print_tlv_list(self.as_slice());
}
other => info!(
"{} {}??:{}??: ",
operation.cyan(),
other,
self.get_proto_raw_opcode()
),
}
}
}

View file

@ -0,0 +1,94 @@
/*
*
* Copyright (c) 2020-2022 Project CHIP Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
use embassy_sync::{blocking_mutex::raw::NoopRawMutex, mutex::Mutex};
use crate::utils::select::Notification;
use super::network::Address;
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
pub struct Chunk {
pub start: usize,
pub end: usize,
pub addr: Address,
}
pub struct PipeData<'a> {
pub buf: &'a mut [u8],
pub chunk: Option<Chunk>,
}
pub struct Pipe<'a> {
pub data: Mutex<NoopRawMutex, PipeData<'a>>,
pub data_supplied_notification: Notification,
pub data_consumed_notification: Notification,
}
impl<'a> Pipe<'a> {
#[inline(always)]
pub fn new(buf: &'a mut [u8]) -> Self {
Self {
data: Mutex::new(PipeData { buf, chunk: None }),
data_supplied_notification: Notification::new(),
data_consumed_notification: Notification::new(),
}
}
pub async fn recv(&self, buf: &mut [u8]) -> (usize, Address) {
loop {
{
let mut data = self.data.lock().await;
if let Some(chunk) = data.chunk {
buf[..chunk.end - chunk.start]
.copy_from_slice(&data.buf[chunk.start..chunk.end]);
data.chunk = None;
self.data_consumed_notification.signal(());
return (chunk.end - chunk.start, chunk.addr);
}
}
self.data_supplied_notification.wait().await
}
}
pub async fn send(&self, addr: Address, buf: &[u8]) {
loop {
{
let mut data = self.data.lock().await;
if data.chunk.is_none() {
data.buf[..buf.len()].copy_from_slice(buf);
data.chunk = Some(Chunk {
start: 0,
end: buf.len(),
addr,
});
self.data_supplied_notification.signal(());
break;
}
}
self.data_consumed_notification.wait().await
}
}
}

Some files were not shown because too many files have changed in this diff Show more