File size: 2,432 Bytes
0d215ca
 
 
 
d7570a5
 
 
 
 
 
 
0d215ca
3442116
 
 
 
0d215ca
 
 
bbe538b
d7570a5
0d215ca
 
3442116
0d215ca
 
 
 
 
d7570a5
0d215ca
 
 
 
 
 
 
0b3be54
0d215ca
bbe538b
0b3be54
 
0d215ca
 
 
 
 
 
 
 
 
 
 
a531b86
0d215ca
d7570a5
 
 
 
 
 
 
 
a531b86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0d215ca
d7570a5
 
 
 
a531b86
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import streamlit as st
import requests
import time

def infer(prompt, 
          model_name, 
          max_new_tokens=10, 
          temperature=0.0, 
          top_p=1.0,
          num_completions=1,
          seed=42,):

    model_name_map = {
        "GPT-JT-6B-v1": "Together-gpt-JT-6B-v1",
    }

    my_post_dict = {
        "type": "general",
        "payload": {
            "max_tokens": int(max_new_tokens),
            "n": int(num_completions),
            "temperature": float(temperature),
            "top_p": float(top_p),
            "model": model_name_map[model_name],
            "prompt": [prompt],
            "request_type": "language-model-inference",
            "stop": None,
            "best_of": 1,
            "echo": False,
            "seed": int(seed),
            "prompt_embedding": False,
        },
        "returned_payload": {},
        "status": "submitted",
        "source": "dalle",
    }
    
    job_id = requests.post("https://planetd.shift.ml/jobs", json=my_post_dict).json()['id']
    
    for i in range(100):
    
        time.sleep(1)
        
        ret = requests.get(f"https://planetd.shift.ml/job/{job_id}", json={'id': job_id}).json()
        
        if ret['status'] == 'finished':
            break
        
    return ret['returned_payload']['result']['inference_result'][0]['choices'][0]['text']
    
    
st.title("TOMA Application")

col1, col2 = st.columns([1, 3])

with col1:
    model_name = st.selectbox("Model", ["GPT-JT-6B-v1"])
    max_new_tokens = st.text_input('Max new tokens', "10")
    temperature = st.text_input('temperature', "1.0")
    top_p = st.text_input('top_p', "1.0")
    num_completions = st.text_input('num_completions', "1")
    seed = st.text_input('seed', "42")

with col2:
    s_example = "Please answer the following question:\n\nQuestion: Where is Zurich?\nAnswer:"
    prompt = st.text_area(
        "Prompt",
        value=s_example,
        max_chars=4096,
        height=400,
    )
        
    generated_area = st.empty()
    generated_area.markdown("(Generate here)")
    
    button_submit = st.button("Submit")

    if button_submit:
    with st.spinner(text="In progress.."):
        report_text = infer(
            prompt, model_name=model_name, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p,
            num_completions=num_completions, seed=seed,
        )
        generated_area.markdown(report_text)