/*
 *
 *    Copyright (c) 2020-2022 Project CHIP Authors
 *
 *    Licensed under the Apache License, Version 2.0 (the "License");
 *    you may not use this file except in compliance with the License.
 *    You may obtain a copy of the License at
 *
 *        http://www.apache.org/licenses/LICENSE-2.0
 *
 *    Unless required by applicable law or agreed to in writing, software
 *    distributed under the License is distributed on an "AS IS" BASIS,
 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *    See the License for the specific language governing permissions and
 *    limitations under the License.
 */

use core::{fmt::Write, time::Duration};

use super::{
    common::{create_sc_status_report, SCStatusCodes},
    spake2p::{Spake2P, VerifierData},
};
use crate::{
    crypto,
    error::{Error, ErrorCode},
    mdns::{Mdns, ServiceMode},
    secure_channel::common::OpCode,
    tlv::{self, get_root_node_struct, FromTLV, OctetStr, TLVWriter, TagType, ToTLV},
    transport::{
        exchange::ExchangeCtx,
        network::Address,
        proto_ctx::ProtoCtx,
        session::{CloneData, SessionMode},
    },
    utils::{epoch::Epoch, rand::Rand},
};
use log::{error, info};

#[allow(clippy::large_enum_variant)]
enum PaseMgrState {
    Enabled(Pake, heapless::String<16>),
    Disabled,
}

pub struct PaseMgr {
    state: PaseMgrState,
    epoch: Epoch,
    rand: Rand,
}

impl PaseMgr {
    #[inline(always)]
    pub fn new(epoch: Epoch, rand: Rand) -> Self {
        Self {
            state: PaseMgrState::Disabled,
            epoch,
            rand,
        }
    }

    pub fn is_pase_session_enabled(&self) -> bool {
        matches!(&self.state, PaseMgrState::Enabled(_, _))
    }

    pub fn enable_pase_session(
        &mut self,
        verifier: VerifierData,
        discriminator: u16,
        mdns: &dyn Mdns,
    ) -> Result<(), Error> {
        let mut buf = [0; 8];
        (self.rand)(&mut buf);
        let num = u64::from_be_bytes(buf);

        let mut mdns_service_name = heapless::String::<16>::new();
        write!(&mut mdns_service_name, "{:016X}", num).unwrap();

        mdns.add(
            &mdns_service_name,
            ServiceMode::Commissionable(discriminator),
        )?;
        self.state = PaseMgrState::Enabled(
            Pake::new(verifier, self.epoch, self.rand),
            mdns_service_name,
        );

        Ok(())
    }

    pub fn disable_pase_session(&mut self, mdns: &dyn Mdns) -> Result<(), Error> {
        if let PaseMgrState::Enabled(_, mdns_service_name) = &self.state {
            mdns.remove(mdns_service_name)?;
        }

        self.state = PaseMgrState::Disabled;

        Ok(())
    }

    /// If the PASE Session is enabled, execute the closure,
    /// if not enabled, generate SC Status Report
    fn if_enabled<F, T>(&mut self, ctx: &mut ProtoCtx, f: F) -> Result<Option<T>, Error>
    where
        F: FnOnce(&mut Pake, &mut ProtoCtx) -> Result<T, Error>,
    {
        if let PaseMgrState::Enabled(pake, _) = &mut self.state {
            let data = f(pake, ctx)?;

            Ok(Some(data))
        } else {
            error!("PASE Not enabled");
            create_sc_status_report(ctx.tx, SCStatusCodes::InvalidParameter, None)?;
            Ok(None)
        }
    }

    pub fn pbkdfparamreq_handler(&mut self, ctx: &mut ProtoCtx) -> Result<bool, Error> {
        ctx.tx.set_proto_opcode(OpCode::PBKDFParamResponse as u8);
        self.if_enabled(ctx, |pake, ctx| pake.handle_pbkdfparamrequest(ctx))?;
        Ok(true)
    }

    pub fn pasepake1_handler(&mut self, ctx: &mut ProtoCtx) -> Result<bool, Error> {
        ctx.tx.set_proto_opcode(OpCode::PASEPake2 as u8);
        self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake1(ctx))?;
        Ok(true)
    }

    pub fn pasepake3_handler(
        &mut self,
        ctx: &mut ProtoCtx,
        mdns: &dyn Mdns,
    ) -> Result<(bool, Option<CloneData>), Error> {
        let clone_data = self.if_enabled(ctx, |pake, ctx| pake.handle_pasepake3(ctx))?;
        self.disable_pase_session(mdns)?;
        Ok((true, clone_data.flatten()))
    }
}

// This file basically deals with the handlers for the PASE secure channel protocol
// TLV extraction and encoding is done in this file.
// We create a Spake2p object and set it up in the exchange-data. This object then
// handles Spake2+ specific stuff.

const PASE_DISCARD_TIMEOUT_SECS: Duration = Duration::from_secs(60);

const SPAKE2_SESSION_KEYS_INFO: [u8; 11] = *b"SessionKeys";

struct SessionData {
    start_time: Duration,
    exch_id: u16,
    peer_addr: Address,
    spake2p: Spake2P,
}

impl SessionData {
    fn is_sess_expired(&self, epoch: Epoch) -> Result<bool, Error> {
        Ok(epoch() - self.start_time > PASE_DISCARD_TIMEOUT_SECS)
    }
}

#[allow(clippy::large_enum_variant)]
enum PakeState {
    Idle,
    InProgress(SessionData),
}

impl PakeState {
    const fn new() -> Self {
        Self::Idle
    }

    fn take(&mut self) -> Result<SessionData, Error> {
        let new = core::mem::replace(self, PakeState::Idle);
        if let PakeState::InProgress(s) = new {
            Ok(s)
        } else {
            Err(ErrorCode::InvalidSignature.into())
        }
    }

    fn is_idle(&self) -> bool {
        core::mem::discriminant(self) == core::mem::discriminant(&PakeState::Idle)
    }

    fn take_sess_data(&mut self, exch_ctx: &ExchangeCtx) -> Result<SessionData, Error> {
        let sd = self.take()?;
        if sd.exch_id != exch_ctx.exch.get_id() || sd.peer_addr != exch_ctx.sess.get_peer_addr() {
            Err(ErrorCode::InvalidState.into())
        } else {
            Ok(sd)
        }
    }

    fn make_in_progress(&mut self, epoch: Epoch, spake2p: Spake2P, exch_ctx: &ExchangeCtx) {
        *self = PakeState::InProgress(SessionData {
            start_time: epoch(),
            spake2p,
            exch_id: exch_ctx.exch.get_id(),
            peer_addr: exch_ctx.sess.get_peer_addr(),
        });
    }

    fn set_sess_data(&mut self, sd: SessionData) {
        *self = PakeState::InProgress(sd);
    }
}

impl Default for PakeState {
    fn default() -> Self {
        Self::new()
    }
}

struct Pake {
    verifier: VerifierData,
    state: PakeState,
    epoch: Epoch,
    rand: Rand,
}

impl Pake {
    pub fn new(verifier: VerifierData, epoch: Epoch, rand: Rand) -> Self {
        // TODO: Can any PBKDF2 calculation be pre-computed here
        Self {
            verifier,
            state: PakeState::new(),
            epoch,
            rand,
        }
    }

    #[allow(non_snake_case)]
    pub fn handle_pasepake3(&mut self, ctx: &mut ProtoCtx) -> Result<Option<CloneData>, Error> {
        let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?;

        let cA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?;
        let (status_code, ke) = sd.spake2p.handle_cA(cA);

        let clone_data = if status_code == SCStatusCodes::SessionEstablishmentSuccess {
            // Get the keys
            let ke = ke.ok_or(ErrorCode::Invalid)?;
            let mut session_keys: [u8; 48] = [0; 48];
            crypto::hkdf_sha256(&[], ke, &SPAKE2_SESSION_KEYS_INFO, &mut session_keys)
                .map_err(|_x| ErrorCode::NoSpace)?;

            // Create a session
            let data = sd.spake2p.get_app_data();
            let peer_sessid: u16 = (data & 0xffff) as u16;
            let local_sessid: u16 = ((data >> 16) & 0xffff) as u16;
            let mut clone_data = CloneData::new(
                0,
                0,
                peer_sessid,
                local_sessid,
                ctx.exch_ctx.sess.get_peer_addr(),
                SessionMode::Pase,
            );
            clone_data.dec_key.copy_from_slice(&session_keys[0..16]);
            clone_data.enc_key.copy_from_slice(&session_keys[16..32]);
            clone_data
                .att_challenge
                .copy_from_slice(&session_keys[32..48]);

            // Queue a transport mgr request to add a new session
            Some(clone_data)
        } else {
            None
        };

        create_sc_status_report(ctx.tx, status_code, None)?;
        ctx.exch_ctx.exch.close();
        Ok(clone_data)
    }

    #[allow(non_snake_case)]
    pub fn handle_pasepake1(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> {
        let mut sd = self.state.take_sess_data(&ctx.exch_ctx)?;

        let pA = extract_pasepake_1_or_3_params(ctx.rx.as_slice())?;
        let mut pB: [u8; 65] = [0; 65];
        let mut cB: [u8; 32] = [0; 32];
        sd.spake2p.start_verifier(&self.verifier)?;
        sd.spake2p.handle_pA(pA, &mut pB, &mut cB, self.rand)?;

        let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?);
        let resp = Pake1Resp {
            pb: OctetStr(&pB),
            cb: OctetStr(&cB),
        };
        resp.to_tlv(&mut tw, TagType::Anonymous)?;

        self.state.set_sess_data(sd);

        Ok(())
    }

    pub fn handle_pbkdfparamrequest(&mut self, ctx: &mut ProtoCtx) -> Result<(), Error> {
        if !self.state.is_idle() {
            let sd = self.state.take()?;
            if sd.is_sess_expired(self.epoch)? {
                info!("Previous session expired, clearing it");
                self.state = PakeState::Idle;
            } else {
                info!("Previous session in-progress, denying new request");
                // little-endian timeout (here we've hardcoded 500ms)
                create_sc_status_report(ctx.tx, SCStatusCodes::Busy, Some(&[0xf4, 0x01]))?;
                return Ok(());
            }
        }

        let root = tlv::get_root_node(ctx.rx.as_slice())?;
        let a = PBKDFParamReq::from_tlv(&root)?;
        if a.passcode_id != 0 {
            error!("Can't yet handle passcode_id != 0");
            Err(ErrorCode::Invalid)?;
        }

        let mut our_random: [u8; 32] = [0; 32];
        (self.rand)(&mut our_random);

        let local_sessid = ctx.exch_ctx.sess.reserve_new_sess_id();
        let spake2p_data: u32 = ((local_sessid as u32) << 16) | a.initiator_ssid as u32;
        let mut spake2p = Spake2P::new();
        spake2p.set_app_data(spake2p_data);

        // Generate response
        let mut tw = TLVWriter::new(ctx.tx.get_writebuf()?);
        let mut resp = PBKDFParamResp {
            init_random: a.initiator_random,
            our_random: OctetStr(&our_random),
            local_sessid,
            params: None,
        };
        if !a.has_params {
            let params_resp = PBKDFParamRespParams {
                count: self.verifier.count,
                salt: OctetStr(&self.verifier.salt),
            };
            resp.params = Some(params_resp);
        }
        resp.to_tlv(&mut tw, TagType::Anonymous)?;

        spake2p.set_context(ctx.rx.as_slice(), ctx.tx.as_mut_slice())?;
        self.state
            .make_in_progress(self.epoch, spake2p, &ctx.exch_ctx);

        Ok(())
    }
}

#[derive(ToTLV)]
#[tlvargs(start = 1)]
struct Pake1Resp<'a> {
    pb: OctetStr<'a>,
    cb: OctetStr<'a>,
}

#[derive(ToTLV)]
#[tlvargs(start = 1)]
struct PBKDFParamRespParams<'a> {
    count: u32,
    salt: OctetStr<'a>,
}

#[derive(ToTLV)]
#[tlvargs(start = 1)]
struct PBKDFParamResp<'a> {
    init_random: OctetStr<'a>,
    our_random: OctetStr<'a>,
    local_sessid: u16,
    params: Option<PBKDFParamRespParams<'a>>,
}

#[allow(non_snake_case)]
fn extract_pasepake_1_or_3_params(buf: &[u8]) -> Result<&[u8], Error> {
    let root = get_root_node_struct(buf)?;
    let pA = root.find_tag(1)?.slice()?;
    Ok(pA)
}

#[derive(FromTLV)]
#[tlvargs(lifetime = "'a", start = 1)]
struct PBKDFParamReq<'a> {
    initiator_random: OctetStr<'a>,
    initiator_ssid: u16,
    passcode_id: u16,
    has_params: bool,
}