Shreyas094 commited on
Commit
03ea444
·
verified ·
1 Parent(s): 0e78477

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -19,7 +19,7 @@ from llama_cpp_agent.tools import WebSearchTool
19
  from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
20
  from langchain_community.llms import HuggingFaceHub
21
  from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings, LlmStructuredOutputType
22
- from pydantic import BaseModel
23
  from llama_cpp_agent.llm_output_settings import LlmStructuredOutputType
24
 
25
  print("Available LlmStructuredOutputType options:")
@@ -103,10 +103,11 @@ examples = [
103
 
104
  class CustomLLMSettings(BaseModel):
105
  structured_output: LlmStructuredOutputSettings
106
- temperature: float
107
- top_p: float
108
- repetition_penalty: float
109
-
 
110
  class HuggingFaceHubWrapper:
111
  def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
112
  self.model = HuggingFaceHub(
@@ -117,7 +118,7 @@ class HuggingFaceHubWrapper:
117
  self.temperature = model_kwargs.get('temperature', 0.7)
118
  self.top_p = model_kwargs.get('top_p', 0.95)
119
  self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
120
-
121
 
122
  def get_provider_default_settings(self):
123
  return CustomLLMSettings(
@@ -129,7 +130,8 @@ class HuggingFaceHubWrapper:
129
  ),
130
  temperature=self.temperature,
131
  top_p=self.top_p,
132
- repetition_penalty=self.repetition_penalty
 
133
  )
134
 
135
  def get_provider_identifier(self):
@@ -170,13 +172,14 @@ class CitingSources(BaseModel):
170
  )
171
 
172
  # Model function
173
- def get_model(temperature, top_p, repetition_penalty):
174
  return HuggingFaceHubWrapper(
175
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
176
  model_kwargs={
177
  "temperature": temperature,
178
  "top_p": top_p,
179
  "repetition_penalty": repetition_penalty,
 
180
  "max_length": 1000
181
  },
182
  huggingfacehub_api_token=huggingface_token
@@ -204,8 +207,9 @@ def respond(
204
  temperature,
205
  top_p,
206
  repeat_penalty,
 
207
  ):
208
- model = get_model(temperature, top_p, repeat_penalty)
209
 
210
  chat_template = MessagesFormatterType.MISTRAL
211
 
@@ -258,6 +262,7 @@ demo = gr.ChatInterface(
258
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
259
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
260
  gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
 
261
  ],
262
  theme=gr.themes.Soft(
263
  primary_hue="orange",
 
19
  from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
20
  from langchain_community.llms import HuggingFaceHub
21
  from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings, LlmStructuredOutputType
22
+ from pydantic import BaseModel, Field
23
  from llama_cpp_agent.llm_output_settings import LlmStructuredOutputType
24
 
25
  print("Available LlmStructuredOutputType options:")
 
103
 
104
  class CustomLLMSettings(BaseModel):
105
  structured_output: LlmStructuredOutputSettings
106
+ temperature: float = Field(default=0.7)
107
+ top_p: float = Field(default=0.95)
108
+ repetition_penalty: float = Field(default=1.1)
109
+ top_k: int = Field(default=50) # Added top_k parameter
110
+
111
  class HuggingFaceHubWrapper:
112
  def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
113
  self.model = HuggingFaceHub(
 
118
  self.temperature = model_kwargs.get('temperature', 0.7)
119
  self.top_p = model_kwargs.get('top_p', 0.95)
120
  self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
121
+ self.top_k = model_kwargs.get('top_k', 50) # Added top_k
122
 
123
  def get_provider_default_settings(self):
124
  return CustomLLMSettings(
 
130
  ),
131
  temperature=self.temperature,
132
  top_p=self.top_p,
133
+ repetition_penalty=self.repetition_penalty,
134
+ top_k=self.top_k # Added top_k
135
  )
136
 
137
  def get_provider_identifier(self):
 
172
  )
173
 
174
  # Model function
175
+ def get_model(temperature, top_p, repetition_penalty, top_k=50):
176
  return HuggingFaceHubWrapper(
177
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
178
  model_kwargs={
179
  "temperature": temperature,
180
  "top_p": top_p,
181
  "repetition_penalty": repetition_penalty,
182
+ "top_k": top_k,
183
  "max_length": 1000
184
  },
185
  huggingfacehub_api_token=huggingface_token
 
207
  temperature,
208
  top_p,
209
  repeat_penalty,
210
+ top_k=50, # Added top_k parameter
211
  ):
212
+ model = get_model(temperature, top_p, repeat_penalty, top_k)
213
 
214
  chat_template = MessagesFormatterType.MISTRAL
215
 
 
262
  gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
263
  gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
264
  gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
265
+ gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"), # Added top_k slider
266
  ],
267
  theme=gr.themes.Soft(
268
  primary_hue="orange",