Spaces:
Runtime error
Runtime error
File size: 2,611 Bytes
0d215ca e532db6 0d215ca 7f5cbab d7570a5 b2703de 0d215ca 3442116 0d215ca bbe538b d7570a5 0d215ca 3442116 0d215ca b2703de 0d215ca d7570a5 0d215ca 0b3be54 0d215ca bbe538b 0b3be54 0d215ca 0c8c7dc 0d215ca a531b86 0d215ca d7570a5 157e1ad d7570a5 b2703de 8ff5e07 d7570a5 a531b86 87df952 a531b86 87df952 b2703de 8ff5e07 b2703de 87df952 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import streamlit as st
import requests
import time
from ast import literal_eval
@st.cache
def infer(prompt,
model_name,
max_new_tokens=10,
temperature=0.0,
top_p=1.0,
num_completions=1,
seed=42,
stop="\n"):
model_name_map = {
"GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
}
my_post_dict = {
"type": "general",
"payload": {
"max_tokens": int(max_new_tokens),
"n": int(num_completions),
"temperature": float(temperature),
"top_p": float(top_p),
"model": model_name_map[model_name],
"prompt": [prompt],
"request_type": "language-model-inference",
"stop": stop.split(";"),
"best_of": 1,
"echo": False,
"seed": int(seed),
"prompt_embedding": False,
},
"returned_payload": {},
"status": "submitted",
"source": "dalle",
}
job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
for i in range(100):
time.sleep(1)
ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
if ret['status'] == 'finished':
break
return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
st.title("GPT-JT")
col1, col2 = st.columns([1, 3])
with col1:
model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
max_new_tokens = st.text_input('Max new tokens', "10")
temperature = st.text_input('temperature', "0.0")
top_p = st.text_input('top_p', "1.0")
num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
stop = st.text_input('stop, split by;', r'\n')
seed = st.text_input('seed', "42")
with col2:
s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
prompt = st.text_area(
"Prompt",
value=s_example,
max_chars=4096,
height=400,
)
generated_area = st.empty()
generated_area.text("(Generate here)")
button_submit = st.button("Submit")
if button_submit:
generated_area.text(prompt)
report_text = infer(
prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
)
generated_area.text(prompt + report_text)
|