@@ -3,12 +3,13 @@ from sqlalchemy.orm import Session | |||||
from fastapi import Depends | from fastapi import Depends | ||||
from typing import List | from typing import List | ||||
from db.session import get_db | from db.session import get_db | ||||
from schemas.vehicle import OutputVehicle, CreateVehicle | |||||
from schemas.vehicle import OutputVehicle, CreateVehicle, UpdateVehicle | |||||
from db.repository.vehicle import ( | from db.repository.vehicle import ( | ||||
create_new_vehicle, | create_new_vehicle, | ||||
assign_vehicle_driver, | assign_vehicle_driver, | ||||
list_vehicles, | list_vehicles, | ||||
get_vehicle_by_id, | get_vehicle_by_id, | ||||
replace_vehicle_data, | |||||
) | ) | ||||
router = APIRouter() | router = APIRouter() | ||||
@@ -78,3 +79,17 @@ async def get_vehicle(vehicle_id: int, db: Session = Depends(get_db)): | |||||
if not vehicle: | if not vehicle: | ||||
raise HTTPException(status_code=404, detail="Vehicle not found") | raise HTTPException(status_code=404, detail="Vehicle not found") | ||||
return vehicle | return vehicle | ||||
@router.put( | |||||
"/{vehicle_id}", response_model=OutputVehicle, status_code=status.HTTP_200_OK | |||||
) | |||||
def update_vehicle( | |||||
vehicle_id: int, vehicle: UpdateVehicle, db: Session = Depends(get_db) | |||||
): | |||||
vehicleRes = replace_vehicle_data(id=vehicle_id, vehicle=vehicle, db=db) | |||||
if vehicleRes == "vehicleNotFound": | |||||
raise HTTPException(status_code=404, detail="Vehicle not found") | |||||
elif vehicleRes == "badreq": | |||||
raise HTTPException(status_code=502, detail="Bad request") | |||||
return vehicleRes |
@@ -1,5 +1,5 @@ | |||||
from sqlalchemy.orm import Session | from sqlalchemy.orm import Session | ||||
from schemas.vehicle import CreateVehicle, OutputVehicle | |||||
from schemas.vehicle import CreateVehicle, OutputVehicle, UpdateVehicle | |||||
from db.models.vehicle import Vehicle | from db.models.vehicle import Vehicle | ||||
from db.repository.user import verify_driver_exists | from db.repository.user import verify_driver_exists | ||||
@@ -45,3 +45,23 @@ def list_vehicles(db: Session): | |||||
def get_vehicle_by_id(vehicle_id: int, db: Session): | def get_vehicle_by_id(vehicle_id: int, db: Session): | ||||
vehicle = db.query(Vehicle).filter(Vehicle.Id == vehicle_id).first() | vehicle = db.query(Vehicle).filter(Vehicle.Id == vehicle_id).first() | ||||
return vehicle | return vehicle | ||||
def replace_vehicle_data(id: int, vehicle: UpdateVehicle, db: Session): | |||||
vehicle_db = db.query(Vehicle).filter(Vehicle.Id == id) | |||||
vehicle_object = vehicle_db.first() | |||||
if not vehicle_object: | |||||
return "vehiclenotfound" | |||||
vehicle_object.AssignedDriverIds = vehicle.AssignedDriverIds | |||||
vehicle_object.CurrentLocation = vehicle.CurrentLocation | |||||
vehicle_object.Fuel = vehicle.Fuel | |||||
vehicle_object.LicensePlate = vehicle.LicensePlate | |||||
vehicle_object.MaintenanceNotes = vehicle.MaintenanceNotes | |||||
vehicle_object.Mileage = vehicle.Mileage | |||||
vehicle_object.Model = vehicle.Model | |||||
vehicle_object.Type = vehicle.Type | |||||
vehicle_object.Year = vehicle.Year | |||||
print(vehicle_object) | |||||
db.add(vehicle_object) | |||||
db.commit() | |||||
return vehicle_object |
@@ -23,3 +23,15 @@ class OutputVehicle(BaseModel): | |||||
Fuel: Optional[int] = 0 | Fuel: Optional[int] = 0 | ||||
MaintenanceNotes: Optional[list[str]] = None | MaintenanceNotes: Optional[list[str]] = None | ||||
AssignedDriverIds: Optional[list[int]] = None | AssignedDriverIds: Optional[list[int]] = None | ||||
class UpdateVehicle(BaseModel): | |||||
Model: str | |||||
Year: int | |||||
LicensePlate: str | |||||
Type: str | |||||
Mileage: int | |||||
CurrentLocation: Optional[list[str]] = None | |||||
Fuel: Optional[int] = 0 | |||||
MaintenanceNotes: Optional[list[str]] = None | |||||
AssignedDriverIds: Optional[list[int]] = None |