gmerrill commited on
Commit
759408c
·
1 Parent(s): 4c215e2
Files changed (1) hide show
  1. main.py +42 -21
main.py CHANGED
@@ -5,38 +5,35 @@ import json
5
 
6
  from transformers import pipeline
7
 
8
- app = FastAPI()
 
 
9
 
10
- @app.post("/query_gorilla")
11
- async def query_gorilla(req: Request):
12
- body = await req.body()
13
- parsedBody = json.loads(body)
14
- print(parsedBody['query'])
15
- print(parsedBody['functions'])
16
- return {
17
- "val": body
18
- }
19
 
20
- app.mount("/", StaticFiles(directory="static", html=True), name="static")
21
-
22
- @app.get("/")
23
- def index() -> FileResponse:
24
- return FileResponse(path="/app/static/index.html", media_type="text/html")
 
 
25
 
26
- TODO = '''
27
- print('Device setup')
28
  device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
29
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
 
31
- print('Model and tokenizer setup')
32
  model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
 
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
 
34
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
35
 
36
- print('Move model to device')
37
  model.to(device)
38
 
39
- print('Pipeline setup')
40
  pipe = pipeline(
41
  "text-generation",
42
  model=model,
@@ -46,5 +43,29 @@ pipe = pipeline(
46
  torch_dtype=torch_dtype,
47
  device=device,
48
  )
49
- '''
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
 
5
 
6
  from transformers import pipeline
7
 
8
+ def get_prompt(user_query: str, functions: list = []) -> str:
9
+ """
10
+ Generates a conversation prompt based on the user's query and a list of functions.
11
 
12
+ Parameters:
13
+ - user_query (str): The user's query.
14
+ - functions (list): A list of functions to include in the prompt.
 
 
 
 
 
 
15
 
16
+ Returns:
17
+ - str: The formatted conversation prompt.
18
+ """
19
+ if len(functions) == 0:
20
+ return f"USER: <<question>> {user_query}\nASSISTANT: "
21
+ functions_string = json.dumps(functions)
22
+ return f"USER: <<question>> {user_query} <<function>> {functions_string}\nASSISTANT: "
23
 
 
 
24
  device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
25
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
 
 
27
  model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
28
+ print('AutoTokenizer.from_pretrained ...')
29
  tokenizer = AutoTokenizer.from_pretrained(model_id)
30
+ print('AutoModelForCausalLM.from_pretrained ...')
31
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
32
 
33
+ print('mode.to(device) ...')
34
  model.to(device)
35
 
36
+ print('Pipeline setup ...')
37
  pipe = pipeline(
38
  "text-generation",
39
  model=model,
 
43
  torch_dtype=torch_dtype,
44
  device=device,
45
  )
46
+
47
+ print('FastAPI setup ...')
48
+ app = FastAPI()
49
+
50
+ @app.post("/query_gorilla")
51
+ async def query_gorilla(req: Request):
52
+ body = await req.body()
53
+ parsedBody = json.loads(body)
54
+ print(parsedBody['query'])
55
+ print(parsedBody['functions'])
56
+
57
+ print('Generate prompt and obtain model output')
58
+ prompt = get_prompt(query, functions=functions)
59
+ output = pipe(prompt)
60
+
61
+ return {
62
+ "val": output
63
+ }
64
+
65
+ app.mount("/", StaticFiles(directory="static", html=True), name="static")
66
+
67
+ @app.get("/")
68
+ def index() -> FileResponse:
69
+ return FileResponse(path="/app/static/index.html", media_type="text/html")
70
+
71