File size: 1,941 Bytes
0d215ca
 
 
e532db6
0d215ca
7f5cbab
d7570a5
 
 
f203ba6
d7570a5
 
b2703de
 
f203ba6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d215ca
f203ba6
 
a531b86
0d215ca
d7570a5
 
 
f203ba6
d7570a5
b2703de
8ff5e07
d7570a5
 
a531b86
 
 
 
 
 
 
 
 
 
87df952
a531b86
 
 
 
87df952
b2703de
 
8ff5e07
b2703de
f203ba6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import streamlit as st
import requests
import time
from ast import literal_eval

@st.cache
def infer(prompt, 
          model_name, 
          max_new_tokens=10, 
          temperature=0.0, 
          top_p=1.0,
          num_completions=1,
          seed=42,
          stop="\n"):

    model_name_map = {
        "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
    }

    my_post_dict = {
        "model": "Together-gpt-JT-6B-v1",
        "prompt": prompt,
        "top_p": float(top_p),
        "temperature": float(temperature),
        "max_tokens": int(max_new_tokens),
        "stop": stop.split(";")
    }
    response = requests.get("https://staging.together.xyz/api/inference", params=my_post_dict).json()
    return response['output']['choices'][0]['text']
    
    
st.title("GPT-JT")

col1, col2 = st.columns([1, 3])

with col1:
    model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
    max_new_tokens = st.text_input('Max new tokens', "10")
    temperature = st.text_input('temperature', "0.0")
    top_p = st.text_input('top_p', "1.0")
    num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
    stop = st.text_input('stop, split by;', r'\n')
    seed = st.text_input('seed', "42")

with col2:
    s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
    prompt = st.text_area(
        "Prompt",
        value=s_example,
        max_chars=4096,
        height=400,
    )
        
    generated_area = st.empty()
    generated_area.text("(Generate here)")
    
    button_submit = st.button("Submit")

    if button_submit:
        generated_area.text(prompt)
        report_text = infer(
            prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
            num_completions=num_completions, seed=seed, stop=literal_eval("'''"+stop+"'''"),
        )
        generated_area.text(prompt + report_text)