"""MFA Management Controller."""
from __future__ import annotations
import base64
import logging
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from litestar import Controller, delete, get, post
from litestar.di import Provide
from litestar.exceptions import ClientException
from sqlalchemy.orm import undefer_group
from app.domain.accounts.deps import provide_users_service
from app.domain.accounts.schemas import (
MfaBackupCodes,
MfaConfirm,
MfaDisable,
MfaSetup,
MfaStatus,
)
from app.domain.admin.deps import provide_audit_log_service
from app.lib.crypt import (
generate_backup_codes,
generate_totp_qr_code,
generate_totp_secret,
get_totp_provisioning_uri,
verify_password,
verify_totp_code,
)
from app.lib.schema import Message
if TYPE_CHECKING:
from litestar import Request
from litestar.security.jwt import Token
from app.db import models as m
from app.domain.accounts.services import UserService
from app.domain.admin.services import AuditLogService
from app.lib.settings import AppSettings
MFA_RATE_LIMIT_WINDOW_MINUTES = 15
MFA_RATE_LIMIT_MAX_ATTEMPTS = 5
logger = logging.getLogger(__name__)
[docs]
class MfaController(Controller):
"""MFA management endpoints for setting up and managing two-factor authentication."""
tags = ["MFA"]
path = "/api/mfa"
dependencies = {
"users_service": Provide(provide_users_service),
"audit_service": Provide(provide_audit_log_service),
}
@get(operation_id="GetMfaStatus", path="/status")
async def get_mfa_status(
self,
request: Request[m.User, Token, Any],
users_service: UserService,
) -> MfaStatus:
"""Get current MFA status for the authenticated user.
Args:
request: Request with authenticated user
users_service: User service
Returns:
Current MFA status
"""
user = await users_service.get(request.user.id, load=[undefer_group("security_sensitive")])
backup_codes_remaining = None
if user.backup_codes:
backup_codes_remaining = sum(1 for code in user.backup_codes if code)
return MfaStatus(
enabled=user.is_two_factor_enabled,
confirmed_at=user.two_factor_confirmed_at,
backup_codes_remaining=backup_codes_remaining,
)
@post(operation_id="InitiateMfaSetup", path="/enable")
async def initiate_setup(
self,
request: Request[m.User, Token, Any],
users_service: UserService,
settings: AppSettings,
) -> MfaSetup:
"""Initiate MFA setup - generates TOTP secret and QR code.
The secret is stored but MFA is not enabled until confirmed with a valid code.
Args:
request: Request with authenticated user
users_service: User service
settings: Application settings
Returns:
TOTP secret and QR code for authenticator app
Raises:
ClientException: If MFA is already enabled
"""
user = await users_service.get(request.user.id, load=[undefer_group("security_sensitive")])
if user.is_two_factor_enabled:
raise ClientException(detail="MFA is already enabled", status_code=400)
secret = generate_totp_secret()
await users_service.update({"totp_secret": secret}, item_id=user.id)
issuer = settings.slug
qr_code_bytes = await generate_totp_qr_code(secret, user.email, issuer=issuer)
qr_code_base64 = base64.b64encode(qr_code_bytes).decode("utf-8")
provisioning_uri = get_totp_provisioning_uri(secret, user.email, issuer=issuer)
return MfaSetup(
secret=secret,
qr_code=f"data:image/png;base64,{qr_code_base64}",
provisioning_uri=provisioning_uri,
)
@post(operation_id="ConfirmMfaSetup", path="/confirm")
async def confirm_setup(
self,
request: Request[m.User, Token, Any],
users_service: UserService,
audit_service: AuditLogService,
data: MfaConfirm,
) -> MfaBackupCodes:
"""Confirm MFA setup with a valid TOTP code.
Verifies the code, enables MFA, and returns backup codes.
Args:
request: Request with authenticated user
users_service: User service
audit_service: Audit log service
data: TOTP code from authenticator app
Returns:
Backup recovery codes (shown only once)
Raises:
ClientException: If code is invalid or no setup in progress
"""
user = await users_service.get(request.user.id, load=[undefer_group("security_sensitive")])
failed_attempts = await audit_service.count_recent_actions(
action="mfa.setup.failed",
actor_id=user.id,
window_minutes=MFA_RATE_LIMIT_WINDOW_MINUTES,
)
if failed_attempts >= MFA_RATE_LIMIT_MAX_ATTEMPTS:
raise ClientException(detail="Too many verification attempts. Please try again later.", status_code=429)
if user.is_two_factor_enabled:
raise ClientException(detail="MFA is already enabled", status_code=400)
if not user.totp_secret:
raise ClientException(detail="No MFA setup in progress. Call /enable first.", status_code=400)
if not verify_totp_code(user.totp_secret, data.code):
await audit_service.log_action(
action="mfa.setup.failed",
actor_id=user.id,
actor_email=user.email,
target_type="user",
target_id=str(user.id),
request=request,
)
raise ClientException(detail="Invalid verification code", status_code=400)
plaintext_codes = generate_backup_codes(count=8)
await users_service.update(
{
"is_two_factor_enabled": True,
"two_factor_confirmed_at": datetime.now(UTC),
"backup_codes": plaintext_codes,
},
item_id=user.id,
)
await audit_service.log_action(
action="mfa.setup.confirmed",
actor_id=user.id,
actor_email=user.email,
target_type="user",
target_id=str(user.id),
request=request,
)
return MfaBackupCodes(codes=plaintext_codes)
@delete(operation_id="DisableMfa", path="/disable", status_code=200)
async def disable_mfa(
self,
request: Request[m.User, Token, Any],
users_service: UserService,
data: MfaDisable,
) -> Message:
"""Disable MFA for the authenticated user.
Requires password verification for security.
Args:
request: Request with authenticated user
users_service: User service
data: Password for verification
Returns:
Success message
Raises:
ClientException: If password is incorrect or MFA not enabled
"""
user = await users_service.get(request.user.id, load=[undefer_group("security_sensitive")])
if not user.is_two_factor_enabled:
raise ClientException(detail="MFA is not enabled", status_code=400)
if not user.hashed_password or not await verify_password(data.password, user.hashed_password):
raise ClientException(detail="Invalid password", status_code=400)
await users_service.update(
{
"is_two_factor_enabled": False,
"totp_secret": None,
"two_factor_confirmed_at": None,
"backup_codes": None,
},
item_id=user.id,
)
logger.info("MFA disabled for user %s", user.email)
return Message(message="MFA has been disabled")
@post(operation_id="RegenerateMfaBackupCodes", path="/regenerate-codes")
async def regenerate_backup_codes(
self,
request: Request[m.User, Token, Any],
users_service: UserService,
data: MfaDisable,
) -> MfaBackupCodes:
"""Generate new backup codes (invalidates old ones).
Requires password verification for security.
Args:
request: Request with authenticated user
users_service: User service
data: Password for verification
Returns:
New backup codes (shown only once)
Raises:
ClientException: If password is incorrect or MFA not enabled
"""
user = await users_service.get(request.user.id, load=[undefer_group("security_sensitive")])
if not user.is_two_factor_enabled:
raise ClientException(detail="MFA is not enabled", status_code=400)
if not user.hashed_password or not await verify_password(data.password, user.hashed_password):
raise ClientException(detail="Invalid password", status_code=400)
plaintext_codes = generate_backup_codes(count=8)
await users_service.update({"backup_codes": plaintext_codes}, item_id=user.id)
logger.info("Backup codes regenerated for user %s", user.email)
return MfaBackupCodes(codes=plaintext_codes)