gmerrill commited on
Commit
f423eb3
·
1 Parent(s): 3e57e69
Files changed (1) hide show
  1. main.py +24 -18
main.py CHANGED
@@ -2,9 +2,13 @@ from fastapi import FastAPI, Request
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
5
  import json
6
  import torch
7
 
 
 
 
8
  def get_prompt(user_query: str, functions: list = []) -> str:
9
  """
10
  Generates a conversation prompt based on the user's query and a list of functions.
@@ -25,37 +29,38 @@ device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
25
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
26
 
27
  model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
28
- print('AutoTokenizer.from_pretrained ...')
29
  tokenizer = AutoTokenizer.from_pretrained(model_id)
30
- print('AutoModelForCausalLM.from_pretrained ...')
31
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
32
 
33
- print('mode.to(device) ...')
34
  model.to(device)
35
 
36
- print('Pipeline setup ...')
37
- pipe = pipeline(
38
- "text-generation",
39
- model=model,
40
- tokenizer=tokenizer,
41
- max_new_tokens=128,
42
- batch_size=16,
43
- torch_dtype=torch_dtype,
44
- device=device,
45
- )
46
-
47
- print('FastAPI setup ...')
48
  app = FastAPI()
49
 
50
  @app.post("/query_gorilla")
51
  async def query_gorilla(req: Request):
52
  body = await req.body()
53
  parsedBody = json.loads(body)
54
- print(parsedBody['query'])
55
- print(parsedBody['functions'])
56
 
57
- print('Generate prompt and obtain model output')
58
  prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions'])
 
 
 
 
 
 
 
 
 
 
 
 
59
  output = pipe(prompt)
60
 
61
  return {
@@ -69,3 +74,4 @@ def index() -> FileResponse:
69
  return FileResponse(path="/app/static/index.html", media_type="text/html")
70
 
71
 
 
 
2
  from fastapi.staticfiles import StaticFiles
3
  from fastapi.responses import FileResponse
4
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
5
+ import datetime
6
  import json
7
  import torch
8
 
9
+ def log(msg):
10
+ print(str(datetime.datetime.now()) + ': ' + msg)
11
+
12
  def get_prompt(user_query: str, functions: list = []) -> str:
13
  """
14
  Generates a conversation prompt based on the user's query and a list of functions.
 
29
  torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
30
 
31
  model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
32
+ log('AutoTokenizer.from_pretrained ...')
33
  tokenizer = AutoTokenizer.from_pretrained(model_id)
34
+ log('AutoModelForCausalLM.from_pretrained ...')
35
  model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
36
 
37
+ log('mode.to(device) ...')
38
  model.to(device)
39
 
40
+ log('FastAPI setup ...')
 
 
 
 
 
 
 
 
 
 
 
41
  app = FastAPI()
42
 
43
  @app.post("/query_gorilla")
44
  async def query_gorilla(req: Request):
45
  body = await req.body()
46
  parsedBody = json.loads(body)
47
+ log(parsedBody['query'])
48
+ log(parsedBody['functions'])
49
 
50
+ log('Generate prompt and obtain model output')
51
  prompt = get_prompt(parsedBody['query'], functions=parsedBody['functions'])
52
+
53
+ log('Pipeline setup ...')
54
+ pipe = pipeline(
55
+ "text-generation",
56
+ model=model,
57
+ tokenizer=tokenizer,
58
+ max_new_tokens=128,
59
+ batch_size=16,
60
+ torch_dtype=torch_dtype,
61
+ device=device,
62
+ )
63
+
64
  output = pipe(prompt)
65
 
66
  return {
 
74
  return FileResponse(path="/app/static/index.html", media_type="text/html")
75
 
76
 
77
+ log('Initialization done.')