File size: 2,536 Bytes
f99e419
02b4c7b
 
8ce0ba9
f423eb3
25b91cd
bd8563e
8ce0ba9
02b4c7b
f423eb3
13dc3e9
f423eb3
759408c
 
 
02b4c7b
759408c
 
 
02b4c7b
759408c
 
 
 
 
 
 
0bc8a9d
 
 
 
6edf48b
 
bd8563e
0bc8a9d
f423eb3
0bc8a9d
f423eb3
0bc8a9d
 
6edf48b
 
bd8563e
f423eb3
0bc8a9d
 
f423eb3
759408c
 
 
 
 
 
f423eb3
 
759408c
f423eb3
77485e2
f423eb3
 
 
 
 
 
 
 
 
 
 
 
759408c
 
 
 
 
 
 
 
 
 
 
 
0bc8a9d
f423eb3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from fastapi import FastAPI, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import datetime
import json
import subprocess
import torch

def log(msg):
    print(str(datetime.datetime.now()) + ': ' + str(msg), flush=True)

def get_prompt(user_query: str, functions: list = []) -> str:
    """
    Generates a conversation prompt based on the user's query and a list of functions.

    Parameters:
    - user_query (str): The user's query.
    - functions (list): A list of functions to include in the prompt.

    Returns:
    - str: The formatted conversation prompt.
    """
    if len(functions) == 0:
        return f"USER: <<question>> {user_query}\nASSISTANT: "
    functions_string = json.dumps(functions)
    return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "

device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

result = subprocess.run('find /', shell=True, capture_output=True, text=True)
log('Files: ' + result.stdout)

model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
log('AutoTokenizer.from_pretrained ...')
tokenizer = AutoTokenizer.from_pretrained(model_id)
log('AutoModelForCausalLM.from_pretrained ...')
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)

result = subprocess.run('find /', shell=True, capture_output=True, text=True)
log('Files: ' + result.stdout)

log('mode.to(device) ...')
model.to(device)

log('FastAPI setup ...')
app = FastAPI()

@app.post("/query_gorilla")
async def query_gorilla(req: Request):
    body = await req.body()
    parsedBody = json.loads(body)
    log(parsedBody['query'])
    log(parsedBody['functions'])

    log('Generate prompt and obtain model output')
    prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions'])

    log('Pipeline setup ...')
    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_new_tokens=128,
        batch_size=16,
        torch_dtype=torch_dtype,
        device=device,
    )

    output = pipe(prompt)
    
    return {
      "val": output
    }

app.mount("/", StaticFiles(directory="static", html=True), name="static")

@app.get("/")
def index() -> FileResponse:
    return FileResponse(path="/app/static/index.html", media_type="text/html")


log('Initialization done.')