gmerrill commited on
Commit
0bc8a9d
·
1 Parent(s): 7c3f842
Files changed (3) hide show
  1. main.py +28 -4
  2. static/index.html +18 -2
  3. static/script.js +7 -1
main.py CHANGED
@@ -6,14 +6,38 @@ from transformers import pipeline
6
 
7
  app = FastAPI()
8
 
9
- pipe_flan = pipeline("text2text-generation", model="google/flan-t5-small")
10
-
11
- @app.get("/query_gorilla")
12
  def query_gorilla(input):
13
- return {"output": "Test Result"}
14
 
15
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
16
 
17
  @app.get("/")
18
  def index() -> FileResponse:
19
  return FileResponse(path="/app/static/index.html", media_type="text/html")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  app = FastAPI()
8
 
9
+ @app.post("/query_gorilla")
 
 
10
  def query_gorilla(input):
11
+ return input
12
 
13
  app.mount("/", StaticFiles(directory="static", html=True), name="static")
14
 
15
  @app.get("/")
16
  def index() -> FileResponse:
17
  return FileResponse(path="/app/static/index.html", media_type="text/html")
18
+
19
+ TODO = '''
20
+ print('Device setup')
21
+ device : str = "cuda:0" if torch.cuda.is_available() else "cpu"
22
+ torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
23
+
24
+ print('Model and tokenizer setup')
25
+ model_id : str = "gorilla-llm/gorilla-openfunctions-v1"
26
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
27
+ model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch_dtype, low_cpu_mem_usage=True)
28
+
29
+ print('Move model to device')
30
+ model.to(device)
31
+
32
+ print('Pipeline setup')
33
+ pipe = pipeline(
34
+ "text-generation",
35
+ model=model,
36
+ tokenizer=tokenizer,
37
+ max_new_tokens=128,
38
+ batch_size=16,
39
+ torch_dtype=torch_dtype,
40
+ device=device,
41
+ )
42
+ '''
43
+
static/index.html CHANGED
@@ -15,14 +15,30 @@
15
  Prompt:<br/>
16
  <textarea
17
  id="text-gen-input"
18
- type="text">Sample</textarea>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  <br/>
20
  <button id="text-gen-submit">Submit</button>
21
 
22
  <p/>
23
  Result:<br/>
24
  <p class="text-gen-output">
25
- axyz
26
  </p>
27
 
28
  </form>
 
15
  Prompt:<br/>
16
  <textarea
17
  id="text-gen-input"
18
+ type="text">
19
+
20
+ {
21
+ query: "Call me an Uber ride type \"Plus\" in Berkeley at zipcode 94704 in 10 minutes",
22
+ function: [
23
+ {
24
+ "name": "Uber Carpool",
25
+ "api_name": "uber.ride",
26
+ "description": "Find suitable ride for customers given the location, type of ride, and the amount of time the customer is willing to wait as parameters",
27
+ "parameters": [
28
+ {"name": "loc", "description": "Location of the starting place of the Uber ride"},
29
+ {"name": "type", "enum": ["plus", "comfort", "black"], "description": "Types of Uber ride user is ordering"},
30
+ {"name": "time", "description": "The amount of time in minutes the customer is willing to wait"}
31
+ ]
32
+ }
33
+ ]
34
+
35
+ </textarea>
36
  <br/>
37
  <button id="text-gen-submit">Submit</button>
38
 
39
  <p/>
40
  Result:<br/>
41
  <p class="text-gen-output">
 
42
  </p>
43
 
44
  </form>
static/script.js CHANGED
@@ -1,7 +1,13 @@
1
  const textGenForm = document.querySelector('.text-gen-form');
2
 
3
  const translateText = async (text) => {
4
- const inferResponse = await fetch(`query_gorilla?input=${text}`);
 
 
 
 
 
 
5
  const inferJson = await inferResponse.json();
6
 
7
  return inferJson.output;
 
1
  const textGenForm = document.querySelector('.text-gen-form');
2
 
3
  const translateText = async (text) => {
4
+ const inferResponse = await fetch(
5
+ `query_gorilla`,
6
+ {
7
+ method: 'POST',
8
+ body: text
9
+ }
10
+ );
11
  const inferJson = await inferResponse.json();
12
 
13
  return inferJson.output;