xzyao commited on
Commit
e3b7503
·
1 Parent(s): f30382c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -4
app.py CHANGED
@@ -13,7 +13,7 @@ if 'together_web3' not in st.session_state:
13
  st.session_state.together_web3 = TogetherWeb3()
14
  if 'loop' not in st.session_state:
15
  st.session_state.loop = asyncio.new_event_loop()
16
- async def _inference(prompt, max_tokens):
17
  result = await st.session_state.together_web3.language_model_inference(
18
  from_dict(
19
  data_class=LanguageModelInferenceRequest,
@@ -21,6 +21,10 @@ async def _inference(prompt, max_tokens):
21
  "model": "Together-gpt-JT-6B-v1",
22
  "max_tokens": max_tokens,
23
  "prompt": prompt,
 
 
 
 
24
  }
25
  ),
26
  )
@@ -30,13 +34,14 @@ async def _inference(prompt, max_tokens):
30
  def infer(prompt,
31
  model_name,
32
  max_new_tokens=10,
33
- temperature=0.0,
34
  top_p=1.0,
35
  num_completions=1,
36
  seed=42,
37
  stop="\n"):
38
  print("prompt", prompt)
39
- response = st.session_state.loop.run_until_complete(_inference(prompt, int(max_new_tokens)))
 
40
  print(response)
41
  return response.choices[0].text
42
 
@@ -45,7 +50,7 @@ col1, col2 = st.columns([1, 3])
45
  with col1:
46
  model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
47
  max_new_tokens = st.text_input('Max new tokens', "10")
48
- temperature = st.text_input('temperature', "0.0")
49
  top_p = st.text_input('top_p', "1.0")
50
  num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
51
  stop = st.text_input('stop, split by;', r'\n')
 
13
  st.session_state.together_web3 = TogetherWeb3()
14
  if 'loop' not in st.session_state:
15
  st.session_state.loop = asyncio.new_event_loop()
16
+ async def _inference(prompt, max_tokens, stop, top_p, temperature, seed):
17
  result = await st.session_state.together_web3.language_model_inference(
18
  from_dict(
19
  data_class=LanguageModelInferenceRequest,
 
21
  "model": "Together-gpt-JT-6B-v1",
22
  "max_tokens": max_tokens,
23
  "prompt": prompt,
24
+ "stop": stop,
25
+ "top_p": top_p,
26
+ "temperature": temperature,
27
+ "seed": seed,
28
  }
29
  ),
30
  )
 
34
  def infer(prompt,
35
  model_name,
36
  max_new_tokens=10,
37
+ temperature=1.0,
38
  top_p=1.0,
39
  num_completions=1,
40
  seed=42,
41
  stop="\n"):
42
  print("prompt", prompt)
43
+ stop = stop.split(";")
44
+ response = st.session_state.loop.run_until_complete(_inference(prompt, int(max_new_tokens), stop, float(top_p), float(temperature), int(seed)))
45
  print(response)
46
  return response.choices[0].text
47
 
 
50
  with col1:
51
  model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
52
  max_new_tokens = st.text_input('Max new tokens', "10")
53
+ temperature = st.text_input('temperature', "1.0")
54
  top_p = st.text_input('top_p', "1.0")
55
  num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
56
  stop = st.text_input('stop, split by;', r'\n')