File size: 2,499 Bytes
0d215ca
 
 
 
7f5cbab
d7570a5
 
 
 
 
 
 
0d215ca
3442116
 
 
 
0d215ca
 
 
bbe538b
d7570a5
0d215ca
 
3442116
0d215ca
 
 
 
 
d7570a5
0d215ca
 
 
 
 
 
 
0b3be54
0d215ca
bbe538b
0b3be54
 
0d215ca
 
 
 
 
 
 
 
 
0c8c7dc
0d215ca
a531b86
0d215ca
d7570a5
 
 
157e1ad
d7570a5
 
 
 
a531b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5c2bbf9
157e1ad
5c2bbf9
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import streamlit as st
import requests
import time

@st.cache
def infer(prompt, 
          model_name, 
          max_new_tokens=10, 
          temperature=0.0, 
          top_p=1.0,
          num_completions=1,
          seed=42,):

    model_name_map = {
        "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
    }

    my_post_dict = {
        "type": "general",
        "payload": {
            "max_tokens": int(max_new_tokens),
            "n": int(num_completions),
            "temperature": float(temperature),
            "top_p": float(top_p),
            "model": model_name_map[model_name],
            "prompt": [prompt],
            "request_type": "language-model-inference",
            "stop": None,
            "best_of": 1,
            "echo": False,
            "seed": int(seed),
            "prompt_embedding": False,
        },
        "returned_payload": {},
        "status": "submitted",
        "source": "dalle",
    }
    
    job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
    
    for i in range(100):
    
        time.sleep(1)
        
        ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
        
        if ret['status'] == 'finished':
            break
        
    return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
    
    
st.title("GPT-JT")

col1, col2 = st.columns([1, 3])

with col1:
    model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
    max_new_tokens = st.text_input('Max new tokens', "10")
    temperature = st.text_input('temperature', "0.0")
    top_p = st.text_input('top_p', "1.0")
    num_completions = st.text_input('num_completions', "1")
    seed = st.text_input('seed', "42")

with col2:
    s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
    prompt = st.text_area(
        "Prompt",
        value=s_example,
        max_chars=4096,
        height=400,
    )
        
    generated_area = st.empty()
    generated_area.markdown("(Generate here)")
    
    button_submit = st.button("Submit")

    if button_submit:
        with st.spinner(text="In progress.."):
            generated_area.markdown("...")
            report_text = infer(
                prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
                num_completions=num_completions, seed=seed,
            )
            generated_area.markdown(report_text)