|
|
@@ -1,3 +1,4 @@ |
|
|
|
from datetime import datetime |
|
|
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm |
|
|
|
from fastapi import Depends, APIRouter |
|
|
|
from sqlalchemy.orm import Session |
|
|
@@ -7,20 +8,23 @@ from db.session import get_db |
|
|
|
from core.hashing import Hasher |
|
|
|
from core.config import settings |
|
|
|
from jose import JWTError, jwt |
|
|
|
from schemas.token import Token |
|
|
|
from schemas.token import Token, TokenPayload |
|
|
|
from db.repository.user import get_user_by_email, get_user_by_phone |
|
|
|
from core.auth import create_access_token |
|
|
|
|
|
|
|
|
|
|
|
router = APIRouter() |
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token") |
|
|
|
oauth2_scheme = OAuth2PasswordBearer( |
|
|
|
tokenUrl="/token", |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def authenticate_user(login: str, password: str, db: Session): |
|
|
|
print("Trying to auth...") |
|
|
|
user = None |
|
|
|
if ("@" in login): |
|
|
|
if "@" in login: |
|
|
|
user = get_user_by_email(email=login, db=db) |
|
|
|
elif ("+" in login): |
|
|
|
elif "+" in login: |
|
|
|
user = get_user_by_phone(phone=login, db=db) |
|
|
|
else: |
|
|
|
return False |
|
|
@@ -31,32 +35,61 @@ def authenticate_user(login: str, password: str, db: Session): |
|
|
|
return user |
|
|
|
|
|
|
|
|
|
|
|
def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Annotated[Session, Depends(get_db)]): |
|
|
|
def get_current_user( |
|
|
|
token: Annotated[str, Depends(oauth2_scheme)], |
|
|
|
db: Annotated[Session, Depends(get_db)], |
|
|
|
): |
|
|
|
print("Getting current user...") |
|
|
|
try: |
|
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) |
|
|
|
payload = jwt.decode( |
|
|
|
token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM] |
|
|
|
) |
|
|
|
token_data = TokenPayload(**payload) |
|
|
|
|
|
|
|
if datetime.fromtimestamp(token_data.exp) < datetime.now(): |
|
|
|
raise HTTPException( |
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
|
detail="Session expired. Please login again.", |
|
|
|
) |
|
|
|
username: str = payload.get("sub") |
|
|
|
if username is None: |
|
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") |
|
|
|
raise HTTPException( |
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
|
detail="Could not validate credentials", |
|
|
|
) |
|
|
|
except JWTError: |
|
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") |
|
|
|
if ("@" in username): |
|
|
|
raise HTTPException( |
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
|
detail="Could not validate credentials", |
|
|
|
) |
|
|
|
if "@" in username: |
|
|
|
user = get_user_by_email(email=username, db=db) |
|
|
|
elif ("+" in username): |
|
|
|
elif "+" in username: |
|
|
|
user = get_user_by_phone(phone=username, db=db) |
|
|
|
else: |
|
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") |
|
|
|
raise HTTPException( |
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
|
detail="Could not validate credentials", |
|
|
|
) |
|
|
|
return user |
|
|
|
|
|
|
|
|
|
|
|
@router.post("/token", response_model=Token) |
|
|
|
def access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): |
|
|
|
def access_token( |
|
|
|
form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db) |
|
|
|
): |
|
|
|
print("Getting token...") |
|
|
|
user = authenticate_user(form_data.username, form_data.password, db) |
|
|
|
print(user) |
|
|
|
if not user: |
|
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password") |
|
|
|
raise HTTPException( |
|
|
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
|
|
detail="Invalid username or password", |
|
|
|
) |
|
|
|
access_token = create_access_token(data={"sub": user.Email}) |
|
|
|
return {"access_token": access_token, "token_type": "bearer"} |
|
|
|
|
|
|
|
|
|
|
|
print("TOKENS ARE: ") |
|
|
|
print(access_token) |
|
|
|
return { |
|
|
|
"access_token": access_token, |
|
|
|
"token_type": "bearer", |
|
|
|
} |