File size: 3,512 Bytes
7831eba
 
9d49e57
7831eba
 
 
a7d91d4
 
37a3c87
a7d91d4
b752df1
7831eba
b752df1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7831eba
555ac42
7831eba
 
c7fd9ac
7831eba
 
 
b5fab19
8baca64
7831eba
0cd27a0
 
 
7831eba
 
 
 
 
 
 
 
 
 
 
555ac42
8baca64
408d3e1
7831eba
 
408d3e1
 
 
8baca64
408d3e1
 
 
 
 
7831eba
 
 
 
 
 
 
 
 
8dbb362
d0520f9
7831eba
 
 
e3d87c2
f2e6053
257a390
793da93
7831eba
 
 
 
 
 
 
793da93
7831eba
555ac42
7831eba
 
 
 
408d3e1
d8d19ad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import gradio as gr
from huggingface_hub import InferenceClient
import os
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
import requests

from openai import OpenAI, AsyncOpenAI

clients = {}

try:
    clients['32B (work in progress)'] = [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_32B')), requests.get(os.getenv('MODEL_NAME_OR_PATH_32B') + '/models').json()['data'][0]['id']]
except:
    pass

try:
    clients['32B QWQ (experimental, without any additional tuning after LEP!)'] = [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_QWQ')), requests.get(os.getenv('MODEL_NAME_OR_PATH_QWQ') + '/models').json()['data'][0]['id']]
except:
    pass
    
try:
    clients['7B (work in progress)'] = [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_7B')), requests.get(os.getenv('MODEL_NAME_OR_PATH_7B') + '/models').json()['data'][0]['id']]
except:
    pass
    
try:
    clients['3B'] = [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_3B')), requests.get(os.getenv('MODEL_NAME_OR_PATH_3B') + '/models').json()['data'][0]['id']]
except:
    pass

def respond(
    message,
    history: list[tuple[str, str]],
    model_name,
    system_message,
    max_tokens,
    temperature,
    top_p,
    repetition_penalty
):
    messages = []
    if len(system_message.strip()) > 0:
        messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    res = clients[model_name][0].chat.completions.create(
        model=clients[model_name][1],
        messages=messages,
        temperature=temperature,
        top_p=top_p,
        max_tokens=max_tokens,
        stream=True,
        extra_body={
            "repetition_penalty": repetition_penalty,
            "add_generation_prompt": True,
        }
    )

    for message in res:
        token = message.choices[0].delta.content

        response += token
        yield response


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
options = ["32B (work in progress)", "32B QWQ (experimental, without any additional tuning after LEP!)", "7B (work in progress)", "3B"]
options = options[:2]
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Radio(choices=options, label="Model:", value=options[1]),
        gr.Textbox(value="You are a helpful and harmless assistant. You should think step-by-step. First, reason (the user does not see your reasoning), then give your final answer.", label="System message"),
        gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.0, maximum=2.0, value=0.3, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
        gr.Slider(minimum=0.9, maximum=1.5, value=1.05, step=0.05, label="repetition_penalty"),
    ],
    concurrency_limit=10
)


if __name__ == "__main__":
    #print(requests.get(os.getenv('MODEL_NAME_OR_PATH')[:-3] + '/docs'))
    demo.launch(share=True)