Skip to content

jwt

mlte/backend/api/auth/jwt.py

Handling of JWT tokens.

ALGORITHM = 'HS256' module-attribute

Token hashing algorithm.

DEFAULT_EXPIRATION_MINS = 120 module-attribute

Default token expiration time.

EXPIRATION_CLAIM_KEY = 'exp' module-attribute

Token claim keys.

DecodedToken

Bases: BaseModel

Model for the claims inside the token.

Source code in mlte/backend/api/auth/jwt.py
38
39
40
41
42
43
44
45
class DecodedToken(BaseModel):
    """Model for the claims inside the token."""

    username: str
    """The user name."""

    expiration_time: datetime
    """The date and time the token expires."""

expiration_time instance-attribute

The date and time the token expires.

username instance-attribute

The user name.

EncodedToken

Bases: BaseModel

Model for the encoded token and additional metadata.

Source code in mlte/backend/api/auth/jwt.py
28
29
30
31
32
33
34
35
class EncodedToken(BaseModel):
    """Model for the encoded token and additional metadata."""

    encoded_token: str
    """The actual encoded token."""

    expires_in: int
    """Lifetime in seconds of the token."""

encoded_token instance-attribute

The actual encoded token.

expires_in instance-attribute

Lifetime in seconds of the token.

check_expired_token(token)

Checks whether the provided token has expired.

Source code in mlte/backend/api/auth/jwt.py
92
93
94
def check_expired_token(token: DecodedToken) -> bool:
    """Checks whether the provided token has expired."""
    return token.expiration_time < datetime.now(timezone.utc)

create_user_token(username, key, expires_delta=None)

Creates an access token containing a given username.

Source code in mlte/backend/api/auth/jwt.py
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
def create_user_token(
    username: str, key: str, expires_delta: Optional[timedelta] = None
) -> EncodedToken:
    """Creates an access token containing a given username."""
    # Main data is username.
    claims: dict[str, Union[str, int]] = {SUBJECT_CLAIM_KEY: username}

    # Calculate expiration time, and add it to claims.
    if expires_delta is None:
        expires_delta = timedelta(minutes=DEFAULT_EXPIRATION_MINS)
    expiration_time = datetime.now(timezone.utc) + expires_delta
    claims.update({EXPIRATION_CLAIM_KEY: int(expiration_time.timestamp())})

    # Encode and sign token, and return it.
    encoded_jwt = jwt.encode(claims, key, algorithm=ALGORITHM)
    token = EncodedToken(
        encoded_token=encoded_jwt, expires_in=int(expires_delta.total_seconds())
    )
    return token

decode_user_token(encoded_token, key)

Decodes the provided user access token.

Source code in mlte/backend/api/auth/jwt.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def decode_user_token(encoded_token: str, key: str) -> DecodedToken:
    """Decodes the provided user access token."""
    try:
        payload = jwt.decode(encoded_token, key, algorithms=[ALGORITHM])

        username: str = typing.cast(str, payload.get(SUBJECT_CLAIM_KEY))
        if username is None:
            raise Exception("No valid user in token")
        exp_timestamp: int = typing.cast(int, payload.get(EXPIRATION_CLAIM_KEY))
        if exp_timestamp is None:
            raise Exception("No valid expiration time in token")
        expiration_time: datetime = datetime.fromtimestamp(
            exp_timestamp, timezone.utc
        )
        decoded_token = DecodedToken(
            username=username, expiration_time=expiration_time
        )

        return decoded_token
    except JWTError as ex:
        raise Exception(f"Error decoding token: {str(ex)}")