RefalMachine commited on
Commit
8baca64
·
verified ·
1 Parent(s): 00667c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -7
app.py CHANGED
@@ -8,19 +8,20 @@ import requests
8
 
9
  from openai import OpenAI
10
 
11
- client = OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH'))
12
-
13
  #client = InferenceClient(os.getenv('MODEL_NAME_OR_PATH'))
14
 
15
 
16
  def respond(
 
17
  message,
18
  history: list[tuple[str, str]],
19
  system_message,
20
  max_tokens,
21
  temperature,
22
  top_p
23
- #repetition_penalty
24
  ):
25
  messages = []
26
  if len(system_message.strip()) > 0:
@@ -36,15 +37,15 @@ def respond(
36
 
37
  response = ""
38
 
39
- res = client.chat.completions.create(
40
- model='RefalMachine/ruadapt_qwen2.5_7B_ext_u48_instruct',
41
  messages=messages,
42
  temperature=temperature,
43
  top_p=top_p,
44
  max_tokens=max_tokens,
45
  stream=True,
46
  extra_body={
47
- "repetition_penalty": 1.0,
48
  "add_generation_prompt": True,
49
  }
50
  )
@@ -59,9 +60,11 @@ def respond(
59
  """
60
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
61
  """
 
62
  demo = gr.ChatInterface(
63
  respond,
64
  additional_inputs=[
 
65
  gr.Textbox(value="", label="System message"),
66
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
67
  gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature"),
@@ -72,7 +75,7 @@ demo = gr.ChatInterface(
72
  step=0.05,
73
  label="Top-p (nucleus sampling)",
74
  ),
75
- #gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.05, label="repetition_penalty"),
76
  ],
77
  )
78
 
 
8
 
9
  from openai import OpenAI
10
 
11
+ clients{'3B': [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_3B')), 'RefalMachine/ruadapt_qwen2.5_3B_ext_u48_instruct'],
12
+ '7B (work in progress)': [OpenAI(api_key='123', base_url=os.getenv('MODEL_NAME_OR_PATH_7B')), 'RefalMachine/ruadapt_qwen2.5_7B_ext_u48_instruct']}
13
  #client = InferenceClient(os.getenv('MODEL_NAME_OR_PATH'))
14
 
15
 
16
  def respond(
17
+ model_name,
18
  message,
19
  history: list[tuple[str, str]],
20
  system_message,
21
  max_tokens,
22
  temperature,
23
  top_p
24
+ repetition_penalty
25
  ):
26
  messages = []
27
  if len(system_message.strip()) > 0:
 
37
 
38
  response = ""
39
 
40
+ res = clients[model_name][0].chat.completions.create(
41
+ model=clients[model_name][1],
42
  messages=messages,
43
  temperature=temperature,
44
  top_p=top_p,
45
  max_tokens=max_tokens,
46
  stream=True,
47
  extra_body={
48
+ "repetition_penalty": repetition_penalty,
49
  "add_generation_prompt": True,
50
  }
51
  )
 
60
  """
61
  For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
62
  """
63
+ options = ["3B", "7B (work in progress)"]
64
  demo = gr.ChatInterface(
65
  respond,
66
  additional_inputs=[
67
+ gr.Radio(choices=options, label="Model:", value=options[0])
68
  gr.Textbox(value="", label="System message"),
69
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
70
  gr.Slider(minimum=0.1, maximum=4.0, value=0.3, step=0.1, label="Temperature"),
 
75
  step=0.05,
76
  label="Top-p (nucleus sampling)",
77
  ),
78
+ gr.Slider(minimum=0.9, maximum=1.2, value=1.0, step=0.05, label="repetition_penalty"),
79
  ],
80
  )
81