File size: 2,581 Bytes
0d215ca
 
 
 
7f5cbab
d7570a5
 
 
 
 
 
b2703de
 
0d215ca
3442116
 
 
 
0d215ca
 
 
bbe538b
d7570a5
0d215ca
 
3442116
0d215ca
 
b2703de
0d215ca
 
d7570a5
0d215ca
 
 
 
 
 
 
0b3be54
0d215ca
bbe538b
0b3be54
 
0d215ca
 
 
 
 
 
 
 
 
0c8c7dc
0d215ca
a531b86
0d215ca
d7570a5
 
 
157e1ad
d7570a5
b2703de
 
d7570a5
 
a531b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2703de
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import streamlit as st
import requests
import time

@st.cache
def infer(prompt, 
          model_name, 
          max_new_tokens=10, 
          temperature=0.0, 
          top_p=1.0,
          num_completions=1,
          seed=42,
          stop="\n"):

    model_name_map = {
        "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
    }

    my_post_dict = {
        "type": "general",
        "payload": {
            "max_tokens": int(max_new_tokens),
            "n": int(num_completions),
            "temperature": float(temperature),
            "top_p": float(top_p),
            "model": model_name_map[model_name],
            "prompt": [prompt],
            "request_type": "language-model-inference",
            "stop": stop.split(";"),
            "best_of": 1,
            "echo": False,
            "seed": int(seed),
            "prompt_embedding": False,
        },
        "returned_payload": {},
        "status": "submitted",
        "source": "dalle",
    }
    
    job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
    
    for i in range(100):
    
        time.sleep(1)
        
        ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
        
        if ret['status'] == 'finished':
            break
        
    return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
    
    
st.title("GPT-JT")

col1, col2 = st.columns([1, 3])

with col1:
    model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
    max_new_tokens = st.text_input('Max new tokens', "10")
    temperature = st.text_input('temperature', "0.0")
    top_p = st.text_input('top_p', "1.0")
    num_completions = st.text_input('num_completions (only the best one will be returend)', "1")
    stop = st.text_input('stop, split by;', "\n")
    seed = st.text_input('seed', "42")

with col2:
    s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
    prompt = st.text_area(
        "Prompt",
        value=s_example,
        max_chars=4096,
        height=400,
    )
        
    generated_area = st.empty()
    generated_area.markdown("(Generate here)")
    
    button_submit = st.button("Submit")

    if button_submit:
        generated_area.markdown(prompt)
        report_text = infer(
            prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
            num_completions=num_completions, seed=seed, stop=stop,
        )
        generated_area.markdown(prompt + "**" + report_text + "**")