xzyao commited on
Commit
f30382c
·
1 Parent(s): 58b7577

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -42
app.py CHANGED
@@ -1,59 +1,45 @@
1
  import streamlit as st
2
  import requests
 
3
  import time
4
  from ast import literal_eval
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  @st.cache
7
  def infer(prompt,
8
  model_name,
9
  max_new_tokens=10,
10
- temperature=0.0,
11
  top_p=1.0,
12
  num_completions=1,
13
  seed=42,
14
  stop="\n"):
15
-
16
- model_name_map = {
17
- "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
18
- }
19
-
20
- my_post_dict = {
21
- "type": "general",
22
- "payload": {
23
- "max_tokens": int(max_new_tokens),
24
- "n": int(num_completions),
25
- "temperature": float(temperature),
26
- "top_p": float(top_p),
27
- "model": model_name_map[model_name],
28
- "prompt": [prompt],
29
- "request_type": "language-model-inference",
30
- "stop": stop.split(";"),
31
- "best_of": 1,
32
- "echo": False,
33
- "seed": int(seed),
34
- "prompt_embedding": False,
35
- },
36
- "returned_payload": {},
37
- "status": "submitted",
38
- "source": "dalle",
39
- }
40
-
41
- job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
42
-
43
- for i in range(100):
44
 
45
- time.sleep(1)
46
-
47
- ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
48
-
49
- if ret['status'] == 'finished':
50
- break
51
-
52
- return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
53
-
54
-
55
- st.title("GPT-JT")
56
-
57
  col1, col2 = st.columns([1, 3])
58
 
59
  with col1:
 
1
  import streamlit as st
2
  import requests
3
+ import asyncio
4
  import time
5
  from ast import literal_eval
6
+ import urllib.parse
7
+ from dacite import from_dict
8
+ from together_web3.computer import LanguageModelInferenceRequest
9
+ from together_web3.together import TogetherWeb3
10
+
11
+ st.title("GPT-JT")
12
+ if 'together_web3' not in st.session_state:
13
+ st.session_state.together_web3 = TogetherWeb3()
14
+ if 'loop' not in st.session_state:
15
+ st.session_state.loop = asyncio.new_event_loop()
16
+ async def _inference(prompt, max_tokens):
17
+ result = await st.session_state.together_web3.language_model_inference(
18
+ from_dict(
19
+ data_class=LanguageModelInferenceRequest,
20
+ data={
21
+ "model": "Together-gpt-JT-6B-v1",
22
+ "max_tokens": max_tokens,
23
+ "prompt": prompt,
24
+ }
25
+ ),
26
+ )
27
+ return result
28
 
29
  @st.cache
30
  def infer(prompt,
31
  model_name,
32
  max_new_tokens=10,
33
+ temperature=0.0,
34
  top_p=1.0,
35
  num_completions=1,
36
  seed=42,
37
  stop="\n"):
38
+ print("prompt", prompt)
39
+ response = st.session_state.loop.run_until_complete(_inference(prompt, int(max_new_tokens)))
40
+ print(response)
41
+ return response.choices[0].text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  col1, col2 = st.columns([1, 3])
44
 
45
  with col1: