ClemSummer commited on
Commit
2b3d156
Β·
1 Parent(s): 4bf21ae

revert back to non-lazy one

Browse files
Files changed (1) hide show
  1. main.py +42 -19
main.py CHANGED
@@ -1,21 +1,24 @@
1
- from fastapi import FastAPI
 
 
2
  from fastapi.responses import HTMLResponse
 
 
 
3
  import uvicorn
 
4
  from vit_captioning.generate import CaptionGenerator
5
 
6
  app = FastAPI()
7
- caption_generator = None # Lazy-load placeholder
8
-
9
- @app.on_event("startup")
10
- def startup_event():
11
- global caption_generator
12
- if caption_generator is None:
13
- print("Loading CaptionGenerator...")
14
- caption_generator = CaptionGenerator(
15
- model_type="CLIPEncoder",
16
- checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
17
- quantized=False
18
- )
19
 
20
  @app.get("/", response_class=HTMLResponse)
21
  def root():
@@ -25,12 +28,32 @@ def root():
25
  def health_check():
26
  return {"status": "ok"}
27
 
28
- # Example endpoint to trigger model
29
- @app.get("/caption")
30
- def caption():
31
- if caption_generator is None:
32
- return {"error": "Model not loaded"}
33
- return {"result": "dummy caption"} # Replace with real logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # if __name__ == "__main__":
36
  # uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
+ # app/main.py
2
+
3
+ from fastapi import FastAPI, UploadFile, File
4
  from fastapi.responses import HTMLResponse
5
+ from fastapi.staticfiles import StaticFiles
6
+ import shutil
7
+ from pathlib import Path
8
  import uvicorn
9
+
10
  from vit_captioning.generate import CaptionGenerator
11
 
12
  app = FastAPI()
13
+
14
+ # Serve static files
15
+ static_dir = Path(__file__).parent / "vit_captioning" / "static"
16
+ app.mount("/static", StaticFiles(directory=static_dir), name="static")
17
+
18
+ # βœ… Landing page at `/`
19
+ # @app.get("/", response_class=HTMLResponse)
20
+ # async def landing():
21
+ # return Path("vit_captioning/static/landing.html").read_text()
 
 
 
22
 
23
  @app.get("/", response_class=HTMLResponse)
24
  def root():
 
28
  def health_check():
29
  return {"status": "ok"}
30
 
31
+ # βœ… Captioning page at `/captioning`
32
+ @app.get("/captioning", response_class=HTMLResponse)
33
+ async def captioning():
34
+ return Path("vit_captioning/static/captioning/index.html").read_text()
35
+
36
+ # βœ… Example: Project 2 placeholder
37
+ @app.get("/project2", response_class=HTMLResponse)
38
+ async def project2():
39
+ return "<h1>Coming Soon: Project 2</h1>"
40
+
41
+ # βœ… Caption generation endpoint for captioning app
42
+ # Keep the path consistent with your JS fetch()!
43
+ caption_generator = CaptionGenerator(
44
+ model_type="CLIPEncoder",
45
+ checkpoint_path="./vit_captioning/artifacts/CLIPEncoder_40epochs_unfreeze12.pth",
46
+ quantized=False
47
+ )
48
+
49
+ @app.post("/generate")
50
+ async def generate(file: UploadFile = File(...)):
51
+ temp_file = f"temp_{file.filename}"
52
+ with open(temp_file, "wb") as buffer:
53
+ shutil.copyfileobj(file.file, buffer)
54
+
55
+ captions = caption_generator.generate_caption(temp_file)
56
+ return captions
57
 
58
  # if __name__ == "__main__":
59
  # uvicorn.run(app, host="0.0.0.0", port=8000)