Shilpaj commited on
Commit
0d84fb8
·
1 Parent(s): 4c1a791

Feat: Python server for the project

Browse files
Files changed (1) hide show
  1. app.py +236 -0
app.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request, HTTPException, WebSocket, WebSocketDisconnect
2
+ from fastapi.templating import Jinja2Templates
3
+ from fastapi.staticfiles import StaticFiles
4
+ from fastapi.responses import HTMLResponse
5
+ from pydantic import BaseModel
6
+ from typing import List, Optional
7
+ import uvicorn
8
+ import torch
9
+ from scripts.model import Net
10
+ from scripts.training.train import train
11
+ import json
12
+ import os
13
+ from pathlib import Path
14
+ import asyncio
15
+ from fastapi import BackgroundTasks
16
+
17
+ app = FastAPI()
18
+
19
+ # Mount static files and templates
20
+ app.mount("/static", StaticFiles(directory="static"), name="static")
21
+ templates = Jinja2Templates(directory="templates")
22
+
23
+ # Model configurations
24
+ class TrainingConfig(BaseModel):
25
+ block1: int
26
+ block2: int
27
+ block3: int
28
+ optimizer: str
29
+ batch_size: int
30
+ epochs: int = 1
31
+
32
+ class ComparisonConfig(BaseModel):
33
+ model1: TrainingConfig
34
+ model2: TrainingConfig
35
+
36
+ def get_available_models():
37
+ models_dir = Path("scripts/training/models")
38
+ if not models_dir.exists():
39
+ models_dir.mkdir(exist_ok=True, parents=True)
40
+ return [f.stem for f in models_dir.glob("*.pth")]
41
+
42
+ # Add a global variable to store training task
43
+ training_task = None
44
+
45
+ @app.get("/", response_class=HTMLResponse)
46
+ async def home(request: Request):
47
+ return templates.TemplateResponse("index.html", {"request": request})
48
+
49
+ @app.get("/train", response_class=HTMLResponse)
50
+ async def train_page(request: Request):
51
+ return templates.TemplateResponse("train.html", {"request": request})
52
+
53
+ @app.get("/inference", response_class=HTMLResponse)
54
+ async def inference_page(request: Request):
55
+ available_models = get_available_models()
56
+ return templates.TemplateResponse(
57
+ "inference.html",
58
+ {
59
+ "request": request,
60
+ "available_models": available_models
61
+ }
62
+ )
63
+
64
+ @app.post("/train")
65
+ async def train_model(config: TrainingConfig, background_tasks: BackgroundTasks):
66
+ try:
67
+ # Create model instance with the configuration
68
+ model = Net(
69
+ kernels=[config.block1, config.block2, config.block3]
70
+ )
71
+
72
+ # Store training configuration
73
+ training_config = {
74
+ "optimizer": config.optimizer,
75
+ "batch_size": config.batch_size
76
+ }
77
+
78
+ return {"status": "success", "message": "Training configuration received"}
79
+ except Exception as e:
80
+ raise HTTPException(status_code=500, detail=str(e))
81
+
82
+ @app.websocket("/ws/train")
83
+ async def websocket_endpoint(websocket: WebSocket):
84
+ await websocket.accept()
85
+ try:
86
+ # Wait for configuration from client
87
+ config_data = await websocket.receive_json()
88
+
89
+ # Create model instance with the configuration
90
+ model = Net(
91
+ kernels=[
92
+ config_data['block1'],
93
+ config_data['block2'],
94
+ config_data['block3']
95
+ ]
96
+ )
97
+
98
+ # Create config object
99
+ from scripts.training.config import NetworkConfig
100
+ config = NetworkConfig()
101
+ config.update(
102
+ block1=config_data['block1'],
103
+ block2=config_data['block2'],
104
+ block3=config_data['block3'],
105
+ optimizer=config_data['optimizer'],
106
+ batch_size=config_data['batch_size'],
107
+ epochs=1
108
+ )
109
+
110
+ print(f"Starting training with config: {config_data}")
111
+
112
+ # Start training with websocket for real-time updates
113
+ try:
114
+ await train(model, config, websocket)
115
+ await websocket.send_json({
116
+ "type": "training_complete",
117
+ "data": {
118
+ "message": "Training completed successfully!"
119
+ }
120
+ })
121
+ except Exception as e:
122
+ print(f"Training error: {str(e)}")
123
+ await websocket.send_json({
124
+ "type": "training_error",
125
+ "data": {
126
+ "message": f"Training failed: {str(e)}"
127
+ }
128
+ })
129
+
130
+ except WebSocketDisconnect:
131
+ print("WebSocket disconnected")
132
+ except Exception as e:
133
+ print(f"WebSocket error: {str(e)}")
134
+ finally:
135
+ print("WebSocket connection closed")
136
+
137
+ # @app.post("/api/train_single")
138
+ # async def train_single_model(config: TrainingConfig):
139
+ # try:
140
+ # model = Net(kernels=config.kernels)
141
+ # # Start training without passing the websocket
142
+ # await train(model, config)
143
+ # return {"status": "success"}
144
+ # except Exception as e:
145
+ # # Log the error for debugging
146
+ # print(f"Error during training: {str(e)}")
147
+ # # Return a JSON response with the error message
148
+ # raise HTTPException(status_code=500, detail=f"Error during training: {str(e)}")
149
+
150
+ @app.post("/api/train_compare")
151
+ async def train_compare_models(config: ComparisonConfig):
152
+ try:
153
+ # Train both models
154
+ model1 = Net(kernels=config.model1.kernels)
155
+ model2 = Net(kernels=config.model2.kernels)
156
+
157
+ results1 = train(model1, config.model1)
158
+ results2 = train(model2, config.model2)
159
+
160
+ return {
161
+ "status": "success",
162
+ "model1_results": results1,
163
+ "model2_results": results2
164
+ }
165
+ except Exception as e:
166
+ raise HTTPException(status_code=500, detail=str(e))
167
+
168
+ @app.post("/api/inference")
169
+ async def perform_inference(data: dict):
170
+ try:
171
+ model_name = data.get("model_name")
172
+ if not model_name:
173
+ raise HTTPException(status_code=400, detail="No model selected")
174
+
175
+ model_path = Path("scripts/training/models") / f"{model_name}.pth"
176
+ if not model_path.exists():
177
+ raise HTTPException(status_code=404, detail=f"Model not found: {model_path}")
178
+
179
+ # Load model and perform inference
180
+ model = Net()
181
+ model.load_state_dict(torch.load(str(model_path), map_location=torch.device('cpu')))
182
+ model.eval()
183
+
184
+ # Process image data and get prediction
185
+ image_data = data.get("image")
186
+ if not image_data:
187
+ raise HTTPException(status_code=400, detail="No image data provided")
188
+
189
+ # Convert base64 image to tensor and process
190
+ try:
191
+ # Remove the data URL prefix
192
+ image_data = image_data.split(',')[1]
193
+ import base64
194
+ import io
195
+ from PIL import Image
196
+ import torchvision.transforms as transforms
197
+
198
+ # Decode base64 to image
199
+ image_bytes = base64.b64decode(image_data)
200
+ image = Image.open(io.BytesIO(image_bytes)).convert('L') # Convert to grayscale
201
+
202
+ # Resize using PIL directly with LANCZOS
203
+ image = image.resize((28, 28), Image.LANCZOS)
204
+
205
+ # Preprocess image
206
+ transform = transforms.Compose([
207
+ transforms.ToTensor(),
208
+ transforms.Normalize((0.1307,), (0.3081,))
209
+ ])
210
+
211
+ # Convert to tensor and add batch dimension
212
+ image_tensor = transform(image).unsqueeze(0)
213
+
214
+ # Get prediction
215
+ with torch.no_grad():
216
+ output = model(image_tensor)
217
+ prediction = output.argmax(dim=1).item()
218
+
219
+ return {"prediction": prediction}
220
+
221
+ except Exception as e:
222
+ raise HTTPException(status_code=500, detail=f"Error processing image: {str(e)}")
223
+
224
+ except Exception as e:
225
+ raise HTTPException(status_code=500, detail=str(e))
226
+
227
+ @app.get("/train/single", response_class=HTMLResponse)
228
+ async def train_single_page(request: Request):
229
+ return templates.TemplateResponse("train_single.html", {"request": request})
230
+
231
+ @app.get("/train/compare", response_class=HTMLResponse)
232
+ async def train_compare_page(request: Request):
233
+ return templates.TemplateResponse("train_compare.html", {"request": request})
234
+
235
+ if __name__ == "__main__":
236
+ uvicorn.run(app, host="0.0.0.0", port=8000)