File size: 2,385 Bytes
0d215ca
 
f30382c
0d215ca
e532db6
f30382c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d215ca
7f5cbab
d7570a5
 
 
f30382c
d7570a5
 
b2703de
 
f30382c
 
 
 
0d215ca
a531b86
0d215ca
d7570a5
 
 
157e1ad
d7570a5
b2703de
8ff5e07
d7570a5
 
a531b86
 
 
 
 
 
 
 
 
 
87df952
a531b86
 
 
 
87df952
b2703de
 
8ff5e07
b2703de
87df952
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import streamlit as st
import requests
import asyncio
import time
from ast import literal_eval
import urllib.parse
from dacite import from_dict
from together_web3.computer import LanguageModelInferenceRequest
from together_web3.together import TogetherWeb3

st.title("GPT-JT")
if 'together_web3' not in st.session_state:
    st.session_state.together_web3 = TogetherWeb3()
if 'loop' not in st.session_state:
    st.session_state.loop = asyncio.new_event_loop()
async def _inference(prompt, max_tokens):
    result = await st.session_state.together_web3.language_model_inference(
        from_dict(
            data_class=LanguageModelInferenceRequest,
            data={
                "model": "Together-gpt-JT-6B-v1",
                "max_tokens": max_tokens,
                "prompt": prompt,
            }
        ),
    )
    return result

@st.cache
def infer(prompt, 
          model_name, 
          max_new_tokens=10, 
          temperature=0.0,
          top_p=1.0,
          num_completions=1,
          seed=42,
          stop="\n"):
    print("prompt", prompt)
    response = st.session_state.loop.run_until_complete(_inference(prompt, int(max_new_tokens)))
    print(response)
    return response.choices[0].text
    
col1, col2 = st.columns([1, 3])

with col1:
    model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
    max_new_tokens = st.text_input('Max new tokens', "10")
    temperature = st.text_input('temperature', "0.0")
    top_p = st.text_input('top_p', "1.0")
    num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
    stop = st.text_input('stop, split by;', r'\n')
    seed = st.text_input('seed', "42")

with col2:
    s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
    prompt = st.text_area(
        "Prompt",
        value=s_example,
        max_chars=4096,
        height=400,
    )
        
    generated_area = st.empty()
    generated_area.text("(Generate here)")
    
    button_submit = st.button("Submit")

    if button_submit:
        generated_area.text(prompt)
        report_text = infer(
            prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
            num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
        )
        generated_area.text(prompt + report_text)