from fastapi import FastAPI, HTTPException from pydantic import BaseModel from util import get_client_id, get_trained_models, train_client_model, download_dataset_locally, predict_vendor_category from typing import Optional download_dataset_locally() app = FastAPI() # Models class TrainInput(BaseModel): client_id: str data: list[list[str]] ignore_value: Optional[str] = 'Need help from accountant' class PredictInput(BaseModel): client_id: str data: list[list[str]] class UserInput(BaseModel): client_name: str # Endpoints @app.get("/models") def get_models(): trained_models = get_trained_models() if len(trained_models) == 0: return {"models": trained_models, "message": "No models trained yet."} return {"models": trained_models, "message": "List of trained models."} @app.post("/create-client") def create_username(user_input: UserInput): client_name = user_input.client_name trained_models = get_trained_models() client_ids = [m['client_id'] for m in trained_models] client_id = get_client_id(client_name) if client_id in client_ids: raise HTTPException(status_code=400, detail=f"Model for {client_name}, {client_id} already exists.") return {"client_id": client_id, "message": "client created successfully."} @app.post("/train") def train_model(train_input: TrainInput): # check if client_id contains space if ' ' in train_input.client_id: raise HTTPException(status_code=400, detail="client_id cannot contain space.") # check if every entry in rows is contains exactly 4 items for row in train_input.data: if len(row) != 4: raise HTTPException(status_code=400, detail="Each row must contain exactly 4 items.") training_result = train_client_model(client_id=train_input.client_id, rows=train_input.data, ignore_value=train_input.ignore_value) return {"message": f"Model '{train_input.client_id}' trained successfully.", "result": training_result} @app.post("/predict") def predict(predict_input: PredictInput): # check if client_id contains space if ' ' in predict_input.client_id: raise HTTPException(status_code=400, detail="client_id cannot contain space.") # check if every entry in rows is contains exactly 4 items for row in predict_input.data: if len(row) != 2: raise HTTPException(status_code=400, detail="Each row must contain exactly 2 items.") predictions = predict_vendor_category(client_id=predict_input.client_id, data=predict_input.data) return {"result": predictions, 'message': 'Predictions generated successfully.' }