File size: 2,001 Bytes
0d215ca
 
 
e532db6
0d215ca
7f5cbab
d7570a5
 
 
af56d2c
d7570a5
 
b2703de
 
f203ba6
 
 
 
 
9cecd9c
f79758c
9cecd9c
f203ba6
 
 
 
 
 
 
 
 
 
 
0d215ca
f203ba6
 
a531b86
0d215ca
d7570a5
 
 
af56d2c
d7570a5
b2703de
8ff5e07
d7570a5
 
a531b86
 
 
 
 
 
 
 
 
 
87df952
a531b86
 
 
 
87df952
b2703de
 
8ff5e07
b2703de
f203ba6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import streamlit as st
import requests
import time
from ast import literal_eval

@st.cache
def infer(prompt, 
          model_name, 
          max_new_tokens=10, 
          temperature=0.8, 
          top_p=1.0,
          num_completions=1,
          seed=42,
          stop="\n"):

    model_name_map = {
        "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
    }

    if float(temperature) == 0:
        temperature = 0.01

    my_post_dict = {
        "model": "Together-gpt-JT-6B-v1",
        "prompt": prompt,
        "top_p": float(top_p),
        "temperature": float(temperature),
        "max_tokens": int(max_new_tokens),
        "stop": stop.split(";")
    }
    response = requests.get("https://staging.together.xyz/api/inference", params=my_post_dict).json()
    return response['output']['choices'][0]['text']
    
    
st.title("GPT-JT")

col1, col2 = st.columns([1, 3])

with col1:
    model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
    max_new_tokens = st.text_input('Max new tokens', "10")
    temperature = st.text_input('temperature', "0.8")
    top_p = st.text_input('top_p', "1.0")
    num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
    stop = st.text_input('stop, split by;', r'\n')
    seed = st.text_input('seed', "42")

with col2:
    s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
    prompt = st.text_area(
        "Prompt",
        value=s_example,
        max_chars=4096,
        height=400,
    )
        
    generated_area = st.empty()
    generated_area.text("(Generate here)")
    
    button_submit = st.button("Submit")

    if button_submit:
        generated_area.text(prompt)
        report_text = infer(
            prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
            num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
        )
        generated_area.text(prompt + report_text)