Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -100,13 +100,22 @@ examples = [
|
|
100 |
["filetype:pdf intitle:python"]
|
101 |
]
|
102 |
|
103 |
-
|
104 |
class CustomLLMSettings(BaseModel):
|
105 |
structured_output: LlmStructuredOutputSettings
|
106 |
temperature: float = Field(default=0.7)
|
107 |
top_p: float = Field(default=0.95)
|
108 |
repetition_penalty: float = Field(default=1.1)
|
109 |
-
top_k: int = Field(default=50)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
class HuggingFaceHubWrapper:
|
112 |
def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
|
@@ -118,7 +127,9 @@ class HuggingFaceHubWrapper:
|
|
118 |
self.temperature = model_kwargs.get('temperature', 0.7)
|
119 |
self.top_p = model_kwargs.get('top_p', 0.95)
|
120 |
self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
|
121 |
-
self.top_k = model_kwargs.get('top_k', 50)
|
|
|
|
|
122 |
|
123 |
def get_provider_default_settings(self):
|
124 |
return CustomLLMSettings(
|
@@ -131,7 +142,9 @@ class HuggingFaceHubWrapper:
|
|
131 |
temperature=self.temperature,
|
132 |
top_p=self.top_p,
|
133 |
repetition_penalty=self.repetition_penalty,
|
134 |
-
top_k=self.top_k
|
|
|
|
|
135 |
)
|
136 |
|
137 |
def get_provider_identifier(self):
|
@@ -172,7 +185,7 @@ class CitingSources(BaseModel):
|
|
172 |
)
|
173 |
|
174 |
# Model function
|
175 |
-
def get_model(temperature, top_p, repetition_penalty, top_k=50):
|
176 |
return HuggingFaceHubWrapper(
|
177 |
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
|
178 |
model_kwargs={
|
@@ -180,7 +193,8 @@ def get_model(temperature, top_p, repetition_penalty, top_k=50):
|
|
180 |
"top_p": top_p,
|
181 |
"repetition_penalty": repetition_penalty,
|
182 |
"top_k": top_k,
|
183 |
-
"max_length":
|
|
|
184 |
},
|
185 |
huggingfacehub_api_token=huggingface_token
|
186 |
)
|
@@ -207,10 +221,10 @@ def respond(
|
|
207 |
temperature,
|
208 |
top_p,
|
209 |
repeat_penalty,
|
210 |
-
top_k=50,
|
|
|
211 |
):
|
212 |
-
model = get_model(temperature, top_p, repeat_penalty, top_k)
|
213 |
-
|
214 |
chat_template = MessagesFormatterType.MISTRAL
|
215 |
|
216 |
search_tool = WebSearchTool(
|
@@ -262,7 +276,8 @@ demo = gr.ChatInterface(
|
|
262 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
|
263 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
|
264 |
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
|
265 |
-
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
|
|
|
266 |
],
|
267 |
theme=gr.themes.Soft(
|
268 |
primary_hue="orange",
|
|
|
100 |
["filetype:pdf intitle:python"]
|
101 |
]
|
102 |
|
|
|
103 |
class CustomLLMSettings(BaseModel):
|
104 |
structured_output: LlmStructuredOutputSettings
|
105 |
temperature: float = Field(default=0.7)
|
106 |
top_p: float = Field(default=0.95)
|
107 |
repetition_penalty: float = Field(default=1.1)
|
108 |
+
top_k: int = Field(default=50)
|
109 |
+
max_tokens: int = Field(default=1000)
|
110 |
+
stop: list[str] = Field(default_factory=list)
|
111 |
+
echo: bool = Field(default=False)
|
112 |
+
stream: bool = Field(default=False)
|
113 |
+
logprobs: int = Field(default=None)
|
114 |
+
presence_penalty: float = Field(default=0.0)
|
115 |
+
frequency_penalty: float = Field(default=0.0)
|
116 |
+
best_of: int = Field(default=1)
|
117 |
+
logit_bias: dict = Field(default_factory=dict)
|
118 |
+
max_tokens_per_summary: int = Field(default=2048)
|
119 |
|
120 |
class HuggingFaceHubWrapper:
|
121 |
def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
|
|
|
127 |
self.temperature = model_kwargs.get('temperature', 0.7)
|
128 |
self.top_p = model_kwargs.get('top_p', 0.95)
|
129 |
self.repetition_penalty = model_kwargs.get('repetition_penalty', 1.1)
|
130 |
+
self.top_k = model_kwargs.get('top_k', 50)
|
131 |
+
self.max_tokens = model_kwargs.get('max_length', 1000)
|
132 |
+
self.max_tokens_per_summary = model_kwargs.get('max_tokens_per_summary', 2048)
|
133 |
|
134 |
def get_provider_default_settings(self):
|
135 |
return CustomLLMSettings(
|
|
|
142 |
temperature=self.temperature,
|
143 |
top_p=self.top_p,
|
144 |
repetition_penalty=self.repetition_penalty,
|
145 |
+
top_k=self.top_k,
|
146 |
+
max_tokens=self.max_tokens,
|
147 |
+
max_tokens_per_summary=self.max_tokens_per_summary
|
148 |
)
|
149 |
|
150 |
def get_provider_identifier(self):
|
|
|
185 |
)
|
186 |
|
187 |
# Model function
|
188 |
+
def get_model(temperature, top_p, repetition_penalty, top_k=50, max_tokens=1000, max_tokens_per_summary=2048):
|
189 |
return HuggingFaceHubWrapper(
|
190 |
repo_id="mistralai/Mistral-7B-Instruct-v0.3",
|
191 |
model_kwargs={
|
|
|
193 |
"top_p": top_p,
|
194 |
"repetition_penalty": repetition_penalty,
|
195 |
"top_k": top_k,
|
196 |
+
"max_length": max_tokens,
|
197 |
+
"max_tokens_per_summary": max_tokens_per_summary
|
198 |
},
|
199 |
huggingfacehub_api_token=huggingface_token
|
200 |
)
|
|
|
221 |
temperature,
|
222 |
top_p,
|
223 |
repeat_penalty,
|
224 |
+
top_k=50,
|
225 |
+
max_tokens_per_summary=2048
|
226 |
):
|
227 |
+
model = get_model(temperature, top_p, repeat_penalty, top_k, max_tokens, max_tokens_per_summary)
|
|
|
228 |
chat_template = MessagesFormatterType.MISTRAL
|
229 |
|
230 |
search_tool = WebSearchTool(
|
|
|
276 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.7, step=0.1, label="Temperature"),
|
277 |
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p"),
|
278 |
gr.Slider(minimum=0.0, maximum=2.0, value=1.1, step=0.1, label="Repetition penalty"),
|
279 |
+
gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k"),
|
280 |
+
gr.Slider(minimum=1, maximum=4096, value=2048, step=1, label="Max tokens per summary"),
|
281 |
],
|
282 |
theme=gr.themes.Soft(
|
283 |
primary_hue="orange",
|