Spaces:
Running
Running
from fastapi import FastAPI, Request, HTTPException, WebSocket, WebSocketDisconnect | |
from fastapi.templating import Jinja2Templates | |
from fastapi.staticfiles import StaticFiles | |
from fastapi.responses import HTMLResponse | |
from pydantic import BaseModel | |
from typing import List, Optional | |
import uvicorn | |
import torch | |
from scripts.model import Net | |
from scripts.training.train import train | |
from pathlib import Path | |
from fastapi import BackgroundTasks | |
import warnings | |
warnings.filterwarnings("ignore", category=UserWarning, module="torchvision.transforms") | |
app = FastAPI() | |
# Mount static files and templates | |
app.mount("/static", StaticFiles(directory="static"), name="static") | |
templates = Jinja2Templates(directory="templates") | |
# Model configurations | |
class TrainingConfig(BaseModel): | |
block1: int | |
block2: int | |
block3: int | |
optimizer: str | |
batch_size: int | |
epochs: int = 1 | |
class ComparisonConfig(BaseModel): | |
model1: TrainingConfig | |
model2: TrainingConfig | |
def get_available_models(): | |
models_dir = Path("scripts/training/models") | |
if not models_dir.exists(): | |
models_dir.mkdir(exist_ok=True, parents=True) | |
return [f.stem for f in models_dir.glob("*.pth")] | |
# Add a global variable to store training task | |
training_task = None | |
async def home(request: Request): | |
return templates.TemplateResponse("index.html", {"request": request}) | |
async def train_page(request: Request): | |
return templates.TemplateResponse("train.html", {"request": request}) | |
async def inference_page(request: Request): | |
available_models = get_available_models() | |
return templates.TemplateResponse( | |
"inference.html", | |
{ | |
"request": request, | |
"available_models": available_models | |
} | |
) | |
async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks): | |
try: | |
# Create model instance with the configuration | |
model = Net( | |
kernels=[config.block1, config.block2, config.block3] | |
) | |
# Store training configuration | |
training_config = { | |
"optimizer": config.optimizer, | |
"batch_size": config.batch_size | |
} | |
return {"status": "success", "message": "Training configuration received"} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def websocket_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
try: | |
config_data = await websocket.receive_json() | |
model = Net( | |
kernels=[ | |
config_data['block1'], | |
config_data['block2'], | |
config_data['block3'] | |
] | |
) | |
from scripts.training.config import NetworkConfig | |
config = NetworkConfig() | |
config.update( | |
block1=config_data['block1'], | |
block2=config_data['block2'], | |
block3=config_data['block3'], | |
optimizer=config_data['optimizer'], | |
batch_size=config_data['batch_size'], | |
epochs=config_data['epochs'] | |
) | |
print(f"Starting training with config: {config_data}") | |
try: | |
# Pass "single" as model_type for single model training | |
await train(model, config, websocket, model_type="single") | |
await websocket.send_json({ | |
"type": "training_complete", | |
"data": { | |
"message": "Training completed successfully!" | |
} | |
}) | |
except Exception as e: | |
print(f"Training error: {str(e)}") | |
await websocket.send_json({ | |
"type": "training_error", | |
"data": { | |
"message": f"Training failed: {str(e)}" | |
} | |
}) | |
except WebSocketDisconnect: | |
print("WebSocket disconnected") | |
except Exception as e: | |
print(f"WebSocket error: {str(e)}") | |
finally: | |
print("WebSocket connection closed") | |
async def websocket_compare_endpoint(websocket: WebSocket): | |
await websocket.accept() | |
try: | |
data = await websocket.receive_json() | |
if data.get("type") == "start_comparison": | |
from scripts.training.config import NetworkConfig | |
# Create and train both models | |
model1_config = NetworkConfig() | |
model2_config = NetworkConfig() | |
# Update configs with received data | |
model1_config.update(**data["model1"]) | |
model2_config.update(**data["model2"]) | |
# Create models with respective configurations | |
model1 = Net( | |
kernels=[ | |
model1_config.block1, | |
model1_config.block2, | |
model1_config.block3 | |
] | |
) | |
model2 = Net( | |
kernels=[ | |
model2_config.block1, | |
model2_config.block2, | |
model2_config.block3 | |
] | |
) | |
# Train both models with appropriate model_type | |
try: | |
await train(model1, model1_config, websocket, model_type="model_1") | |
await train(model2, model2_config, websocket, model_type="model_2") | |
await websocket.send_json({ | |
"type": "comparison_complete", | |
"data": { | |
"message": "Training completed successfully!" | |
} | |
}) | |
except Exception as e: | |
print(f"Training error: {str(e)}") | |
await websocket.send_json({ | |
"type": "training_error", | |
"data": { | |
"message": f"Training failed: {str(e)}" | |
} | |
}) | |
except WebSocketDisconnect: | |
print("WebSocket disconnected") | |
except Exception as e: | |
print(f"WebSocket error: {str(e)}") | |
finally: | |
print("WebSocket connection closed") | |
# @app.post("/api/train_single") | |
# async def train_single_model(config: TrainingConfig): | |
# try: | |
# model = Net(kernels=config.kernels) | |
# # Start training without passing the websocket | |
# await train(model, config) | |
# return {"status": "success"} | |
# except Exception as e: | |
# # Log the error for debugging | |
# print(f"Error during training: {str(e)}") | |
# # Return a JSON response with the error message | |
# raise HTTPException(status_code=500, detail=f"Error during training: {str(e)}") | |
async def train_compare_models(config: ComparisonConfig): | |
try: | |
# Train both models | |
model1 = Net(kernels=config.model1.kernels) | |
model2 = Net(kernels=config.model2.kernels) | |
results1 = train(model1, config.model1) | |
results2 = train(model2, config.model2) | |
return { | |
"status": "success", | |
"model1_results": results1, | |
"model2_results": results2 | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def parse_model_filename(filename): | |
"""Extract configuration from model filename""" | |
# Example filename: single_arch_32_64_128_opt_adam_batch_64_20240322_123456.pth | |
try: | |
parts = filename.split('_') | |
# Find architecture values | |
arch_index = parts.index('arch') | |
block1 = int(parts[arch_index + 1]) | |
block2 = int(parts[arch_index + 2]) | |
block3 = int(parts[arch_index + 3]) | |
# Find optimizer | |
opt_index = parts.index('opt') | |
optimizer = parts[opt_index + 1] | |
# Find batch size | |
batch_index = parts.index('batch') | |
batch_size = int(parts[batch_index + 1]) | |
return { | |
'block1': block1, | |
'block2': block2, | |
'block3': block3, | |
'optimizer': optimizer, | |
'batch_size': batch_size | |
} | |
except Exception as e: | |
print(f"Error parsing model filename: {e}") | |
return None | |
async def perform_inference(data: dict): | |
try: | |
model_name = data.get("model_name") | |
if not model_name: | |
raise HTTPException(status_code=400, detail="No model selected") | |
model_path = Path("scripts/training/models") / f"{model_name}.pth" | |
if not model_path.exists(): | |
raise HTTPException(status_code=404, detail=f"Model not found: {model_path}") | |
# Parse model configuration from filename | |
config = parse_model_filename(model_name) | |
if not config: | |
raise HTTPException(status_code=500, detail="Could not parse model configuration") | |
# Create model with the correct configuration | |
model = Net( | |
kernels=[ | |
config['block1'], | |
config['block2'], | |
config['block3'] | |
] | |
) | |
# Load model weights | |
model.load_state_dict(torch.load(str(model_path), map_location=torch.device('cpu'), weights_only=True)) | |
model.eval() | |
# Process image data and get prediction | |
image_data = data.get("image") | |
if not image_data: | |
raise HTTPException(status_code=400, detail="No image data provided") | |
# Convert base64 image to tensor and process | |
try: | |
# Remove the data URL prefix | |
image_data = image_data.split(',')[1] | |
import base64 | |
import io | |
from PIL import Image | |
import torchvision.transforms as transforms | |
# Decode base64 to image | |
image_bytes = base64.b64decode(image_data) | |
image = Image.open(io.BytesIO(image_bytes)).convert('L') # Convert to grayscale | |
# Resize using PIL directly with LANCZOS | |
image = image.resize((28, 28), Image.LANCZOS) | |
# Preprocess image | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
# Convert to tensor and add batch dimension | |
image_tensor = transform(image).unsqueeze(0) | |
# Get prediction | |
with torch.no_grad(): | |
output = model(image_tensor) | |
prediction = output.argmax(dim=1).item() | |
# Add configuration info to response | |
return { | |
"prediction": prediction, | |
"model_config": { | |
"architecture": f"{config['block1']}-{config['block2']}-{config['block3']}", | |
"optimizer": config['optimizer'], | |
"batch_size": config['batch_size'] | |
} | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def train_single_page(request: Request): | |
return templates.TemplateResponse("train_single.html", {"request": request}) | |
async def train_compare_page(request: Request): | |
return templates.TemplateResponse("train_compare.html", {"request": request}) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=8000) | |