You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

96 lines
2.9 KiB

  1. from datetime import datetime
  2. from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
  3. from fastapi import Depends, APIRouter
  4. from sqlalchemy.orm import Session
  5. from fastapi import status, HTTPException
  6. from typing import Annotated
  7. from db.session import get_db
  8. from core.hashing import Hasher
  9. from core.config import settings
  10. from jose import JWTError, jwt
  11. from schemas.token import Token, TokenPayload
  12. from db.repository.user import get_user_by_email, get_user_by_phone
  13. from core.auth import create_access_token
  14. router = APIRouter()
  15. oauth2_scheme = OAuth2PasswordBearer(
  16. tokenUrl="/token",
  17. )
  18. def authenticate_user(login: str, password: str, db: Session):
  19. print("Trying to auth...")
  20. user = None
  21. if "@" in login:
  22. user = get_user_by_email(email=login, db=db)
  23. elif "+" in login:
  24. user = get_user_by_phone(phone=login, db=db)
  25. else:
  26. return False
  27. if not user:
  28. return False
  29. if not Hasher.verify_password(password, user.HashedPassword):
  30. return False
  31. return user
  32. def get_current_user(
  33. token: Annotated[str, Depends(oauth2_scheme)],
  34. db: Annotated[Session, Depends(get_db)],
  35. ):
  36. print("Getting current user...")
  37. try:
  38. payload = jwt.decode(
  39. token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]
  40. )
  41. token_data = TokenPayload(**payload)
  42. if datetime.fromtimestamp(token_data.exp) < datetime.now():
  43. raise HTTPException(
  44. status_code=status.HTTP_401_UNAUTHORIZED,
  45. detail="Session expired. Please login again.",
  46. )
  47. username: str = payload.get("sub")
  48. if username is None:
  49. raise HTTPException(
  50. status_code=status.HTTP_401_UNAUTHORIZED,
  51. detail="Could not validate credentials",
  52. )
  53. except JWTError:
  54. raise HTTPException(
  55. status_code=status.HTTP_401_UNAUTHORIZED,
  56. detail="Could not validate credentials",
  57. )
  58. if "@" in username:
  59. user = get_user_by_email(email=username, db=db)
  60. elif "+" in username:
  61. user = get_user_by_phone(phone=username, db=db)
  62. else:
  63. raise HTTPException(
  64. status_code=status.HTTP_401_UNAUTHORIZED,
  65. detail="Could not validate credentials",
  66. )
  67. return user
  68. @router.post("/token", response_model=Token)
  69. def access_token(
  70. form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)
  71. ):
  72. print("Getting token...")
  73. user = authenticate_user(form_data.username, form_data.password, db)
  74. print(user)
  75. if not user:
  76. raise HTTPException(
  77. status_code=status.HTTP_401_UNAUTHORIZED,
  78. detail="Invalid username or password",
  79. )
  80. access_token = create_access_token(data={"sub": user.Email})
  81. print("TOKENS ARE: ")
  82. print(access_token)
  83. return {
  84. "access_token": access_token,
  85. "token_type": "bearer",
  86. }