File size: 2,286 Bytes
f99e419
02b4c7b
 
8ce0ba9
f423eb3
25b91cd
8ce0ba9
02b4c7b
f423eb3
1780c8d
f423eb3
759408c
 
 
02b4c7b
759408c
 
 
02b4c7b
759408c
 
 
 
 
 
 
0bc8a9d
 
 
 
 
f423eb3
0bc8a9d
f423eb3
0bc8a9d
 
f423eb3
0bc8a9d
 
f423eb3
759408c
 
 
 
 
 
f423eb3
 
759408c
f423eb3
77485e2
f423eb3
 
 
 
 
 
 
 
 
 
 
 
759408c
 
 
 
 
 
 
 
 
 
 
 
0bc8a9d
f423eb3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import datetime
import json
import torch

def log(msg):
    print(str(datetime.datetime.now()) + ': ' + str(msg))

def get_prompt(user_query: str, functions: list = []) -> str:
    """
    Generates a conversation prompt based on the user's query and a list of functions.

    Parameters:
    - user_query (str): The user's query.
    - functions (list): A list of functions to include in the prompt.

    Returns:
    - str: The formatted conversation prompt.
    """
    if len(functions) == 0:
        return f"USER: <<question>> {user_query}\nASSISTANT: "
    functions_string = json.dumps(functions)
    return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "

device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
log('AutoTokenizer.from_pretrained ...')
tokenizer = AutoTokenizer.from_pretrained(model_id)
log('AutoModelForCausalLM.from_pretrained ...')
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)

log('mode.to(device) ...')
model.to(device)

log('FastAPI setup ...')
app = FastAPI()

@app.post("/query_gorilla")
async def query_gorilla(req: Request):
    body = await req.body()
    parsedBody = json.loads(body)
    log(parsedBody['query'])
    log(parsedBody['functions'])

    log('Generate prompt and obtain model output')
    prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions'])

    log('Pipeline setup ...')
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=128,
        batch_size=16,
        torch_dtype=torch_dtype,
        device=device,
    )

    output = pipe(prompt)
    
    return {
      "val": output
    }

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")


log('Initialization done.')