gorilla-test2 / main.py
gmerrill
update
77485e2
raw
history blame
2.14 kB
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import json
import torch
def get_prompt(user_query: str, functions: list = []) -> str:
"""
Generates a conversation prompt based on the user's query and a list of functions.
Parameters:
- user_query (str): The user's query.
- functions (list): A list of functions to include in the prompt.
Returns:
- str: The formatted conversation prompt.
"""
if len(functions) == 0:
return f"USER: <<question>> {user_query}\nASSISTANT: "
functions_string = json.dumps(functions)
return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "
device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
print('AutoTokenizer.from_pretrained ...')
tokenizer = AutoTokenizer.from_pretrained(model_id)
print('AutoModelForCausalLM.from_pretrained ...')
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
print('mode.to(device) ...')
model.to(device)
print('Pipeline setup ...')
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=128,
batch_size=16,
torch_dtype=torch_dtype,
device=device,
)
print('FastAPI setup ...')
app = FastAPI()
@app.post("/query_gorilla")
async def query_gorilla(req: Request):
body = await req.body()
parsedBody = json.loads(body)
print(parsedBody['query'])
print(parsedBody['functions'])
print('Generate prompt and obtain model output')
prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions'])
output = pipe(prompt)
return {
"val": output
}
app.mount("/", StaticFiles(directory="static", html=True), name="static")
@app.get("/")
def index() -> FileResponse:
return FileResponse(path="/app/static/index.html", media_type="text/html")