auth.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150
  1. """
  2. 认证相关API
  3. """
  4. from fastapi import APIRouter, Depends, HTTPException, status
  5. from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
  6. from sqlalchemy.ext.asyncio import AsyncSession
  7. from sqlalchemy import select
  8. from app.core.database import get_db
  9. from app.core.security import (
  10. verify_password,
  11. get_password_hash,
  12. create_access_token,
  13. create_refresh_token,
  14. decode_token,
  15. encrypt_api_key
  16. )
  17. from app.models.user import User
  18. from app.schemas.user import UserCreate, UserResponse, TokenResponse, APIKeyCreate, APIKeyResponse
  19. from typing import List
  20. router = APIRouter()
  21. oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/auth/login")
  22. async def get_current_user(
  23. token: str = Depends(oauth2_scheme),
  24. db: AsyncSession = Depends(get_db)
  25. ) -> User:
  26. """获取当前用户(依赖注入)"""
  27. payload = decode_token(token)
  28. if not payload:
  29. raise HTTPException(
  30. status_code=status.HTTP_401_UNAUTHORIZED,
  31. detail="无效的令牌"
  32. )
  33. user_id = payload.get("user_id")
  34. result = await db.execute(select(User).where(User.id == user_id))
  35. user = result.scalar_one_or_none()
  36. if not user or not user.is_active:
  37. raise HTTPException(
  38. status_code=status.HTTP_401_UNAUTHORIZED,
  39. detail="用户不存在或已禁用"
  40. )
  41. return user
  42. @router.post("/register", response_model=UserResponse)
  43. async def register(user_data: UserCreate, db: AsyncSession = Depends(get_db)):
  44. """用户注册"""
  45. # 检查用户名是否存在
  46. result = await db.execute(select(User).where(User.username == user_data.username))
  47. if result.scalar_one_or_none():
  48. raise HTTPException(status_code=400, detail="用户名已存在")
  49. # 检查邮箱是否存在
  50. result = await db.execute(select(User).where(User.email == user_data.email))
  51. if result.scalar_one_or_none():
  52. raise HTTPException(status_code=400, detail="邮箱已被注册")
  53. # 创建用户
  54. user = User(
  55. username=user_data.username,
  56. email=user_data.email,
  57. password_hash=get_password_hash(user_data.password)
  58. )
  59. db.add(user)
  60. await db.commit()
  61. await db.refresh(user)
  62. return user
  63. @router.post("/login", response_model=TokenResponse)
  64. async def login(
  65. form_data: OAuth2PasswordRequestForm = Depends(),
  66. db: AsyncSession = Depends(get_db)
  67. ):
  68. """用户登录"""
  69. # 查找用户
  70. result = await db.execute(select(User).where(User.username == form_data.username))
  71. user = result.scalar_one_or_none()
  72. if not user or not verify_password(form_data.password, user.password_hash):
  73. raise HTTPException(
  74. status_code=status.HTTP_401_UNAUTHORIZED,
  75. detail="用户名或密码错误"
  76. )
  77. # 生成令牌
  78. access_token = create_access_token({"user_id": user.id})
  79. refresh_token = create_refresh_token({"user_id": user.id})
  80. return {
  81. "access_token": access_token,
  82. "refresh_token": refresh_token,
  83. "token_type": "bearer"
  84. }
  85. @router.get("/me", response_model=UserResponse)
  86. async def get_current_user_info(current_user: User = Depends(get_current_user)):
  87. """获取当前用户信息"""
  88. return current_user
  89. @router.post("/api-keys", response_model=APIKeyResponse)
  90. async def save_api_key(
  91. api_key_data: APIKeyCreate,
  92. current_user: User = Depends(get_current_user),
  93. db: AsyncSession = Depends(get_db)
  94. ):
  95. """保存用户的API Key"""
  96. # 加密API Key
  97. encrypted_key = encrypt_api_key(api_key_data.api_key)
  98. # 更新用户的加密API Keys
  99. if not current_user.encrypted_api_keys:
  100. current_user.encrypted_api_keys = {}
  101. current_user.encrypted_api_keys[api_key_data.provider] = {
  102. "key": encrypted_key,
  103. "model": api_key_data.model
  104. }
  105. await db.commit()
  106. return {
  107. "provider": api_key_data.provider,
  108. "model": api_key_data.model,
  109. "status": "active"
  110. }
  111. @router.get("/api-keys", response_model=List[APIKeyResponse])
  112. async def get_api_keys(current_user: User = Depends(get_current_user)):
  113. """获取用户配置的API Key列表(不返回实际key)"""
  114. if not current_user.encrypted_api_keys:
  115. return []
  116. return [
  117. {
  118. "provider": provider,
  119. "model": config.get("model"),
  120. "status": "active"
  121. }
  122. for provider, config in current_user.encrypted_api_keys.items()
  123. ]