Shreyas094 commited on
Commit
3c6b68b
·
verified ·
1 Parent(s): 03ea444

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -10
app.py CHANGED
@@ -100,13 +100,22 @@ examples = [
100
  ["filetype:pdf intitle:python"]
101
  ]
102
 
103
-
104
  class CustomLLMSettings(BaseModel):
105
  structured_output: LlmStructuredOutputSettings
106
  temperature: float = Field(default=0.7)
107
  top_p: float = Field(default=0.95)
108
  repetition_penalty: float = Field(default=1.1)
109
- top_k: int = Field(default=50) # Added top_k parameter
 
 
 
 
 
 
 
 
 
 
110
 
111
  class HuggingFaceHubWrapper:
112
  def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
@@ -118,7 +127,9 @@ class HuggingFaceHubWrapper:
118
  self.temperature = model_kwargs.get('temperature', 0.7)
119
  self.top_p = model_kwargs.get('top_p', 0.95)
120
  self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
121
- self.top_k = model_kwargs.get('top_k', 50) # Added top_k
 
 
122
 
123
  def get_provider_default_settings(self):
124
  return CustomLLMSettings(
@@ -131,7 +142,9 @@ class HuggingFaceHubWrapper:
131
  temperature=self.temperature,
132
  top_p=self.top_p,
133
  repetition_penalty=self.repetition_penalty,
134
- top_k=self.top_k # Added top_k
 
 
135
  )
136
 
137
  def get_provider_identifier(self):
@@ -172,7 +185,7 @@ class CitingSources(BaseModel):
172
  )
173
 
174
  # Model function
175
- def get_model(temperature, top_p, repetition_penalty, top_k=50):
176
  return HuggingFaceHubWrapper(
177
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
178
  model_kwargs={
@@ -180,7 +193,8 @@ def get_model(temperature, top_p, repetition_penalty, top_k=50):
180
  "top_p": top_p,
181
  "repetition_penalty": repetition_penalty,
182
  "top_k": top_k,
183
- "max_length": 1000
 
184
  },
185
  huggingfacehub_api_token=huggingface_token
186
  )
@@ -207,10 +221,10 @@ def respond(
207
  temperature,
208
  top_p,
209
  repeat_penalty,
210
- top_k=50, # Added top_k parameter
 
211
  ):
212
- model = get_model(temperature, top_p, repeat_penalty, top_k)
213
-
214
  chat_template = MessagesFormatterType.MISTRAL
215
 
216
  search_tool = WebSearchTool(
@@ -262,7 +276,8 @@ demo = gr.ChatInterface(
262
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
263
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
264
  gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
265
- gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"), # Added top_k slider
 
266
  ],
267
  theme=gr.themes.Soft(
268
  primary_hue="orange",
 
100
  ["filetype:pdf intitle:python"]
101
  ]
102
 
 
103
  class CustomLLMSettings(BaseModel):
104
  structured_output: LlmStructuredOutputSettings
105
  temperature: float = Field(default=0.7)
106
  top_p: float = Field(default=0.95)
107
  repetition_penalty: float = Field(default=1.1)
108
+ top_k: int = Field(default=50)
109
+ max_tokens: int = Field(default=1000)
110
+ stop: list[str] = Field(default_factory=list)
111
+ echo: bool = Field(default=False)
112
+ stream: bool = Field(default=False)
113
+ logprobs: int = Field(default=None)
114
+ presence_penalty: float = Field(default=0.0)
115
+ frequency_penalty: float = Field(default=0.0)
116
+ best_of: int = Field(default=1)
117
+ logit_bias: dict = Field(default_factory=dict)
118
+ max_tokens_per_summary: int = Field(default=2048)
119
 
120
  class HuggingFaceHubWrapper:
121
  def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
 
127
  self.temperature = model_kwargs.get('temperature', 0.7)
128
  self.top_p = model_kwargs.get('top_p', 0.95)
129
  self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
130
+ self.top_k = model_kwargs.get('top_k', 50)
131
+ self.max_tokens = model_kwargs.get('max_length', 1000)
132
+ self.max_tokens_per_summary = model_kwargs.get('max_tokens_per_summary', 2048)
133
 
134
  def get_provider_default_settings(self):
135
  return CustomLLMSettings(
 
142
  temperature=self.temperature,
143
  top_p=self.top_p,
144
  repetition_penalty=self.repetition_penalty,
145
+ top_k=self.top_k,
146
+ max_tokens=self.max_tokens,
147
+ max_tokens_per_summary=self.max_tokens_per_summary
148
  )
149
 
150
  def get_provider_identifier(self):
 
185
  )
186
 
187
  # Model function
188
+ def get_model(temperature, top_p, repetition_penalty, top_k=50, max_tokens=1000, max_tokens_per_summary=2048):
189
  return HuggingFaceHubWrapper(
190
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
191
  model_kwargs={
 
193
  "top_p": top_p,
194
  "repetition_penalty": repetition_penalty,
195
  "top_k": top_k,
196
+ "max_length": max_tokens,
197
+ "max_tokens_per_summary": max_tokens_per_summary
198
  },
199
  huggingfacehub_api_token=huggingface_token
200
  )
 
221
  temperature,
222
  top_p,
223
  repeat_penalty,
224
+ top_k=50,
225
+ max_tokens_per_summary=2048
226
  ):
227
+ model = get_model(temperature, top_p, repeat_penalty, top_k, max_tokens, max_tokens_per_summary)
 
228
  chat_template = MessagesFormatterType.MISTRAL
229
 
230
  search_tool = WebSearchTool(
 
276
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
277
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
278
  gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
279
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
280
+ gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens per summary"),
281
  ],
282
  theme=gr.themes.Soft(
283
  primary_hue="orange",