darkwing/server/extractors/
required_authentication_extractor.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
use async_trait::async_trait;
use axum::extract::FromRequestParts;
use axum::http::header::AUTHORIZATION;
use axum::http::request::Parts;
use axum::Extension;
use time::OffsetDateTime;
use tracing::error;

use crate::server::dtos::team_dto::ResponseTeamDto;
use crate::server::dtos::user_dto::ResponseUserDto;
use crate::server::error::Error;
use crate::server::services::Services;
use crate::unreachable_if_none;

/// An extractor that handles authentication and authorization for protected
/// routes.
///
/// This extractor validates the JWT token from the Authorization header,
/// retrieves the associated user and team information, and checks subscription
/// status. It is used as a guard for routes that require authenticated access.
///
/// # Fields
/// * `user` - The authenticated user's information
/// * `team` - The team information associated with the authenticated user
/// * `services` - Access to application services
/// * `token` - The validated JWT token string
///
/// # Example
/// ```rust
/// async fn protected_route(
///     RequiredAuthentication { user, team, .. }: RequiredAuthentication
/// ) -> Result<impl Response, Error> {
///     // Handle authenticated request
/// }
/// ```
pub struct RequiredAuthentication {
  pub user: ResponseUserDto,
  pub team: ResponseTeamDto,
  pub services: Services,
  pub token: String,
}

#[async_trait]
impl<S> FromRequestParts<S> for RequiredAuthentication
where
  S: Send + Sync,
{
  type Rejection = Error;

  /// Implementation of the FromRequestParts trait to enable automatic
  /// extraction of authentication information from incoming requests.
  ///
  /// This implementation:
  /// 1. Extracts the services from request extensions
  /// 2. Validates the Authorization header format (Bearer token)
  /// 3. Extracts and validates the JWT token
  /// 4. Retrieves the associated user and team information
  /// 5. Verifies subscription status
  ///
  /// # Errors
  /// Returns `Error::Unauthorized` in cases of:
  /// - Missing or malformed Authorization header
  /// - Invalid token format
  /// - Invalid or expired JWT token
  /// - User not found
  /// - Team not found
  ///
  /// Returns `Error::PaymentRequired` if:
  /// - Team's subscription has expired (except for free plan teams)
  async fn from_request_parts(
    parts: &mut Parts,
    state: &S,
  ) -> Result<Self, Self::Rejection> {
    let Extension(services): Extension<Services> =
      Extension::from_request_parts(parts, state)
        .await
        .map_err(|err| {
          Error::InternalServerErrorWithContext(err.to_string())
        })?;

    if let Some(authorization_header) = parts.headers.get(AUTHORIZATION) {
      let header_value = authorization_header
        .to_str()
        .map_err(|_| Error::Unauthorized)?;

      if !header_value.contains("Bearer") {
        error!(
          "request does not contain valid 'Bearer' prefix for authorization"
        );
        return Err(Error::Unauthorized);
      }

      let tokenized_value: Vec<_> = header_value.split(' ').collect();

      if tokenized_value.len() != 2 || tokenized_value.get(1).is_none() {
        error!(
          "request does not contain a valid token: {:?}",
          tokenized_value
        );
        return Err(Error::Unauthorized);
      }

      let token_value =
        unreachable_if_none!(tokenized_value.into_iter().nth(1));
      let user_id = services
        .jwt_util
        .get_user_id_from_token(String::from(token_value))
        .map_err(|err| {
          error!("could not validate user ID from token: {:?}", err);
          Error::Unauthorized
        })?;

      let user =
        services
          .users
          .get_user_by_id(user_id)
          .await
          .map_err(|err| {
            error!("invalid user ID from token: {:?}", err);
            Error::Unauthorized
          })?;

      let team =
        services
          .users
          .get_team_by_id(user.team_id)
          .await
          .map_err(|err| {
            error!("invalid user ID from token: {:?}", err);
            Error::Unauthorized
          })?;

      if team.subscription_expiration.assume_utc() < OffsetDateTime::now_utc()
        && !services.users.is_fully_free_plan(&team)
      {
        return Err(Error::PaymentRequired);
      }

      Ok(RequiredAuthentication {
        user,
        team,
        services,
        token: token_value.into(),
      })
    } else {
      Err(Error::Unauthorized)
    }
  }
}