Source code for app.domain.accounts.services._password_reset

from __future__ import annotations

import hashlib
import secrets
from datetime import UTC, datetime, timedelta
from typing import TYPE_CHECKING

from advanced_alchemy.extensions.litestar import repository, service
from litestar.exceptions import ClientException

from app.db import models as m

if TYPE_CHECKING:
    from uuid import UUID


[docs] class PasswordResetService(service.SQLAlchemyAsyncRepositoryService[m.PasswordResetToken]): """Handles database operations for password reset tokens."""
[docs] class Repo(repository.SQLAlchemyAsyncRepository[m.PasswordResetToken]): """PasswordResetToken SQLAlchemy Repository.""" model_type = m.PasswordResetToken
repository_type = Repo match_fields = ["token_hash"] @staticmethod def _hash_token(token: str) -> str: return hashlib.sha256(token.encode()).hexdigest()
[docs] async def to_model_on_create( self, data: service.ModelDictT[m.PasswordResetToken], ) -> service.ModelDictT[m.PasswordResetToken]: data = service.schema_dump(data) if service.is_dict_with_field(data, "token") and service.is_dict_without_field(data, "token_hash"): data["token_hash"] = self._hash_token(data.pop("token")) return data
[docs] async def create_reset_token( self, user_id: UUID, ip_address: str | None = None, user_agent: str | None = None ) -> tuple[m.PasswordResetToken, str]: """Create a new password reset token for a user. Args: user_id: The user's UUID ip_address: IP address of the request user_agent: User agent string of the request Returns: Tuple of (PasswordResetToken, plain_token) """ await self.invalidate_user_tokens(user_id) token = secrets.token_urlsafe(32) created = await self.create( { "user_id": user_id, "token": token, "expires_at": m.PasswordResetToken.create_expires_at(hours=1), "ip_address": ip_address, "user_agent": user_agent, }, auto_commit=True, ) return created, token
[docs] async def validate_reset_token(self, token: str) -> m.PasswordResetToken: """Validate a token without consuming it. Args: token: The reset token string Returns: The PasswordResetToken instance if valid. Raises: ClientException: If token is invalid, expired, or already used """ reset_token = await self.get_one_or_none(token_hash=self._hash_token(token)) if reset_token is None: raise ClientException(detail="Invalid reset token", status_code=400) if reset_token.is_expired: raise ClientException(detail="Reset token has expired", status_code=400) if reset_token.is_used: raise ClientException(detail="Reset token has already been used", status_code=400) return reset_token
[docs] async def use_reset_token(self, token: str) -> m.PasswordResetToken: """Use a token to mark it as consumed. Args: token: The reset token string Returns: The PasswordResetToken instance """ reset_token = await self.validate_reset_token(token) reset_token.used_at = datetime.now(UTC) await self.update(reset_token) return reset_token
[docs] async def invalidate_user_tokens(self, user_id: UUID) -> None: """Invalidate all active tokens for a user. Args: user_id: The user's UUID """ tokens = await self.list(m.PasswordResetToken.user_id == user_id, m.PasswordResetToken.used_at.is_(None)) current_time = datetime.now(UTC) for token in tokens: if not token.is_used: token.used_at = current_time if tokens: await self.update_many(tokens)
[docs] async def cleanup_expired_tokens(self) -> int: """Remove expired tokens from the database. Returns: Number of tokens removed """ current_time = datetime.now(UTC) expired_tokens = await self.list(m.PasswordResetToken.expires_at < current_time) if not expired_tokens: return 0 # Pass IDs explicitly to delete_many, not model objects token_ids = [token.id for token in expired_tokens] await self.delete_many(token_ids) return len(expired_tokens)
[docs] async def check_rate_limit(self, user_id: UUID, hours: float = 1) -> bool: """Check if user has exceeded reset token creation rate limit. Args: user_id: The user's UUID hours: Hours to look back for rate limiting Returns: True if rate limit exceeded, False otherwise """ cutoff_time = datetime.now(UTC) - timedelta(hours=hours) recent_tokens = await self.list( m.PasswordResetToken.user_id == user_id, m.PasswordResetToken.created_at >= cutoff_time ) max_reset_requests_per_hour = 3 return len(recent_tokens) >= max_reset_requests_per_hour