Spaces:
Runtime error
Runtime error
File size: 2,385 Bytes
0d215ca f30382c 0d215ca e532db6 f30382c 0d215ca 7f5cbab d7570a5 f30382c d7570a5 b2703de f30382c 0d215ca a531b86 0d215ca d7570a5 157e1ad d7570a5 b2703de 8ff5e07 d7570a5 a531b86 87df952 a531b86 87df952 b2703de 8ff5e07 b2703de 87df952 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
import streamlit as st
import requests
import asyncio
import time
from ast import literal_eval
import urllib.parse
from dacite import from_dict
from together_web3.computer import LanguageModelInferenceRequest
from together_web3.together import TogetherWeb3
st.title("GPT-JT")
if 'together_web3' not in st.session_state:
st.session_state.together_web3 = TogetherWeb3()
if 'loop' not in st.session_state:
st.session_state.loop = asyncio.new_event_loop()
async def _inference(prompt, max_tokens):
result = await st.session_state.together_web3.language_model_inference(
from_dict(
data_class=LanguageModelInferenceRequest,
data={
"model": "Together-gpt-JT-6B-v1",
"max_tokens": max_tokens,
"prompt": prompt,
}
),
)
return result
@st.cache
def infer(prompt,
model_name,
max_new_tokens=10,
temperature=0.0,
top_p=1.0,
num_completions=1,
seed=42,
stop="\n"):
print("prompt", prompt)
response = st.session_state.loop.run_until_complete(_inference(prompt, int(max_new_tokens)))
print(response)
return response.choices[0].text
col1, col2 = st.columns([1, 3])
with col1:
model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
max_new_tokens = st.text_input('Max new tokens', "10")
temperature = st.text_input('temperature', "0.0")
top_p = st.text_input('top_p', "1.0")
num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
stop = st.text_input('stop, split by;', r'\n')
seed = st.text_input('seed', "42")
with col2:
s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
prompt = st.text_area(
"Prompt",
value=s_example,
max_chars=4096,
height=400,
)
generated_area = st.empty()
generated_area.text("(Generate here)")
button_submit = st.button("Submit")
if button_submit:
generated_area.text(prompt)
report_text = infer(
prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
)
generated_area.text(prompt + report_text)
|