Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Request, HTTPException, WebSocket, WebSocketDisconnect | |
| from fastapi.templating import Jinja2Templates | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| from typing import List, Optional | |
| import uvicorn | |
| import torch | |
| from scripts.model import Net | |
| from scripts.training.train import train | |
| import json | |
| import os | |
| from pathlib import Path | |
| import asyncio | |
| from fastapi import BackgroundTasks | |
| app = FastAPI() | |
| # Mount static files and templates | |
| app.mount("/static", StaticFiles(directory="static"), name="static") | |
| templates = Jinja2Templates(directory="templates") | |
| # Model configurations | |
| class TrainingConfig(BaseModel): | |
| block1: int | |
| block2: int | |
| block3: int | |
| optimizer: str | |
| batch_size: int | |
| epochs: int = 1 | |
| class ComparisonConfig(BaseModel): | |
| model1: TrainingConfig | |
| model2: TrainingConfig | |
| def get_available_models(): | |
| models_dir = Path("scripts/training/models") | |
| if not models_dir.exists(): | |
| models_dir.mkdir(exist_ok=True, parents=True) | |
| return [f.stem for f in models_dir.glob("*.pth")] | |
| # Add a global variable to store training task | |
| training_task = None | |
| async def home(request: Request): | |
| return templates.TemplateResponse("index.html", {"request": request}) | |
| async def train_page(request: Request): | |
| return templates.TemplateResponse("train.html", {"request": request}) | |
| async def inference_page(request: Request): | |
| available_models = get_available_models() | |
| return templates.TemplateResponse( | |
| "inference.html", | |
| { | |
| "request": request, | |
| "available_models": available_models | |
| } | |
| ) | |
| async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks): | |
| try: | |
| # Create model instance with the configuration | |
| model = Net( | |
| kernels=[config.block1, config.block2, config.block3] | |
| ) | |
| # Store training configuration | |
| training_config = { | |
| "optimizer": config.optimizer, | |
| "batch_size": config.batch_size | |
| } | |
| return {"status": "success", "message": "Training configuration received"} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def websocket_endpoint(websocket: WebSocket): | |
| await websocket.accept() | |
| try: | |
| # Wait for configuration from client | |
| config_data = await websocket.receive_json() | |
| # Create model instance with the configuration | |
| model = Net( | |
| kernels=[ | |
| config_data['block1'], | |
| config_data['block2'], | |
| config_data['block3'] | |
| ] | |
| ) | |
| # Create config object | |
| from scripts.training.config import NetworkConfig | |
| config = NetworkConfig() | |
| config.update( | |
| block1=config_data['block1'], | |
| block2=config_data['block2'], | |
| block3=config_data['block3'], | |
| optimizer=config_data['optimizer'], | |
| batch_size=config_data['batch_size'], | |
| epochs=1 | |
| ) | |
| print(f"Starting training with config: {config_data}") | |
| # Start training with websocket for real-time updates | |
| try: | |
| await train(model, config, websocket) | |
| await websocket.send_json({ | |
| "type": "training_complete", | |
| "data": { | |
| "message": "Training completed successfully!" | |
| } | |
| }) | |
| except Exception as e: | |
| print(f"Training error: {str(e)}") | |
| await websocket.send_json({ | |
| "type": "training_error", | |
| "data": { | |
| "message": f"Training failed: {str(e)}" | |
| } | |
| }) | |
| except WebSocketDisconnect: | |
| print("WebSocket disconnected") | |
| except Exception as e: | |
| print(f"WebSocket error: {str(e)}") | |
| finally: | |
| print("WebSocket connection closed") | |
| # @app.post("/api/train_single") | |
| # async def train_single_model(config: TrainingConfig): | |
| # try: | |
| # model = Net(kernels=config.kernels) | |
| # # Start training without passing the websocket | |
| # await train(model, config) | |
| # return {"status": "success"} | |
| # except Exception as e: | |
| # # Log the error for debugging | |
| # print(f"Error during training: {str(e)}") | |
| # # Return a JSON response with the error message | |
| # raise HTTPException(status_code=500, detail=f"Error during training: {str(e)}") | |
| async def train_compare_models(config: ComparisonConfig): | |
| try: | |
| # Train both models | |
| model1 = Net(kernels=config.model1.kernels) | |
| model2 = Net(kernels=config.model2.kernels) | |
| results1 = train(model1, config.model1) | |
| results2 = train(model2, config.model2) | |
| return { | |
| "status": "success", | |
| "model1_results": results1, | |
| "model2_results": results2 | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def perform_inference(data: dict): | |
| try: | |
| model_name = data.get("model_name") | |
| if not model_name: | |
| raise HTTPException(status_code=400, detail="No model selected") | |
| model_path = Path("scripts/training/models") / f"{model_name}.pth" | |
| if not model_path.exists(): | |
| raise HTTPException(status_code=404, detail=f"Model not found: {model_path}") | |
| # Load model and perform inference | |
| model = Net() | |
| model.load_state_dict(torch.load(str(model_path), map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Process image data and get prediction | |
| image_data = data.get("image") | |
| if not image_data: | |
| raise HTTPException(status_code=400, detail="No image data provided") | |
| # Convert base64 image to tensor and process | |
| try: | |
| # Remove the data URL prefix | |
| image_data = image_data.split(',')[1] | |
| import base64 | |
| import io | |
| from PIL import Image | |
| import torchvision.transforms as transforms | |
| # Decode base64 to image | |
| image_bytes = base64.b64decode(image_data) | |
| image = Image.open(io.BytesIO(image_bytes)).convert('L') # Convert to grayscale | |
| # Resize using PIL directly with LANCZOS | |
| image = image.resize((28, 28), Image.LANCZOS) | |
| # Preprocess image | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize((0.1307,), (0.3081,)) | |
| ]) | |
| # Convert to tensor and add batch dimension | |
| image_tensor = transform(image).unsqueeze(0) | |
| # Get prediction | |
| with torch.no_grad(): | |
| output = model(image_tensor) | |
| prediction = output.argmax(dim=1).item() | |
| return {"prediction": prediction} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def train_single_page(request: Request): | |
| return templates.TemplateResponse("train_single.html", {"request": request}) | |
| async def train_compare_page(request: Request): | |
| return templates.TemplateResponse("train_compare.html", {"request": request}) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |