diff --git a/src/header.rs b/src/header.rs index ca695b0..a79ed35 100644 --- a/src/header.rs +++ b/src/header.rs @@ -18,6 +18,8 @@ pub struct Header { #[derive(Serialize, Deserialize, Debug)] struct ExHeader { + #[serde(with = "serde_bytes")] + ad: Vec, public_key: [u8; 32], pn: usize, n: usize @@ -34,8 +36,9 @@ impl Header { } } // #[doc(hidden)] - pub fn concat(&self) -> Vec { - let ex_header = ExHeader{ + pub fn concat(&self, ad: &[u8]) -> Vec { + let ex_header = ExHeader { + ad: ad.to_vec(), public_key: self.public_key.to_bytes(), pn: self.pn, n: self.n @@ -43,8 +46,8 @@ impl Header { bincode::serialize(&ex_header).expect("Failed to serialize Header") } - pub fn encrypt(&self, hk: &[u8; 32]) -> (Vec, [u8; 12]) { - let header_data = self.concat(); + pub fn encrypt(&self, hk: &[u8; 32], ad: &[u8]) -> (Vec, [u8; 12]) { + let header_data = self.concat(ad); encrypt(hk, &header_data, b"") } @@ -95,7 +98,7 @@ impl From<&[u8]> for Header { impl From
for Vec { fn from(s: Header) -> Self { - s.concat() + s.concat(b"") } } @@ -126,8 +129,9 @@ mod tests { #[test] fn ser_des() { + let ad = b""; let header = gen_header(); - let serialized = header.concat(); + let serialized = header.concat(ad); let created = Header::from(serialized); assert_eq!(header, created) } @@ -136,7 +140,7 @@ mod tests { fn enc_header() { let header = gen_header(); let mk = gen_mk(); - let header_data = header.concat(); + let header_data = header.concat(b""); let data = include_bytes!("aead.rs"); let (encrypted, nonce) = encrypt(&mk, data, &header_data); let decrypted = decrypt(&mk, &encrypted, &header_data, &nonce); diff --git a/src/lib.rs b/src/lib.rs index 1565c90..aedec8e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -17,9 +17,10 @@ //! let (mut bob_ratchet, public_key) = Ratchet::init_bob(sk); // Creating Bobs Ratchet (returns Bobs PublicKey) //! let mut alice_ratchet = Ratchet::init_alice(sk, public_key); // Creating Alice Ratchet with Bobs PublicKey //! let data = b"Hello World".to_vec(); // Data to be encrypted +//! let ad = b"Associated Data"; // Associated Data //! -//! let (header, encrypted, nonce) = alice_ratchet.ratchet_encrypt(&data); // Encrypting message with Alice Ratchet (Alice always needs to send the first message) -//! let decrypted = bob_ratchet.ratchet_decrypt(&header, &encrypted, &nonce); // Decrypt message with Bobs Ratchet +//! let (header, encrypted, nonce) = alice_ratchet.ratchet_encrypt(&data, ad); // Encrypting message with Alice Ratchet (Alice always needs to send the first message) +//! let decrypted = bob_ratchet.ratchet_decrypt(&header, &encrypted, &nonce, ad); // Decrypt message with Bobs Ratchet //! assert_eq!(data, decrypted) //! ``` //! @@ -31,12 +32,13 @@ //! let (mut bob_ratchet, public_key) = Ratchet::init_bob(sk); // Creating Bobs Ratchet (returns Bobs PublicKey) //! let mut alice_ratchet = Ratchet::init_alice(sk, public_key); // Creating Alice Ratchet with Bobs PublicKey //! let data = b"Hello World".to_vec(); // Data to be encrypted +//! let ad = b"Associated Data"; // Associated Data //! -//! let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data); // Lost message -//! let (header2, encrypted2, nonce2) = alice_ratchet.ratchet_encrypt(&data); // Successful message +//! let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data, ad); // Lost message +//! let (header2, encrypted2, nonce2) = alice_ratchet.ratchet_encrypt(&data, ad); // Successful message //! -//! let decrypted2 = bob_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2); // Decrypting second message first -//! let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1); // Decrypting latter message +//! let decrypted2 = bob_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2, ad); // Decrypting second message first +//! let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1, ad); // Decrypting latter message //! //! let comp = decrypted1 == data && decrypted2 == data; //! assert!(comp); @@ -47,11 +49,11 @@ //! ```should_panic //! use double_ratchet_2::ratchet::Ratchet; //! let sk = [1; 32]; -//! +//! let ad = b"Associated Data"; //! let (mut bob_ratchet, _) = Ratchet::init_bob(sk); //! let data = b"Hello World".to_vec(); //! -//! let (_, _, _) = bob_ratchet.ratchet_encrypt(&data); +//! let (_, _, _) = bob_ratchet.ratchet_encrypt(&data, ad); //! ``` //! //! ## Encryption after recieving initial message @@ -65,12 +67,13 @@ //! let mut alice_ratchet = Ratchet::init_alice(sk, public_key); //! //! let data = b"Hello World".to_vec(); +//! let ad = b"Associated Data"; //! -//! let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data); -//! let _decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1); +//! let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data, ad); +//! let _decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1, ad); //! -//! let (header2, encrypted2, nonce2) = bob_ratchet.ratchet_encrypt(&data); -//! let decrypted2 = alice_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2); +//! let (header2, encrypted2, nonce2) = bob_ratchet.ratchet_encrypt(&data, ad); +//! let decrypted2 = alice_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2, ad); //! //! assert_eq!(data, decrypted2); //! ``` @@ -83,7 +86,8 @@ //! # let (mut bob_ratchet, public_key) = Ratchet::init_bob(sk); //! # let mut alice_ratchet = Ratchet::init_alice(sk, public_key); //! # let data = b"hello World".to_vec(); -//! # let (header, _, _) = alice_ratchet.ratchet_encrypt(&data); +//! # let ad = b"Associated Data"; +//! # let (header, _, _) = alice_ratchet.ratchet_encrypt(&data, ad); //! let header_bytes: Vec = header.clone().into(); //! let header_const = Header::from(header_bytes); //! assert_eq!(header, header_const); @@ -100,9 +104,10 @@ //! let (mut bob_ratchet, public_key) = RatchetEncHeader::init_bob(sk, shared_hka, shared_nhkb); //! let mut alice_ratchet = RatchetEncHeader::init_alice(sk, public_key, shared_hka, shared_nhkb); //! let data = b"Hello World".to_vec(); +//! let ad = b"Associated Data"; //! -//! let (header, encrypted, nonce) = alice_ratchet.ratchet_encrypt(&data); -//! let decrypted = bob_ratchet.ratchet_decrypt(&header, &encrypted, &nonce); +//! let (header, encrypted, nonce) = alice_ratchet.ratchet_encrypt(&data, ad); +//! let decrypted = bob_ratchet.ratchet_decrypt(&header, &encrypted, &nonce, ad); //! assert_eq!(data, decrypted) //! ``` //! diff --git a/src/ratchet.rs b/src/ratchet.rs index 6f73526..4b31a28 100644 --- a/src/ratchet.rs +++ b/src/ratchet.rs @@ -65,21 +65,21 @@ impl Ratchet { } /// Encrypt Plaintext with [Ratchet]. Returns Message [Header] and ciphertext. - pub fn ratchet_encrypt(&mut self, plaintext: &[u8]) -> (Header, Vec, [u8; 12]) { + pub fn ratchet_encrypt(&mut self, plaintext: &[u8], ad: &[u8]) -> (Header, Vec, [u8; 12]) { let (cks, mk) = kdf_ck(&self.cks.unwrap()); self.cks = Some(cks); let header = Header::new(&self.dhs, self.pn, self.ns); self.ns += 1; - let (encrypted_data, nonce) = encrypt(&mk, plaintext, &header.concat()); + let (encrypted_data, nonce) = encrypt(&mk, plaintext, &header.concat(ad)); (header, encrypted_data, nonce) } - fn try_skipped_message_keys(&mut self, header: &Header, ciphertext: &[u8], nonce: &[u8; 12]) -> Option> { + fn try_skipped_message_keys(&mut self, header: &Header, ciphertext: &[u8], nonce: &[u8; 12], ad: &[u8]) -> Option> { if self.mkskipped.contains_key(&(header.public_key, header.n)) { let mk = *self.mkskipped.get(&(header.public_key, header.n)) .unwrap(); self.mkskipped.remove(&(header.public_key, header.n)).unwrap(); - Some(decrypt(&mk, ciphertext, &header.concat(), nonce)) + Some(decrypt(&mk, ciphertext, &header.concat(ad), nonce)) } else { None } @@ -104,8 +104,8 @@ impl Ratchet { } /// Decrypt ciphertext with ratchet. Requires Header. Returns plaintext. - pub fn ratchet_decrypt(&mut self, header: &Header, ciphertext: &[u8], nonce: &[u8; 12]) -> Vec { - let plaintext = self.try_skipped_message_keys(header, ciphertext, nonce); + pub fn ratchet_decrypt(&mut self, header: &Header, ciphertext: &[u8], nonce: &[u8; 12], ad: &[u8]) -> Vec { + let plaintext = self.try_skipped_message_keys(header, ciphertext, nonce, ad); match plaintext { Some(d) => d, None => { @@ -119,7 +119,7 @@ impl Ratchet { let (ckr, mk) = kdf_ck(&self.ckr.unwrap()); self.ckr = Some(ckr); self.nr += 1; - decrypt(&mk, ciphertext, &header.concat(), nonce) + decrypt(&mk, ciphertext, &header.concat(ad), nonce) } } } @@ -202,18 +202,18 @@ impl RatchetEncHeader { (ratchet, public_key) } - pub fn ratchet_encrypt(&mut self, plaintext: &[u8]) -> HeaderNonceCipherNonce { + pub fn ratchet_encrypt(&mut self, plaintext: &[u8], ad: &[u8]) -> HeaderNonceCipherNonce { let (cks, mk) = kdf_ck(&self.cks.unwrap()); self.cks = Some(cks); let header = Header::new(&self.dhs, self.pn, self.ns); - let enc_header = header.encrypt(&self.hks.unwrap()); + let enc_header = header.encrypt(&self.hks.unwrap(), ad); self.ns += 1; - let encrypted = encrypt(&mk, plaintext, &header.concat()); + let encrypted = encrypt(&mk, plaintext, &header.concat(ad)); (enc_header, encrypted.0, encrypted.1) } fn try_skipped_message_keys(&mut self, enc_header: &(Vec, [u8; 12]), - ciphertext: &[u8], nonce: &[u8; 12]) -> Option> { + ciphertext: &[u8], nonce: &[u8; 12], ad: &[u8]) -> Option> { let ret_data = self.mkskipped.clone().into_iter().find(|e| { let header = Header::decrypt(&e.0.0, &enc_header.0, &enc_header.1); @@ -228,7 +228,7 @@ impl RatchetEncHeader { let header = Header::decrypt(&data.0.0, &enc_header.0, &enc_header.1); let mk = data.1; self.mkskipped.remove(&(data.0.0, data.0.1)); - Some(decrypt(&mk, ciphertext, &header.unwrap().concat(), nonce)) + Some(decrypt(&mk, ciphertext, &header.unwrap().concat(ad), nonce)) } } } @@ -278,8 +278,8 @@ impl RatchetEncHeader { self.nhks = Some(nhks); } - pub fn ratchet_decrypt(&mut self, enc_header: &(Vec, [u8; 12]), ciphertext: &[u8], nonce: &[u8; 12]) -> Vec { - let plaintext = self.try_skipped_message_keys(enc_header, ciphertext, nonce); + pub fn ratchet_decrypt(&mut self, enc_header: &(Vec, [u8; 12]), ciphertext: &[u8], nonce: &[u8; 12], ad: &[u8]) -> Vec { + let plaintext = self.try_skipped_message_keys(enc_header, ciphertext, nonce, ad); if let Some(d) = plaintext { return d }; let (header, dh_ratchet) = self.decrypt_header(enc_header).unwrap(); if dh_ratchet { @@ -290,6 +290,6 @@ impl RatchetEncHeader { let (ckr, mk) = kdf_ck(&self.ckr.unwrap()); self.ckr = Some(ckr); self.nr += 1; - decrypt(&mk, ciphertext, &header.concat(), nonce) + decrypt(&mk, ciphertext, &header.concat(ad), nonce) } } \ No newline at end of file diff --git a/tests/mod.rs b/tests/mod.rs index 613779b..3a1073c 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -13,8 +13,8 @@ fn ratchet_enc_single() { let (mut bob_ratchet, public_key) = Ratchet::init_bob(sk); let mut alice_ratchet = Ratchet::init_alice(sk, public_key); let data = include_bytes!("../src/header.rs").to_vec(); - let (header, encrypted, nonce) = alice_ratchet.ratchet_encrypt(&data); - let decrypted = bob_ratchet.ratchet_decrypt(&header, &encrypted, &nonce); + let (header, encrypted, nonce) = alice_ratchet.ratchet_encrypt(&data, b""); + let decrypted = bob_ratchet.ratchet_decrypt(&header, &encrypted, &nonce, b""); assert_eq!(data, decrypted) } @@ -24,10 +24,10 @@ fn ratchet_enc_skip() { let (mut bob_ratchet, public_key) = Ratchet::init_bob(sk); let mut alice_ratchet = Ratchet::init_alice(sk, public_key); let data = include_bytes!("../src/header.rs").to_vec(); - let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data); - let (header2, encrypted2, nonce2) = alice_ratchet.ratchet_encrypt(&data); - let decrypted2 = bob_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2); - let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1); + let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data, b""); + let (header2, encrypted2, nonce2) = alice_ratchet.ratchet_encrypt(&data, b""); + let decrypted2 = bob_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2, b""); + let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1, b""); let comp_res = decrypted1 == data && decrypted2 == data; assert!(comp_res) } @@ -38,7 +38,7 @@ fn ratchet_panic_bob() { let sk = [1; 32]; let (mut bob_ratchet, _) = Ratchet::init_bob(sk); let data = include_bytes!("../src/header.rs").to_vec(); - let (_, _, _) = bob_ratchet.ratchet_encrypt(&data); + let (_, _, _) = bob_ratchet.ratchet_encrypt(&data, b""); } #[test] @@ -47,10 +47,10 @@ fn ratchet_encryt_decrypt_four() { let data = include_bytes!("../src/dh.rs").to_vec(); let (mut bob_ratchet, public_key) = Ratchet::init_bob(sk); let mut alice_ratchet = Ratchet::init_alice(sk, public_key); - let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data); - let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1); - let (header2, encrypted2, nonce2) = bob_ratchet.ratchet_encrypt(&data); - let decrypted2 = alice_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2); + let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data, b""); + let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1, b""); + let (header2, encrypted2, nonce2) = bob_ratchet.ratchet_encrypt(&data, b""); + let decrypted2 = alice_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2, b""); let comp_res = decrypted1 == data && decrypted2 == data; assert!(comp_res) } @@ -80,8 +80,8 @@ fn ratchet_ench_enc_single() { shared_hka, shared_nhkb); let data = include_bytes!("../src/header.rs").to_vec(); - let (header, encrypted, nonce) = alice_ratchet.ratchet_encrypt(&data); - let decrypted = bob_ratchet.ratchet_decrypt(&header, &encrypted, &nonce); + let (header, encrypted, nonce) = alice_ratchet.ratchet_encrypt(&data, b""); + let decrypted = bob_ratchet.ratchet_decrypt(&header, &encrypted, &nonce, b""); assert_eq!(data, decrypted) } @@ -98,10 +98,10 @@ fn ratchet_ench_enc_skip() { shared_hka, shared_nhkb); let data = include_bytes!("../src/header.rs").to_vec(); - let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data); - let (header2, encrypted2, nonce2) = alice_ratchet.ratchet_encrypt(&data); - let decrypted2 = bob_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2); - let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1); + let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data, b""); + let (header2, encrypted2, nonce2) = alice_ratchet.ratchet_encrypt(&data, b""); + let decrypted2 = bob_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2, b""); + let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1, b""); let comp_res = decrypted1 == data && decrypted2 == data; assert!(comp_res) } @@ -116,7 +116,7 @@ fn ratchet_ench_panic_bob() { shared_hka, shared_nhkb); let data = include_bytes!("../src/header.rs").to_vec(); - let (_, _, _) = bob_ratchet.ratchet_encrypt(&data); + let (_, _, _) = bob_ratchet.ratchet_encrypt(&data, b""); } #[test] @@ -129,10 +129,10 @@ fn ratchet_ench_decrypt_four() { shared_nhkb); let mut alice_ratchet = RatchetEncHeader::init_alice(sk, public_key, shared_hka, shared_nhkb); let data = include_bytes!("../src/dh.rs").to_vec(); - let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data); - let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1); - let (header2, encrypted2, nonce2) = bob_ratchet.ratchet_encrypt(&data); - let decrypted2 = alice_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2); + let (header1, encrypted1, nonce1) = alice_ratchet.ratchet_encrypt(&data, b""); + let decrypted1 = bob_ratchet.ratchet_decrypt(&header1, &encrypted1, &nonce1, b""); + let (header2, encrypted2, nonce2) = bob_ratchet.ratchet_encrypt(&data, b""); + let decrypted2 = alice_ratchet.ratchet_decrypt(&header2, &encrypted2, &nonce2, b""); let comp_res = decrypted1 == data && decrypted2 == data; assert!(comp_res) } \ No newline at end of file