test / app.py
Chengxb888's picture
Update app.py
c6b6d52 verified
raw
history blame
840 Bytes
from fastapi import FastAPIForm
from fastapi.responses import FileResponse
from typing import Annotated
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
app = FastAPI()
@app.get("/", response_class=FileResponse)
async def root():
return "home.html"
@app.post("/hello/")
async def say_hello(msg: Annotated[str, Form()]):
print("model")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b-it",
device_map="auto",
torch_dtype=torch.bfloat16
)
print("token & msg")
input_ids = tokenizer(msg, return_tensors="pt").to("cpu")
print("output")
outputs = model.generate(**input_ids, max_length=500)
print("complete")
return {"message": tokenizer.decode(outputs[0])}