test / app.py
Chengxb888's picture
Update app.py
6c88388 verified
raw
history blame
1.32 kB
from fastapi import FastAPI, Form
from fastapi.responses import FileResponse
from typing import Annotated
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
app = FastAPI()
@app.get("/", response_class=FileResponse)
async def root():
return "home.html"
@app.post("/hello/")
def say_hello(msg: Annotated[str, Form()]):
print("model")
from transformers import AutoModelForCausalLM, AutoTokenizer
checkpoint = "HuggingFaceTB/SmolLM-1.7B-Instruct"
device = "cpu" # for GPU usage or "cpu" for CPU usage
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
# for multiple GPUs install accelerate and do `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")`
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
messages = [{"role": "user", "content": "things about elasticsearch"}]
input_text=tokenizer.apply_chat_template(messages, tokenize=False)
print(input_text)
input_ids = tokenizer(msg, return_tensors="pt").to("cpu")
print("output")
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
outputs = model.generate(inputs, max_new_tokens=32, temperature=0.6, top_p=0.92, do_sample=True)
print("complete")
return {"message": tokenizer.decode(outputs[0])}