test / app.py
Chengxb888's picture
Update app.py
1cd6700 verified
raw
history blame
1.3 kB
from fastapi import FastAPI, Form
from fastapi.responses import FileResponse
from typing import Annotated
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
app = FastAPI()
@app.get("/", response_class=FileResponse)
async def root():
return "home.html"
@app.post("/hello/")
def say_hello(msg: Annotated[str, Form()]):
print("model")
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
device = "cpu" # for GPU usage or "cpu" for CPU usage
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# for multiple GPUs install accelerate and do `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")`
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
messages = [{"role": "user", "content": msg}]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
print(input_text)
input_ids = tokenizer(msg, return_tensors="pt").to("cpu")
print("output")
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=256, temperature=0.6, top_p=0.92, do_sample=True)
print("complete")
return {"message": tokenizer.decode(outputs[0])}