from fastapi import FastAPI, Request from fastapi.staticfiles import StaticFiles from fastapi.responses import FileResponse from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import datetime import json import subprocess import torch def log(msg): print(str(datetime.datetime.now()) + ': ' + str(msg), flush=True) def get_prompt(user_query: str, functions: list = []) -> str: """ Generates a conversation prompt based on the user's query and a list of functions. Parameters: - user_query (str): The user's query. - functions (list): A list of functions to include in the prompt. Returns: - str: The formatted conversation prompt. """ if len(functions) == 0: return f"USER: <> {user_query}\nASSISTANT: " functions_string = json.dumps(functions) return f"USER: <> {user_query} <> {functions_string}\nASSISTANT: " device : str = "cuda:0" if torch.cuda.is_available() else "cpu" log('Device: ' + device) torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 result = subprocess.run('cat /etc/os-release && pwd && ls -lH && find /.cache/huggingface/hub && find /.cache/gorilla', shell=True, capture_output=True, text=True) log(result.stdout) model_id : str = "gorilla-llm/gorilla-openfunctions-v1" log('AutoTokenizer.from_pretrained ...') tokenizer = AutoTokenizer.from_pretrained(model_id) log('AutoModelForCausalLM.from_pretrained ...') model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True) result = subprocess.run('pwd && ls -lH && find /.cache/huggingface/hub && find /.cache/gorilla', shell=True, capture_output=True, text=True) log(result.stdout) log('model.to(device) ...') model.to(device) log('FastAPI setup ...') app = FastAPI() @app.post("/query_gorilla") async def query_gorilla(req: Request): body = await req.body() parsedBody = json.loads(body) log(parsedBody['query']) log(parsedBody['functions']) log('Generate prompt and obtain model output') prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions']) log('Pipeline setup ...') pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=128, batch_size=16, torch_dtype=torch_dtype, device=device, ) log('Get answer ...') output = pipe(prompt) return { "val": output } app.mount("/", StaticFiles(directory="static", html=True), name="static") @app.get("/") def index() -> FileResponse: return FileResponse(path="/app/static/index.html", media_type="text/html") log('Initialization done.')