use std::borrow::Cow;
use std::{collections::HashMap, fmt::Debug};
use axum::extract::rejection::JsonRejection;
use axum::response::Response;
use axum::{http::StatusCode, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
use serde_json::json;
use thiserror::Error;
use tracing::log::error;
use validator::{ValidationErrors, ValidationErrorsKind};
#[derive(Debug, Deserialize, Serialize)]
pub struct ApiError {
pub errors: HashMap<String, Vec<String>>,
}
impl ApiError {
pub fn new(error: String) -> Self {
let mut error_map: HashMap<String, Vec<String>> = HashMap::new();
error_map.insert("error".to_owned(), vec![error]);
Self { errors: error_map }
}
pub fn from_map(errors: HashMap<String, Vec<String>>) -> Self {
let mut error_map = errors;
if !error_map.contains_key("error") {
if let Some((_key, messages)) = error_map.iter().next() {
if let Some(first_message) = messages.first() {
error_map.insert("error".to_owned(), vec![first_message.clone()]);
}
}
}
Self { errors: error_map }
}
}
pub type AppResult<T> = Result<T, Error>;
pub type ErrorMap = HashMap<Cow<'static, str>, Vec<Cow<'static, str>>>;
#[derive(Error, Debug)]
pub enum Error {
#[error("authentication is required to access this resource")]
Unauthorized,
#[error("payment is required to access this resource")]
PaymentRequired,
#[error("user does not have privilege to access this resource")]
Forbidden,
#[error("plan does not allow to access this resource")]
PlanDoesNotAllowAccess,
#[error("{0}")]
NotFound(
String,
),
#[error("{0}")]
ApplicationStartup(
String,
),
#[error("{0}")]
BadRequest(
String,
),
#[error("unexpected error has occurred")]
InternalServerError,
#[error("{0}")]
InternalServerErrorWithContext(
String,
),
#[error("{0}")]
ObjectConflict(
String,
),
#[error("unprocessable request has occurred")]
UnprocessableEntity {
errors: ErrorMap,
},
#[error(transparent)]
Validation(
#[from]
ValidationErrors,
),
#[error(transparent)]
AxumJsonRejection(
#[from]
JsonRejection,
),
#[error(transparent)]
AnyhowError(
#[from]
anyhow::Error,
),
#[error("error while encrypting config")]
AesGcmError,
#[error("{0}")]
ConfigForming(
String,
),
#[error("{0}")]
DatabaseParsing(
String,
),
#[error(transparent)]
StdIO(
#[from]
std::io::Error,
),
#[error(transparent)]
Zip(
#[from]
zip::result::ZipError,
),
#[error(transparent)]
Rusqlite(
#[from]
rusqlite::Error,
),
#[cfg(test)]
#[error(transparent)]
SqlxSqlite(
#[from]
sqlx::sqlite::SqliteError,
),
#[error(transparent)]
S3(
#[from]
aws_sdk_s3::Error,
),
#[error(transparent)]
S3Client(
#[from]
aws_sdk_s3::error::SdkError<
aws_sdk_s3::operation::get_object::GetObjectError,
>,
),
#[error("S3 GetObjectError")]
S3GetObject(
#[from]
aws_sdk_s3::operation::get_object::GetObjectError,
),
#[error("base file hash does not match")]
BaseFileHashMismatch,
#[error("resulting file hash does not match")]
ResultingFileHashMismatch,
#[error("diff error: {0}")]
Diff(
String,
),
#[error(transparent)]
SerdeJson(
#[from]
serde_json::Error,
),
#[error("no previous datadir, send full zip")]
NoPreviousDatadir,
#[error("unknown mode: {0}")]
UnknownMode(
String,
),
#[error(transparent)]
TimeFormatDescription(
#[from]
time::error::InvalidFormatDescription,
),
#[error(transparent)]
TimeFormat(
#[from]
time::error::Format,
),
#[error("unknown override mode: {0}")]
UnknownOverrideMode(
String,
),
#[error("Error while decoding {0} field from database: {1}")]
DatabaseDecoding(
String,
String,
),
}
impl From<darkwing_diff::Error> for Error {
fn from(value: darkwing_diff::Error) -> Self {
Self::Diff(format!("{:?}", value))
}
}
impl Error {
pub fn unprocessable_entity(errors: ValidationErrors) -> Response {
let mut validation_errors = ErrorMap::new();
for (field_property, error_kind) in errors.into_errors() {
if let ValidationErrorsKind::Field(field_meta) = error_kind.clone() {
for error in field_meta.into_iter() {
validation_errors
.entry(Cow::from(field_property))
.or_default()
.push(error.message.unwrap_or_else(|| {
let params: Vec<Cow<'static, str>> = error
.params
.iter()
.filter(|(key, _value)| *key.to_owned() != *"value")
.map(|(key, value)| {
Cow::from(format!("{} value is {}", key, value))
})
.collect();
if !params.is_empty() {
Cow::from(params.join(", "))
} else {
Cow::from(format!("{} is required", field_property))
}
}));
}
}
if let ValidationErrorsKind::Struct(meta) = error_kind.clone() {
for (struct_property, struct_error_kind) in meta.into_errors() {
if let ValidationErrorsKind::Field(field_meta) = struct_error_kind {
for error in field_meta.into_iter() {
validation_errors
.entry(Cow::from(struct_property))
.or_default()
.push(error.message.unwrap_or_else(|| {
Cow::from(format!("{} is required", struct_property))
}));
}
}
}
}
}
let mut errors_map: HashMap<String, Vec<String>> = validation_errors
.into_iter()
.map(|(k, v)| {
(
k.into_owned(),
v.into_iter().map(|cow| cow.into_owned()).collect(),
)
})
.collect();
if let Some((_key, messages)) = errors_map.iter().next() {
if let Some(first_message) = messages.first() {
errors_map.insert("error".to_string(), vec![first_message.clone()]);
}
}
let body = Json(json!({
"errors": errors_map,
}));
(StatusCode::BAD_REQUEST, body).into_response()
}
}
impl IntoResponse for Error {
fn into_response(self) -> Response {
if let Self::Validation(e) = self {
return Self::unprocessable_entity(e);
}
let (status, error_message) = match self {
Self::InternalServerErrorWithContext(ref err) => {
(StatusCode::INTERNAL_SERVER_ERROR, err.to_string())
}
Self::NotFound(ref err) => (StatusCode::NOT_FOUND, err.to_string()),
Self::PaymentRequired => (
StatusCode::PAYMENT_REQUIRED,
Self::PaymentRequired.to_string(),
),
Self::ObjectConflict(ref err) => (StatusCode::CONFLICT, err.to_string()),
Self::Unauthorized => {
(StatusCode::UNAUTHORIZED, Self::Unauthorized.to_string())
}
Self::Forbidden => (StatusCode::FORBIDDEN, Self::Forbidden.to_string()),
Self::AxumJsonRejection(ref err) => {
(StatusCode::BAD_REQUEST, err.body_text())
}
Self::NoPreviousDatadir => (
StatusCode::FAILED_DEPENDENCY,
Self::NoPreviousDatadir.to_string(),
),
_ => (
StatusCode::INTERNAL_SERVER_ERROR,
String::from("unexpected error occurred"),
),
};
if status != StatusCode::UNAUTHORIZED {
sentry::capture_error(&self);
}
let body = Json(ApiError::new(error_message));
(status, body).into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::http::StatusCode;
use serde_json::Value;
use validator::Validate;
#[derive(Debug, Validate)]
struct TestStruct {
#[validate(length(min = 3))]
name: String,
#[validate(range(min = 18))]
age: i32,
}
#[test]
fn test_api_error_new() {
let error = ApiError::new("test error".to_string());
assert!(error.errors.contains_key("error"));
assert_eq!(error.errors["error"], vec!["test error"]);
}
#[test]
fn test_api_error_from_map() {
let mut errors = HashMap::new();
errors.insert("name".to_string(), vec!["name is invalid".to_string()]);
errors.insert("age".to_string(), vec!["age is too low".to_string()]);
let api_error = ApiError::from_map(errors);
assert!(api_error.errors.contains_key("error"));
assert!(api_error.errors.contains_key("name"));
assert!(api_error.errors.contains_key("age"));
assert_eq!(api_error.errors["error"].len(), 1);
}
#[tokio::test]
async fn test_validation_error_response() {
let test_struct = TestStruct {
name: "a".to_string(),
age: 15,
};
let validation_result = test_struct.validate();
assert!(validation_result.is_err());
let validation_errors = validation_result
.err()
.expect("Should have validation errors");
let response = Error::unprocessable_entity(validation_errors);
let response_status = response.status();
assert_eq!(response_status, StatusCode::BAD_REQUEST);
if let Ok(body) = String::from_utf8(
axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.expect("cannot parse body")
.to_vec(),
) {
let json: Value =
serde_json::from_str(&body).expect("Should be valid JSON");
let errors = json.get("errors").expect("Should have errors object");
assert!(errors.is_object());
let errors_obj = errors.as_object().unwrap();
assert!(errors_obj.contains_key("error"));
assert!(errors_obj.contains_key("name"));
assert!(errors_obj.contains_key("age"));
}
}
#[test]
fn test_error_into_response() {
let not_found = Error::NotFound("Resource not found".to_string());
let response = not_found.into_response();
assert_eq!(response.status(), StatusCode::NOT_FOUND);
let unauthorized = Error::Unauthorized;
let response = unauthorized.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
let internal_error = Error::InternalServerError;
let response = internal_error.into_response();
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_error_messages() {
let error = Error::NotFound("test resource not found".to_string());
assert_eq!(error.to_string(), "test resource not found");
let error = Error::Unauthorized;
assert_eq!(
error.to_string(),
"authentication is required to access this resource"
);
let error = Error::Forbidden;
assert_eq!(
error.to_string(),
"user does not have privilege to access this resource"
);
}
}