ReactLover commited on
Commit
91cb1e2
·
verified ·
1 Parent(s): b854c2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -5
app.py CHANGED
@@ -1,9 +1,15 @@
1
  from fastapi import FastAPI, File, UploadFile
2
- from fastapi.responses import JSONResponse
3
  from transformers import AutoProcessor, AutoModelForVision2Seq
4
  from PIL import Image
5
  import io
6
  import torch
 
 
 
 
 
 
7
 
8
  app = FastAPI()
9
 
@@ -13,17 +19,31 @@ model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instr
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  model = model.to(device)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  @app.post("/predict")
17
  async def predict_gender(file: UploadFile = File(...)):
18
  image_bytes = await file.read()
19
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
20
 
21
  prompt = "Is the person on this ID male or female?"
22
-
23
- # Prepare inputs
24
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
25
-
26
- # Generate response
27
  outputs = model.generate(**inputs, max_new_tokens=32)
28
  answer = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
29
 
 
1
  from fastapi import FastAPI, File, UploadFile
2
+ from fastapi.responses import JSONResponse, HTMLResponse
3
  from transformers import AutoProcessor, AutoModelForVision2Seq
4
  from PIL import Image
5
  import io
6
  import torch
7
+ import os
8
+
9
+ # Make sure cache is writable
10
+ os.environ["HF_HOME"] = "/app/cache"
11
+ os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
12
+ os.environ["HF_HUB_CACHE"] = "/app/cache/hub"
13
 
14
  app = FastAPI()
15
 
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model = model.to(device)
21
 
22
+ @app.get("/")
23
+ def home():
24
+ return {"message": "API is running. Use POST /predict with an image, or visit /upload to test in browser."}
25
+
26
+ @app.get("/upload", response_class=HTMLResponse)
27
+ def upload_form():
28
+ return """
29
+ <html>
30
+ <body>
31
+ <h2>Upload an ID Image</h2>
32
+ <form action="/predict" enctype="multipart/form-data" method="post">
33
+ <input name="file" type="file">
34
+ <input type="submit" value="Upload">
35
+ </form>
36
+ </body>
37
+ </html>
38
+ """
39
+
40
  @app.post("/predict")
41
  async def predict_gender(file: UploadFile = File(...)):
42
  image_bytes = await file.read()
43
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
44
 
45
  prompt = "Is the person on this ID male or female?"
 
 
46
  inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
 
 
47
  outputs = model.generate(**inputs, max_new_tokens=32)
48
  answer = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
49