darkwing/server/services/
database_encryption_services.rsuse anyhow::Context;
use base64::prelude::*;
use libaes::Cipher;
use mockall::automock;
use redis::Commands;
use sha2::{Digest, Sha256};
use sqlx::prelude::FromRow;
use std::sync::Arc;
use async_trait::async_trait;
use crate::{
cache::Cache,
config::DarkwingConfig,
database::Database,
server::error::{AppResult, Error},
};
pub type DynDatabaseEncryption = Arc<dyn DatabaseEncryptionTrait + Send + Sync>;
#[automock]
#[async_trait]
pub trait DatabaseEncryptionTrait {
async fn decrypt_string(
&self,
key_id: usize,
encrypted_string: String,
) -> AppResult<String>;
async fn decrypt_data(
&self,
key_id: usize,
encrypted_data: Data,
) -> AppResult<String>;
}
#[derive(Clone)]
pub struct DatabaseEncryptionService {
config: Arc<DarkwingConfig>,
database: Arc<Database>,
cache: Arc<Cache>,
}
#[derive(FromRow)]
struct CryptoKey {
key: Option<String>,
init_vector: Option<String>,
}
type Data = Vec<u8>;
impl DatabaseEncryptionService {
const IV_SIZE: usize = 16;
pub fn new(
database: Arc<Database>,
cache: Arc<Cache>,
config: Arc<DarkwingConfig>,
) -> Self {
Self {
database,
cache,
config,
}
}
fn check_cache(&self, key: String) -> AppResult<Option<String>> {
let mut redis = self
.cache
.pool
.get_timeout(self.config.redis_timeout())
.context("Failed to get redis pool instance")?;
redis
.get(format!("db_encryption_key:{}", key))
.context("Failed to get key from Redis")
.map_err(Error::AnyhowError)
}
fn save_cache(&self, key: String, value: String) -> AppResult<()> {
let mut redis = self
.cache
.pool
.get_timeout(self.config.redis_timeout())
.context("Failed to get redis pool instance")?;
redis
.set_ex(
format!("db_encryption_key:{}", key),
value,
self.config.redis_expiration_seconds,
)
.context("Failed to save key to Redis")
.map_err(Error::AnyhowError)
}
fn combine_keys(key1: Data, key2: Data) -> String {
let mut combined_keys = Vec::new();
combined_keys.extend_from_slice(&key1);
combined_keys.extend_from_slice(&key2);
let mut sha256 = Sha256::new();
sha256.update(combined_keys);
let sha256_result = sha256.finalize();
let sha256_hex = hex::encode(sha256_result);
let md5_result = md5::compute(sha256_hex);
hex::encode(md5_result.0)
}
fn general_decrypt(
config: Arc<DarkwingConfig>,
key: Option<Data>,
init_vector: Data,
encrypted_string: String,
) -> AppResult<Data> {
let key = Self::combine_keys(
config.database_encryption_key_part_1.as_bytes().to_vec(),
key.unwrap_or(config.database_encryption_key_part_2.as_bytes().to_vec()),
);
let key_array: [u8; 32] = key
.as_bytes()
.try_into()
.map_err(|_| anyhow::anyhow!("Failed to convert key to [u8; 32]"))?;
let cipher = Cipher::new_256(&key_array);
let decrypted = cipher.cbc_decrypt(
&init_vector,
&BASE64_STANDARD
.decode(&encrypted_string)
.context("Failed to decode base64 in decrypt")?,
);
Ok(decrypted)
}
async fn get_key(&self, key_id: usize) -> AppResult<Data> {
if let Ok(Some(cached_key)) = self.check_cache(key_id.to_string()) {
if let Ok(decoded_key) = BASE64_STANDARD.decode(&cached_key) {
return Ok(decoded_key);
}
}
let key = sqlx::query_as::<_, CryptoKey>(
"select convert(`crypto_keys`.`key`, CHAR(1000)) as `key`, HEX(crypto_keys.init_vector) as init_vector from crypto_keys where id = ?",
)
.bind(key_id as u64)
.fetch_one(&self.database.pool)
.await
.context("Failed to fetch key from database")?;
let (key, init_vector) = match (key.key, key.init_vector) {
(Some(key), Some(init_vector)) => (key, init_vector),
_ => return Err(anyhow::anyhow!("Key or init vector is none").into()),
};
let init_vector = hex::decode(init_vector)
.context("Failed to decode init vector from database")?
.to_vec();
let key =
Self::general_decrypt(self.config.clone(), None, init_vector, key)?;
let base64_key = BASE64_STANDARD.encode(&key);
let _ = self.save_cache(key_id.to_string(), base64_key);
Ok(key)
}
fn split(input: String) -> AppResult<(String, String)> {
let iv = input
.get(..16)
.context("Failed to get iv from encrypted data")?;
let encrypted_data = input
.get(Self::IV_SIZE..)
.context("Failed to get encrypted data")?;
Ok((iv.to_string(), encrypted_data.to_string()))
}
pub async fn decrypt_data(
&self,
key_id: usize,
encrypted_data: String,
) -> AppResult<String> {
let key = self.get_key(key_id).await?;
let (iv, encrypted_data) = Self::split(encrypted_data)?;
let decrypted_data = Self::general_decrypt(
self.config.clone(),
Some(key),
iv.into(),
encrypted_data,
)?;
String::from_utf8(decrypted_data).map_err(|e| {
Error::AnyhowError(anyhow::anyhow!(
"Failed to convert decrypted data to string: {}",
e
))
})
}
}
#[async_trait]
impl DatabaseEncryptionTrait for DatabaseEncryptionService {
async fn decrypt_string(
&self,
key_id: usize,
encrypted_string: String,
) -> AppResult<String> {
self.decrypt_data(key_id, encrypted_string).await
}
async fn decrypt_data(
&self,
key_id: usize,
encrypted_data: Data,
) -> AppResult<String> {
let encrypted_string = BASE64_STANDARD.encode(&encrypted_data);
self.decrypt_data(key_id, encrypted_string).await
}
}
#[cfg(test)]
mod tests {
#![allow(unsafe_code)]
use super::*;
const IV_SIZE: usize = 16;
#[test]
fn test_decrypt_data() {
let mut config = DarkwingConfig::default();
config.database_encryption_key_part_1 =
"qpe9IKZ5MrwJRnAL1fY6YTxKalUs8doUeV9O0v266dcL0NhNc5gUJVtAfW6V3unj".into();
config.database_encryption_key_part_2 = "brmSRBfC2gni6Jxc2TVonrC3gwrJ2LA9SFntshhN2iVF5nsZp5u9MZY4wcxEknHGkd3Yqy59Fn97j4V6bd2mfxYV96299FMgE6484Kx3AbiGgfP4Byt32o7uq27HscfB".into();
let message = "фыва";
let encrypted_message = "3956740162767105+M70abcv2/HeLXq7ZlU6JA==";
let message_iv = encrypted_message.get(..IV_SIZE).unwrap();
let message_data = encrypted_message.get(IV_SIZE..).unwrap();
let intermediate_key_encrypted = hex::decode("6a4d38706e5667654841345974686b4d564c4161596d426352645571334942554530306d4b41796b4a4754733843344e6570344547467048507670726b61493248666d59774c61397235766b426972315a7a693050757675544c7a4e51695249463259527151443171364f6e624741394874354d6d66476c2f724e6e663439754b555431497031715235514f59636b2f7a6247734a5370653463595132413456482b323575775135714f5a743457334f51514248437867466265626a62733457").unwrap();
let intermediate_key_iv =
hex::decode("a8fe85edd6b2da6c01da5eb37ad606f2").unwrap();
let intermediate_key = DatabaseEncryptionService::general_decrypt(
Arc::new(config.clone()),
None,
intermediate_key_iv,
unsafe { String::from_utf8_unchecked(intermediate_key_encrypted) },
);
assert!(intermediate_key.is_ok());
let intermediate_key = intermediate_key.unwrap();
assert_eq!(intermediate_key.len(), 128);
let decrypted_message = DatabaseEncryptionService::general_decrypt(
Arc::new(config.clone()),
Some(intermediate_key),
message_iv.as_bytes().to_vec(),
message_data.to_string(),
);
assert!(decrypted_message.is_ok());
let decrypted_message = decrypted_message.unwrap();
assert_eq!(message, String::from_utf8(decrypted_message).unwrap());
}
#[test]
fn test_combine_keys() {
let key1 =
"qpe9IKZ5MrwJRnAL1fY6YTxKalUs8doUeV9O0v266dcL0NhNc5gUJVtAfW6V3unj".into();
let key2 = "brmSRBfC2gni6Jxc2TVonrC3gwrJ2LA9SFntshhN2iVF5nsZp5u9MZY4wcxEknHGkd3Yqy59Fn97j4V6bd2mfxYV96299FMgE6484Kx3AbiGgfP4Byt32o7uq27HscfB".into();
let combined_key = DatabaseEncryptionService::combine_keys(key1, key2);
assert_eq!(combined_key.len(), 32);
assert_eq!(combined_key, "a900095d5dff10d6ca818856c5cabfaf");
}
#[test]
fn test_split() {
let encrypted_message = "3956740162767105+M70abcv2/HeLXq7ZlU6JA==";
let (iv, encrypted_data) =
DatabaseEncryptionService::split(encrypted_message.to_string()).unwrap();
assert_eq!(iv.len(), 16);
assert_eq!(encrypted_data.len(), 24);
assert_eq!(iv, "3956740162767105");
assert_eq!(encrypted_data, "+M70abcv2/HeLXq7ZlU6JA==");
}
}