@@ -1,8 +1,9 @@ | |||||
# Base API router -- collecting all APIs here to not clutter main.py | # Base API router -- collecting all APIs here to not clutter main.py | ||||
from fastapi import APIRouter | from fastapi import APIRouter | ||||
from apis.v1 import route_user, route_vehicle | |||||
from apis.v1 import route_user, route_vehicle, route_auth | |||||
api_router = APIRouter() | api_router = APIRouter() | ||||
api_router.include_router(route_user.router, prefix="/user", tags=["users"]) | api_router.include_router(route_user.router, prefix="/user", tags=["users"]) | ||||
api_router.include_router(route_vehicle.router, prefix="/vehicle", tags=["vehicles"]) | api_router.include_router(route_vehicle.router, prefix="/vehicle", tags=["vehicles"]) | ||||
api_router.include_router(route_auth.router, prefix="", tags=["auth"]) |
@@ -0,0 +1,60 @@ | |||||
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |||||
from fastapi import Depends, APIRouter | |||||
from sqlalchemy.orm import Session | |||||
from fastapi import status, HTTPException | |||||
from typing import Annotated | |||||
from db.session import get_db | |||||
from core.hashing import Hasher | |||||
from core.config import settings | |||||
from jose import JWTError, jwt | |||||
from schemas.token import Token | |||||
from db.repository.user import get_user_by_email, get_user_by_phone | |||||
from core.auth import create_access_token | |||||
router = APIRouter() | |||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/token") | |||||
def authenticate_user(login: str, password: str, db: Session): | |||||
print("Trying to auth...") | |||||
user = None | |||||
if ("@" in login): | |||||
user = get_user_by_email(email=login, db=db) | |||||
elif ("+" in login): | |||||
user = get_user_by_phone(phone=login, db=db) | |||||
else: | |||||
return False | |||||
if not user: | |||||
return False | |||||
if not Hasher.verify_password(password, user.HashedPassword): | |||||
return False | |||||
return user | |||||
def get_current_user(token: Annotated[str, Depends(oauth2_scheme)], db: Annotated[Session, Depends(get_db)]): | |||||
print("Getting current user...") | |||||
try: | |||||
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) | |||||
username: str = payload.get("sub") | |||||
if username is None: | |||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") | |||||
except JWTError: | |||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") | |||||
if ("@" in username): | |||||
user = get_user_by_email(email=username, db=db) | |||||
elif ("+" in username): | |||||
user = get_user_by_phone(phone=username, db=db) | |||||
else: | |||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials") | |||||
return user | |||||
@router.post("/token", response_model=Token) | |||||
def access_token(form_data: OAuth2PasswordRequestForm = Depends(), db: Session = Depends(get_db)): | |||||
print("Getting token...") | |||||
user = authenticate_user(form_data.username, form_data.password, db) | |||||
print(user) | |||||
if not user: | |||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid username or password") | |||||
access_token = create_access_token(data={"sub": user.Email}) | |||||
return {"access_token": access_token, "token_type": "bearer"} |
@@ -2,8 +2,9 @@ | |||||
from fastapi import APIRouter, HTTPException, status | from fastapi import APIRouter, HTTPException, status | ||||
from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
from fastapi import Depends | from fastapi import Depends | ||||
from typing import List | |||||
from typing import List, Annotated | |||||
from apis.v1.route_auth import get_current_user | |||||
from db.models.user import User | |||||
from schemas.user import UserCreate, ShowUser | from schemas.user import UserCreate, ShowUser | ||||
from db.session import get_db | from db.session import get_db | ||||
from db.repository.user import create_new_user, list_users, get_user_by_id | from db.repository.user import create_new_user, list_users, get_user_by_id | ||||
@@ -13,20 +14,28 @@ router = APIRouter() | |||||
@router.post("/", response_model=ShowUser, status_code=status.HTTP_201_CREATED) | @router.post("/", response_model=ShowUser, status_code=status.HTTP_201_CREATED) | ||||
def create_user(user: UserCreate, db: Session = Depends(get_db)): | |||||
def create_user(user: UserCreate, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): | |||||
if current_user.Role != "Admin": | |||||
raise HTTPException(status_code=403, detail="You are not authorized to perform this action") | |||||
user = create_new_user(user=user, db=db) | user = create_new_user(user=user, db=db) | ||||
return user | return user | ||||
@router.get("/", response_model=List[ShowUser], status_code=status.HTTP_200_OK) | @router.get("/", response_model=List[ShowUser], status_code=status.HTTP_200_OK) | ||||
def get_all_users(db: Session = Depends(get_db), role: str = None): | def get_all_users(db: Session = Depends(get_db), role: str = None): | ||||
if role == None: | |||||
if role is None: | |||||
users = list_users(db=db) | users = list_users(db=db) | ||||
return users | return users | ||||
users = list_users(db=db, role=role) | users = list_users(db=db, role=role) | ||||
return users | return users | ||||
@router.get("/me", response_model=ShowUser, status_code=status.HTTP_200_OK) | |||||
def get_user_me(current_user: Annotated[User, Depends(get_current_user)], db: Annotated[Session, Depends(get_db)]): | |||||
print("Getting current user...") | |||||
return current_user | |||||
@router.get("/{user_id}", response_model=ShowUser, status_code=status.HTTP_200_OK) | @router.get("/{user_id}", response_model=ShowUser, status_code=status.HTTP_200_OK) | ||||
def get_user(user_id: int, db: Session = Depends(get_db)): | def get_user(user_id: int, db: Session = Depends(get_db)): | ||||
user = get_user_by_id(user_id=user_id, db=db) | user = get_user_by_id(user_id=user_id, db=db) | ||||
@@ -10,13 +10,18 @@ from db.repository.vehicle import ( | |||||
list_vehicles, | list_vehicles, | ||||
get_vehicle_by_id, | get_vehicle_by_id, | ||||
replace_vehicle_data, | replace_vehicle_data, | ||||
delete_vehicle_data, | |||||
) | ) | ||||
from db.models.user import User | |||||
from apis.v1.route_auth import get_current_user | |||||
router = APIRouter() | router = APIRouter() | ||||
@router.post("/", response_model=OutputVehicle, status_code=status.HTTP_201_CREATED) | @router.post("/", response_model=OutputVehicle, status_code=status.HTTP_201_CREATED) | ||||
async def create_vehicle(vehicle: CreateVehicle, db: Session = Depends(get_db)): | |||||
async def create_vehicle(vehicle: CreateVehicle, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): | |||||
if current_user.Role != "Admin": | |||||
raise HTTPException(status_code=403, detail="You are not authorized to perform this action") | |||||
vehicle = create_new_vehicle(vehicle=vehicle, db=db) | vehicle = create_new_vehicle(vehicle=vehicle, db=db) | ||||
return vehicle | return vehicle | ||||
@@ -47,7 +52,9 @@ async def create_vehicle(vehicle: CreateVehicle, db: Session = Depends(get_db)): | |||||
response_model=OutputVehicle, | response_model=OutputVehicle, | ||||
status_code=status.HTTP_200_OK, | status_code=status.HTTP_200_OK, | ||||
) | ) | ||||
async def assign_driver(vehicle_id: int, driver_id: int, db: Session = Depends(get_db)): | |||||
async def assign_driver(vehicle_id: int, driver_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): | |||||
if current_user.Role != "Admin": | |||||
raise HTTPException(status_code=403, detail="You are not authorized to perform this action") | |||||
vehicle = assign_vehicle_driver(vehicle_id=vehicle_id, driver_id=driver_id, db=db) | vehicle = assign_vehicle_driver(vehicle_id=vehicle_id, driver_id=driver_id, db=db) | ||||
if vehicle == "nodriver": | if vehicle == "nodriver": | ||||
raise HTTPException( | raise HTTPException( | ||||
@@ -85,11 +92,23 @@ async def get_vehicle(vehicle_id: int, db: Session = Depends(get_db)): | |||||
"/{vehicle_id}", response_model=OutputVehicle, status_code=status.HTTP_200_OK | "/{vehicle_id}", response_model=OutputVehicle, status_code=status.HTTP_200_OK | ||||
) | ) | ||||
def update_vehicle( | def update_vehicle( | ||||
vehicle_id: int, vehicle: UpdateVehicle, db: Session = Depends(get_db) | |||||
vehicle_id: int, vehicle: UpdateVehicle, db: Session = Depends(get_db), current_user: User = Depends(get_current_user) | |||||
): | ): | ||||
if current_user.Role != "Admin": | |||||
raise HTTPException(status_code=403, detail="You are not authorized to perform this action") | |||||
vehicleRes = replace_vehicle_data(id=vehicle_id, vehicle=vehicle, db=db) | vehicleRes = replace_vehicle_data(id=vehicle_id, vehicle=vehicle, db=db) | ||||
if vehicleRes == "vehicleNotFound": | if vehicleRes == "vehicleNotFound": | ||||
raise HTTPException(status_code=404, detail="Vehicle not found") | raise HTTPException(status_code=404, detail="Vehicle not found") | ||||
elif vehicleRes == "badreq": | elif vehicleRes == "badreq": | ||||
raise HTTPException(status_code=502, detail="Bad request") | raise HTTPException(status_code=502, detail="Bad request") | ||||
return vehicleRes | return vehicleRes | ||||
@router.delete("/{vehicle_id}", status_code=status.HTTP_200_OK) | |||||
def delete_vehicle(vehicle_id: int, db: Session = Depends(get_db), current_user: User = Depends(get_current_user)): | |||||
if current_user.Role != "Admin": | |||||
raise HTTPException(status_code=403, detail="You are not authorized to perform this action") | |||||
result = delete_vehicle_data(id=vehicle_id, db=db) | |||||
if result == "vehicleNotFound": | |||||
raise HTTPException(status_code=404, detail="Vehicle not found") | |||||
return {"msg": "Vehicle deleted successfully"} |
@@ -0,0 +1,16 @@ | |||||
from datetime import datetime, timedelta | |||||
from typing import Optional | |||||
from jose import jwt | |||||
from core.config import settings | |||||
def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): | |||||
to_encode = data.copy() | |||||
if expires_delta: | |||||
expire = datetime.utcnow() + expires_delta | |||||
else: | |||||
expire = datetime.utcnow() + timedelta(minutes=settings.ACCESS_TOKEN_EXPIRE) | |||||
to_encode.update({"exp": expire}) | |||||
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=settings.ALGORITHM) | |||||
return encoded_jwt |
@@ -1,12 +1,15 @@ | |||||
class Settings: | class Settings: | ||||
PROJECT_NAME:str = "VMS" | |||||
PROJECT_VERSION:str = "1.0.0" | |||||
POSTGRES_USER : str = "VMSBase" | |||||
PROJECT_NAME: str = "VMS" | |||||
PROJECT_VERSION: str = "1.0.0" | |||||
POSTGRES_USER: str = "VMSBase" | |||||
POSTGRES_PASSWORD = "VMSBasePass" | POSTGRES_PASSWORD = "VMSBasePass" | ||||
POSTGRES_SERVER : str = "localhost" | |||||
POSTGRES_PORT : str = "5432" | |||||
POSTGRES_DB : str = "VMSData" | |||||
POSTGRES_SERVER: str = "localhost" | |||||
POSTGRES_PORT: str = "5432" | |||||
POSTGRES_DB: str = "VMSData" | |||||
DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}" | DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}" | ||||
settings = Settings() | |||||
ACCESS_TOKEN_EXPIRE: int = 30 | |||||
SECRET_KEY: str = "tH357aC6oA7ofCaN3yTffYkRh" | |||||
ALGORITHM: str = "HS256" | |||||
settings = Settings() |
@@ -2,5 +2,3 @@ | |||||
from sqlalchemy.ext.declarative import declarative_base | from sqlalchemy.ext.declarative import declarative_base | ||||
Base = declarative_base() | Base = declarative_base() | ||||
from db.models.user import User | |||||
from db.models.vehicle import Vehicle |
@@ -1,6 +1,5 @@ | |||||
# PostgreSQL table model for users | # PostgreSQL table model for users | ||||
from sqlalchemy import Column, Integer, String, DateTime, Boolean, URL | |||||
from sqlalchemy.orm import relationship | |||||
from sqlalchemy import Column, Integer, String, DateTime | |||||
from db.base import Base | from db.base import Base | ||||
@@ -3,13 +3,8 @@ from sqlalchemy import ( | |||||
Column, | Column, | ||||
Integer, | Integer, | ||||
String, | String, | ||||
DateTime, | |||||
Boolean, | |||||
URL, | |||||
ARRAY, | ARRAY, | ||||
ForeignKey, | |||||
) | ) | ||||
from sqlalchemy.orm import relationship | |||||
from db.base import Base | from db.base import Base | ||||
@@ -28,6 +28,16 @@ def get_user_by_id(user_id: int, db: Session): | |||||
return user | return user | ||||
def get_user_by_email(email: str, db: Session): | |||||
user = db.query(User).filter(User.Email == email).first() | |||||
return user | |||||
def get_user_by_phone(phone: str, db: Session): | |||||
user = db.query(User).filter(User.ContactNumber == phone).first() | |||||
return user | |||||
def verify_driver_exists(driver_id: int, db: Session): | def verify_driver_exists(driver_id: int, db: Session): | ||||
driver = db.query(User).filter(User.Id == driver_id).first() | driver = db.query(User).filter(User.Id == driver_id).first() | ||||
if not driver: | if not driver: | ||||
@@ -1,5 +1,5 @@ | |||||
from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
from schemas.vehicle import CreateVehicle, OutputVehicle, UpdateVehicle | |||||
from schemas.vehicle import CreateVehicle, UpdateVehicle | |||||
from db.models.vehicle import Vehicle | from db.models.vehicle import Vehicle | ||||
from db.repository.user import verify_driver_exists | from db.repository.user import verify_driver_exists | ||||
@@ -27,7 +27,9 @@ def assign_vehicle_driver(vehicle_id: int, driver_id: int, db: Session): | |||||
return "alreadyassigned" | return "alreadyassigned" | ||||
if verify_driver_exists(driver_id=driver_id, db=db): | if verify_driver_exists(driver_id=driver_id, db=db): | ||||
print(vehicle.AssignedDriverIds) | print(vehicle.AssignedDriverIds) | ||||
vehicledb.update({"AssignedDriverIds": vehicle.AssignedDriverIds + [driver_id]}) | |||||
vehicledb.update( | |||||
{"AssignedDriverIds": vehicle.AssignedDriverIds + [driver_id]} | |||||
) | |||||
print(vehicle.AssignedDriverIds) | print(vehicle.AssignedDriverIds) | ||||
db.add(vehicle) | db.add(vehicle) | ||||
db.commit() | db.commit() | ||||
@@ -65,3 +67,13 @@ def replace_vehicle_data(id: int, vehicle: UpdateVehicle, db: Session): | |||||
db.add(vehicle_object) | db.add(vehicle_object) | ||||
db.commit() | db.commit() | ||||
return vehicle_object | return vehicle_object | ||||
def delete_vehicle_data(id: int, db: Session): | |||||
vehicle_db = db.query(Vehicle).filter(Vehicle.Id == id) | |||||
vehicle_object = vehicle_db.first() | |||||
if not vehicle_object: | |||||
return "vehiclenotfound" | |||||
db.delete(vehicle_object) | |||||
db.commit() | |||||
return vehicle_object |
@@ -0,0 +1,6 @@ | |||||
from pydantic import BaseModel | |||||
class Token(BaseModel): | |||||
access_token: str | |||||
token_type: str |
@@ -1,6 +1,5 @@ | |||||
from typing import Optional | from typing import Optional | ||||
from pydantic import BaseModel, root_validator | |||||
from datetime import datetime | |||||
from pydantic import BaseModel | |||||
class CreateVehicle(BaseModel): | class CreateVehicle(BaseModel): | ||||
@@ -3,4 +3,6 @@ pydantic | |||||
sqlalchemy | sqlalchemy | ||||
psycopg2 | psycopg2 | ||||
alembic==1.12.0 | alembic==1.12.0 | ||||
passlib | |||||
passlib | |||||
python-jose==3.3.0 | |||||
python-multipart==0.0.6 |