crypto/
sha2.rs

1// SPDX-FileCopyrightText: 2024 Foundation Devices, Inc. <hello@foundation.xyz>
2// SPDX-License-Identifier: GPL-3.0-or-later
3
4#[cfg(keyos)]
5use server::xous::MemoryRange;
6use server::{CheckedConn, CheckedPermissions, MessageAllowed};
7
8use crate::error::CryptoError;
9use crate::messages::{ShaDrop, ShaGetContext, ShaSetContext, ShaUpdate};
10use crate::CryptoApi;
11
12pub trait ShaPermissions:
13    CheckedPermissions
14    + MessageAllowed<ShaSetContext>
15    + MessageAllowed<ShaUpdate>
16    + MessageAllowed<ShaGetContext>
17    + MessageAllowed<ShaDrop>
18{
19}
20
21impl<P> ShaPermissions for P where
22    P: CheckedPermissions
23        + MessageAllowed<ShaSetContext>
24        + MessageAllowed<ShaUpdate>
25        + MessageAllowed<ShaGetContext>
26        + MessageAllowed<ShaDrop>
27{
28}
29
30pub const SHA224_HASH_SIZE: usize = 28;
31pub const SHA256_HASH_SIZE: usize = 32;
32pub const SHA384_HASH_SIZE: usize = 48;
33pub const SHA512_HASH_SIZE: usize = 64;
34
35#[cfg(keyos)]
36const SW_THRESHOLD: usize = 0x1000;
37
38#[cfg(keyos)]
39const SCRATCH_CAP: usize = 128 * 1024;
40
41#[derive(Debug, Clone, Copy, rkyv::Archive, rkyv::Serialize, rkyv::Deserialize)]
42pub enum ShaAlgo {
43    Sha224,
44    Sha256,
45    Sha384,
46    Sha512,
47}
48
49impl From<usize> for ShaAlgo {
50    fn from(value: usize) -> Self {
51        match value {
52            0 => ShaAlgo::Sha224,
53            1 => ShaAlgo::Sha256,
54            2 => ShaAlgo::Sha384,
55            3 => ShaAlgo::Sha512,
56            _ => unreachable!(),
57        }
58    }
59}
60
61impl From<ShaAlgo> for usize {
62    fn from(value: ShaAlgo) -> Self {
63        match value {
64            ShaAlgo::Sha224 => 0,
65            ShaAlgo::Sha256 => 1,
66            ShaAlgo::Sha384 => 2,
67            ShaAlgo::Sha512 => 3,
68        }
69    }
70}
71
72impl ShaAlgo {
73    pub fn block_size(self) -> usize {
74        match self {
75            ShaAlgo::Sha224 | ShaAlgo::Sha256 => 64,
76            ShaAlgo::Sha384 | ShaAlgo::Sha512 => 128,
77        }
78    }
79
80    pub fn hash_size(self) -> usize {
81        match self {
82            ShaAlgo::Sha224 => SHA224_HASH_SIZE,
83            ShaAlgo::Sha256 => SHA256_HASH_SIZE,
84            ShaAlgo::Sha384 => SHA384_HASH_SIZE,
85            ShaAlgo::Sha512 => SHA512_HASH_SIZE,
86        }
87    }
88
89    pub fn initial_hash_state(self) -> [u8; 64] {
90        match self {
91            ShaAlgo::Sha224 => [
92                0xc1, 0x05, 0x9e, 0xd8, 0x36, 0x7c, 0xd5, 0x07, 0x30, 0x70, 0xdd, 0x17, 0xf7, 0x0e, 0x59,
93                0x39, 0xff, 0xc0, 0x0b, 0x31, 0x68, 0x58, 0x15, 0x11, 0x64, 0xf9, 0x8f, 0xa7, 0xbe, 0xfa,
94                0x4f, 0xa4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
95                0, 0, 0, 0, 0,
96            ],
97            ShaAlgo::Sha256 => [
98                0x6a, 0x09, 0xe6, 0x67, 0xbb, 0x67, 0xae, 0x85, 0x3c, 0x6e, 0xf3, 0x72, 0xa5, 0x4f, 0xf5,
99                0x3a, 0x51, 0x0e, 0x52, 0x7f, 0x9b, 0x05, 0x68, 0x8c, 0x1f, 0x83, 0xd9, 0xab, 0x5b, 0xe0,
100                0xcd, 0x19, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
101                0, 0, 0, 0, 0,
102            ],
103            ShaAlgo::Sha384 => [
104                0xcb, 0xbb, 0x9d, 0x5d, 0xc1, 0x05, 0x9e, 0xd8, 0x62, 0x9a, 0x29, 0x2a, 0x36, 0x7c, 0xd5,
105                0x07, 0x91, 0x59, 0x01, 0x5a, 0x30, 0x70, 0xdd, 0x17, 0x15, 0x2f, 0xec, 0xd8, 0xf7, 0x0e,
106                0x59, 0x39, 0x67, 0x33, 0x26, 0x67, 0xff, 0xc0, 0x0b, 0x31, 0x8e, 0xb4, 0x4a, 0x87, 0x68,
107                0x58, 0x15, 0x11, 0xdb, 0x0c, 0x2e, 0x0d, 0x64, 0xf9, 0x8f, 0xa7, 0x47, 0xb5, 0x48, 0x1d,
108                0xbe, 0xfa, 0x4f, 0xa4,
109            ],
110            ShaAlgo::Sha512 => [
111                0x6a, 0x09, 0xe6, 0x67, 0xf3, 0xbc, 0xc9, 0x08, 0xbb, 0x67, 0xae, 0x85, 0x84, 0xca, 0xa7,
112                0x3b, 0x3c, 0x6e, 0xf3, 0x72, 0xfe, 0x94, 0xf8, 0x2b, 0xa5, 0x4f, 0xf5, 0x3a, 0x5f, 0x1d,
113                0x36, 0xf1, 0x51, 0x0e, 0x52, 0x7f, 0xad, 0xe6, 0x82, 0xd1, 0x9b, 0x05, 0x68, 0x8c, 0x2b,
114                0x3e, 0x6c, 0x1f, 0x1f, 0x83, 0xd9, 0xab, 0xfb, 0x41, 0xbd, 0x6b, 0x5b, 0xe0, 0xcd, 0x19,
115                0x13, 0x7e, 0x21, 0x79,
116            ],
117        }
118    }
119}
120
121pub struct ShaStreamingContext<P: ShaPermissions> {
122    #[cfg_attr(not(keyos), allow(dead_code))]
123    conn: CheckedConn<P>,
124    algo: ShaAlgo,
125    accumulator: Vec<u8>,
126    bytes_compressed: u64,
127    hash_state: [u8; 64],
128    #[cfg(keyos)]
129    server_id: Option<usize>,
130    #[cfg(keyos)]
131    server_authoritative: bool,
132    #[cfg(keyos)]
133    scratch: Option<xous::DropDeallocate>,
134}
135
136pub type Sha256StreamingContext<P> = ShaStreamingContext<P>;
137
138impl<P: ShaPermissions> CryptoApi<P> {
139    pub fn sha2(&self, data: &[u8], algo: ShaAlgo) -> Result<Vec<u8>, CryptoError> {
140        let mut ctx = self.sha_init(algo);
141        ctx.update(data)?;
142        ctx.finalize()
143    }
144
145    pub fn sha224(&self, data: &[u8]) -> Result<[u8; SHA224_HASH_SIZE], CryptoError> {
146        Ok(self.sha2(data, ShaAlgo::Sha224)?.try_into().unwrap())
147    }
148
149    pub fn sha256(&self, data: &[u8]) -> Result<[u8; SHA256_HASH_SIZE], CryptoError> {
150        Ok(self.sha2(data, ShaAlgo::Sha256)?.try_into().unwrap())
151    }
152
153    pub fn sha384(&self, data: &[u8]) -> Result<[u8; SHA384_HASH_SIZE], CryptoError> {
154        Ok(self.sha2(data, ShaAlgo::Sha384)?.try_into().unwrap())
155    }
156
157    pub fn sha512(&self, data: &[u8]) -> Result<[u8; SHA512_HASH_SIZE], CryptoError> {
158        Ok(self.sha2(data, ShaAlgo::Sha512)?.try_into().unwrap())
159    }
160
161    pub fn sha_init(&self, algo: ShaAlgo) -> ShaStreamingContext<P> {
162        ShaStreamingContext {
163            conn: self.conn.clone(),
164            algo,
165            accumulator: Vec::new(),
166            bytes_compressed: 0,
167            hash_state: algo.initial_hash_state(),
168            #[cfg(keyos)]
169            server_id: None,
170            #[cfg(keyos)]
171            server_authoritative: false,
172            #[cfg(keyos)]
173            scratch: None,
174        }
175    }
176
177    pub fn sha256_init(&self) -> ShaStreamingContext<P> { self.sha_init(ShaAlgo::Sha256) }
178}
179
180impl<P: ShaPermissions> ShaStreamingContext<P> {
181    pub fn hash_size(&self) -> usize { self.algo.hash_size() }
182
183    pub fn update(&mut self, data: &[u8]) -> Result<(), CryptoError> {
184        if data.is_empty() {
185            return Ok(());
186        }
187        #[cfg(keyos)]
188        if self.accumulator.len() + data.len() >= SW_THRESHOLD {
189            self.push_state_to_server()?;
190            self.update_hw(data)?;
191            return Ok(());
192        }
193        self.update_sw(data)
194    }
195
196    fn update_sw(&mut self, data: &[u8]) -> Result<(), CryptoError> {
197        let bs = self.algo.block_size();
198        self.accumulator.extend_from_slice(data);
199        let num_blocks = self.accumulator.len() / bs;
200        if num_blocks > 0 {
201            #[cfg(keyos)]
202            self.fetch_server_state()?;
203            let blocks_end = num_blocks * bs;
204            sw_compress_blocks(&mut self.hash_state, self.algo, &self.accumulator[..blocks_end]);
205            self.accumulator.drain(..blocks_end);
206            self.bytes_compressed += blocks_end as u64;
207        }
208        Ok(())
209    }
210
211    #[cfg_attr(not(keyos), allow(unused_mut))]
212    pub fn finalize(mut self) -> Result<Vec<u8>, CryptoError> {
213        #[cfg(keyos)]
214        self.fetch_server_state()?;
215
216        let total_bits = (self.bytes_compressed + self.accumulator.len() as u64) * 8;
217        let pad = sha_padding(&self.accumulator, self.algo, total_bits);
218        sw_compress_blocks(&mut self.hash_state, self.algo, &pad);
219
220        Ok(self.hash_state[..self.algo.hash_size()].to_vec())
221    }
222
223    #[cfg(keyos)]
224    fn push_state_to_server(&mut self) -> Result<(), CryptoError> {
225        if !self.server_authoritative {
226            let id = self.conn.send_blocking_archive(ShaSetContext {
227                context_id: self.server_id,
228                algo: self.algo,
229                hash_state: self.hash_state,
230            })?;
231            self.server_id = Some(id);
232            self.server_authoritative = true;
233        }
234        Ok(())
235    }
236
237    #[cfg(keyos)]
238    fn fetch_server_state(&mut self) -> Result<(), CryptoError> {
239        if self.server_authoritative {
240            let snap =
241                self.conn.send_blocking_archive(ShaGetContext { context_id: self.server_id.unwrap() })?;
242            self.hash_state = snap.hash_state;
243            self.server_authoritative = false;
244        }
245        Ok(())
246    }
247
248    #[cfg(keyos)]
249    fn scratch_range(&mut self) -> Result<MemoryRange, CryptoError> {
250        if self.scratch.is_none() {
251            let mem = xous::map_memory(None, None, SCRATCH_CAP, xous::MemoryFlags::W)?;
252            self.scratch = Some(xous::DropDeallocate::new(mem));
253        }
254        Ok(**self.scratch.as_ref().unwrap())
255    }
256
257    #[cfg(keyos)]
258    fn update_hw(&mut self, data: &[u8]) -> Result<(), CryptoError> {
259        use xous::keyos::PAGE_SIZE;
260
261        let bs = self.algo.block_size();
262        // Fast path: directly lend the page-aligned prefix to HW, accumulate the tail in SW.
263        if self.accumulator.is_empty() && (data.as_ptr() as usize) % PAGE_SIZE == 0 && data.len() >= PAGE_SIZE
264        {
265            let hw_len = (data.len() / PAGE_SIZE) * PAGE_SIZE;
266            let mr = unsafe { MemoryRange::new(data.as_ptr() as usize, hw_len)? };
267            self.conn.lend_mut(ShaUpdate { context_id: self.server_id.unwrap(), buf: mr, length: hw_len })?;
268            self.bytes_compressed += hw_len as u64;
269            self.update_sw(&data[hw_len..])?;
270            return Ok(());
271        }
272
273        let mut scratch = self.scratch_range()?;
274
275        let mut remaining = data;
276        while self.accumulator.len() + remaining.len() >= bs {
277            let acc_len = self.accumulator.len();
278            let from_data = (SCRATCH_CAP - acc_len).min(remaining.len());
279            let send_len = (acc_len + from_data) / bs * bs;
280
281            scratch.as_slice_mut::<u8>()[..acc_len].copy_from_slice(&self.accumulator);
282            scratch.as_slice_mut::<u8>()[acc_len..send_len].copy_from_slice(&remaining[..send_len - acc_len]);
283
284            remaining = &remaining[send_len - acc_len..];
285            self.accumulator.clear();
286
287            self.conn.lend_mut(ShaUpdate {
288                context_id: self.server_id.unwrap(),
289                buf: scratch.subrange(0, send_len.next_multiple_of(PAGE_SIZE)).unwrap(),
290                length: send_len,
291            })?;
292            self.bytes_compressed += send_len as u64;
293        }
294        self.accumulator = remaining.to_vec();
295
296        Ok(())
297    }
298}
299
300impl<P: ShaPermissions> Drop for ShaStreamingContext<P> {
301    fn drop(&mut self) {
302        #[cfg(keyos)]
303        if let Some(id) = self.server_id {
304            self.conn.try_send_scalar(ShaDrop(id)).ok();
305        }
306    }
307}
308
309fn sw_compress_blocks(hash_state: &mut [u8; 64], algo: ShaAlgo, data: &[u8]) {
310    match algo {
311        ShaAlgo::Sha224 | ShaAlgo::Sha256 => {
312            let mut state = [0u32; 8];
313            for (i, chunk) in hash_state[..32].chunks_exact(4).enumerate() {
314                state[i] = u32::from_be_bytes(chunk.try_into().unwrap());
315            }
316            for block in data.chunks_exact(64) {
317                sha2::compress256(&mut state, core::slice::from_ref(block.try_into().unwrap()));
318            }
319            for (i, w) in state.iter().enumerate() {
320                hash_state[i * 4..i * 4 + 4].copy_from_slice(&w.to_be_bytes());
321            }
322        }
323        ShaAlgo::Sha384 | ShaAlgo::Sha512 => {
324            let mut state = [0u64; 8];
325            for (i, chunk) in hash_state[..64].chunks_exact(8).enumerate() {
326                state[i] = u64::from_be_bytes(chunk.try_into().unwrap());
327            }
328            for block in data.chunks_exact(128) {
329                sha2::compress512(&mut state, core::slice::from_ref(block.try_into().unwrap()));
330            }
331            for (i, w) in state.iter().enumerate() {
332                hash_state[i * 8..i * 8 + 8].copy_from_slice(&w.to_be_bytes());
333            }
334        }
335    }
336}
337
338fn sha_padding(acc: &[u8], algo: ShaAlgo, total_bits: u64) -> Vec<u8> {
339    let (bs, len_field): (usize, usize) = match algo {
340        ShaAlgo::Sha224 | ShaAlgo::Sha256 => (64, 8),
341        ShaAlgo::Sha384 | ShaAlgo::Sha512 => (128, 16),
342    };
343
344    let mut pad = Vec::with_capacity(bs * 2);
345    pad.extend_from_slice(acc);
346    pad.push(0x80);
347
348    let len_mod = (pad.len() + len_field) % bs;
349    let zeros = if len_mod == 0 { 0 } else { bs - len_mod };
350    pad.extend(core::iter::repeat(0u8).take(zeros));
351
352    if len_field == 8 {
353        pad.extend_from_slice(&total_bits.to_be_bytes());
354    } else {
355        pad.extend_from_slice(&(total_bits as u128).to_be_bytes());
356    }
357
358    debug_assert_eq!(pad.len() % bs, 0);
359    pad
360}