Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,19 +8,20 @@ import requests
|
|
8 |
|
9 |
from openai import OpenAI
|
10 |
|
11 |
-
|
12 |
-
|
13 |
#client = InferenceClient(os.getenv('MODEL_NAME_OR_PATH'))
|
14 |
|
15 |
|
16 |
def respond(
|
|
|
17 |
message,
|
18 |
history: list[tuple[str, str]],
|
19 |
system_message,
|
20 |
max_tokens,
|
21 |
temperature,
|
22 |
top_p
|
23 |
-
|
24 |
):
|
25 |
messages = []
|
26 |
if len(system_message.strip()) > 0:
|
@@ -36,15 +37,15 @@ def respond(
|
|
36 |
|
37 |
response = ""
|
38 |
|
39 |
-
res =
|
40 |
-
model=
|
41 |
messages=messages,
|
42 |
temperature=temperature,
|
43 |
top_p=top_p,
|
44 |
max_tokens=max_tokens,
|
45 |
stream=True,
|
46 |
extra_body={
|
47 |
-
"repetition_penalty":
|
48 |
"add_generation_prompt": True,
|
49 |
}
|
50 |
)
|
@@ -59,9 +60,11 @@ def respond(
|
|
59 |
"""
|
60 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
61 |
"""
|
|
|
62 |
demo = gr.ChatInterface(
|
63 |
respond,
|
64 |
additional_inputs=[
|
|
|
65 |
gr.Textbox(value="", label="System message"),
|
66 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
67 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature"),
|
@@ -72,7 +75,7 @@ demo = gr.ChatInterface(
|
|
72 |
step=0.05,
|
73 |
label="Top-p (nucleus sampling)",
|
74 |
),
|
75 |
-
|
76 |
],
|
77 |
)
|
78 |
|
|
|
8 |
|
9 |
from openai import OpenAI
|
10 |
|
11 |
+
clients{'3B': [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_3B')), 'RefalMachine/ruadapt_qwen2.5_3B_ext_u48_instruct'],
|
12 |
+
'7B (work in progress)': [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_7B')), 'RefalMachine/ruadapt_qwen2.5_7B_ext_u48_instruct']}
|
13 |
#client = InferenceClient(os.getenv('MODEL_NAME_OR_PATH'))
|
14 |
|
15 |
|
16 |
def respond(
|
17 |
+
model_name,
|
18 |
message,
|
19 |
history: list[tuple[str, str]],
|
20 |
system_message,
|
21 |
max_tokens,
|
22 |
temperature,
|
23 |
top_p
|
24 |
+
repetition_penalty
|
25 |
):
|
26 |
messages = []
|
27 |
if len(system_message.strip()) > 0:
|
|
|
37 |
|
38 |
response = ""
|
39 |
|
40 |
+
res = clients[model_name][0].chat.completions.create(
|
41 |
+
model=clients[model_name][1],
|
42 |
messages=messages,
|
43 |
temperature=temperature,
|
44 |
top_p=top_p,
|
45 |
max_tokens=max_tokens,
|
46 |
stream=True,
|
47 |
extra_body={
|
48 |
+
"repetition_penalty": repetition_penalty,
|
49 |
"add_generation_prompt": True,
|
50 |
}
|
51 |
)
|
|
|
60 |
"""
|
61 |
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
|
62 |
"""
|
63 |
+
options = ["3B", "7B (work in progress)"]
|
64 |
demo = gr.ChatInterface(
|
65 |
respond,
|
66 |
additional_inputs=[
|
67 |
+
gr.Radio(choices=options, label="Model:", value=options[0])
|
68 |
gr.Textbox(value="", label="System message"),
|
69 |
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
|
70 |
gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature"),
|
|
|
75 |
step=0.05,
|
76 |
label="Top-p (nucleus sampling)",
|
77 |
),
|
78 |
+
gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.05, label="repetition_penalty"),
|
79 |
],
|
80 |
)
|
81 |
|