Spaces:
Sleeping
Sleeping
File size: 2,499 Bytes
0d215ca 7f5cbab d7570a5 0d215ca 3442116 0d215ca bbe538b d7570a5 0d215ca 3442116 0d215ca d7570a5 0d215ca 0b3be54 0d215ca bbe538b 0b3be54 0d215ca 0c8c7dc 0d215ca a531b86 0d215ca d7570a5 157e1ad d7570a5 a531b86 5c2bbf9 157e1ad 5c2bbf9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import streamlit as st
import requests
import time
@st.cache
def infer(prompt,
model_name,
max_new_tokens=10,
temperature=0.0,
top_p=1.0,
num_completions=1,
seed=42,):
model_name_map = {
"GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
}
my_post_dict = {
"type": "general",
"payload": {
"max_tokens": int(max_new_tokens),
"n": int(num_completions),
"temperature": float(temperature),
"top_p": float(top_p),
"model": model_name_map[model_name],
"prompt": [prompt],
"request_type": "language-model-inference",
"stop": None,
"best_of": 1,
"echo": False,
"seed": int(seed),
"prompt_embedding": False,
},
"returned_payload": {},
"status": "submitted",
"source": "dalle",
}
job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
for i in range(100):
time.sleep(1)
ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
if ret['status'] == 'finished':
break
return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
st.title("GPT-JT")
col1, col2 = st.columns([1, 3])
with col1:
model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
max_new_tokens = st.text_input('Max new tokens', "10")
temperature = st.text_input('temperature', "0.0")
top_p = st.text_input('top_p', "1.0")
num_completions = st.text_input('num_completions', "1")
seed = st.text_input('seed', "42")
with col2:
s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
prompt = st.text_area(
"Prompt",
value=s_example,
max_chars=4096,
height=400,
)
generated_area = st.empty()
generated_area.markdown("(Generate here)")
button_submit = st.button("Submit")
if button_submit:
with st.spinner(text="In progress.."):
generated_area.markdown("...")
report_text = infer(
prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
num_completions=num_completions, seed=seed,
)
generated_area.markdown(report_text)
|