Gemma3n / app.py
ReactLover's picture
Update app.py
91cb1e2 verified
raw
history blame
1.69 kB
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import JSONResponse, HTMLResponse
from transformers import AutoProcessor, AutoModelForVision2Seq
from PIL import Image
import io
import torch
import os
# Make sure cache is writable
os.environ["HF_HOME"] = "/app/cache"
os.environ["TRANSFORMERS_CACHE"] = "/app/cache"
os.environ["HF_HUB_CACHE"] = "/app/cache/hub"
app = FastAPI()
# Load model and processor
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-256M-Instruct")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
@app.get("/")
def home():
return {"message": "API is running. Use POST /predict with an image, or visit /upload to test in browser."}
@app.get("/upload", response_class=HTMLResponse)
def upload_form():
return """
<html>
<body>
<h2>Upload an ID Image</h2>
<form action="/predict" enctype="multipart/form-data" method="post">
<input name="file" type="file">
<input type="submit" value="Upload">
</form>
</body>
</html>
"""
@app.post("/predict")
async def predict_gender(file: UploadFile = File(...)):
image_bytes = await file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
prompt = "Is the person on this ID male or female?"
inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_new_tokens=32)
answer = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
return JSONResponse({"gender": answer})