test / app.py
Chengxb888's picture
Update app.py
47c2354 verified
raw
history blame
842 Bytes
from fastapi import FastAPI, Form
from fastapi.responses import FileResponse
from typing import Annotated
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
app = FastAPI()
@app.get("/", response_class=FileResponse)
async def root():
return "home.html"
@app.post("/hello/")
async def say_hello(msg: Annotated[str, Form()]):
print("model")
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
"google/gemma-2b-it",
device_map="auto",
torch_dtype=torch.bfloat16
)
print("token & msg")
input_ids = tokenizer(msg, return_tensors="pt").to("cpu")
print("output")
outputs = model.generate(**input_ids, max_length=500)
print("complete")
return {"message": tokenizer.decode(outputs[0])}