Shreyas094 commited on
Commit
7bc11b6
·
verified ·
1 Parent(s): 3807145

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -19
app.py CHANGED
@@ -3,7 +3,7 @@ import logging
3
  import gradio as gr
4
  from transformers import pipeline
5
 
6
- from llama_cpp_agent.providers import LlamaCppPythonProvider
7
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
8
  from llama_cpp_agent.chat_history import BasicChatHistory
9
  from llama_cpp_agent.chat_history.messages import Roles
@@ -18,12 +18,8 @@ from trafilatura import fetch_url, extract
18
  import json
19
  from datetime import datetime, timezone
20
  from typing import List
21
- from langchain_community.embeddings import HuggingFaceEmbeddings
22
  from langchain_community.llms import HuggingFaceHub
23
 
24
- llm = None
25
- llm_model = None
26
-
27
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
28
 
29
  examples = [
@@ -50,6 +46,17 @@ def get_messages_formatter_type(model_name):
50
  else:
51
  return MessagesFormatterType.CHATML
52
 
 
 
 
 
 
 
 
 
 
 
 
53
  def get_model(temperature, top_p, repetition_penalty):
54
  return HuggingFaceHub(
55
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
@@ -102,14 +109,13 @@ def respond(
102
  if model is None:
103
  logging.error("Model is None. Please select a valid model.")
104
  return "Error: No model selected. Please choose a valid model."
105
-
106
- global llm
107
- global llm_model
108
  chat_template = get_messages_formatter_type(model)
109
- if llm is None or llm_model != model:
110
- llm = get_model(temperature, top_p, repeat_penalty)
111
- llm_model = model
112
- provider = LlamaCppPythonProvider(llm)
 
113
  logging.info(f"Loaded chat examples: {chat_template}")
114
  search_tool = WebSearchTool(
115
  llm_provider=provider,
@@ -133,12 +139,12 @@ def respond(
133
  )
134
 
135
  settings = provider.get_provider_default_settings()
136
- settings.stream = False
137
- settings.temperature = temperature
138
- settings.top_k = top_k
139
- settings.top_p = top_p
140
- settings.max_tokens = max_tokens
141
- settings.repeat_penalty = repeat_penalty
142
 
143
  output_settings = LlmStructuredOutputSettings.from_functions(
144
  [search_tool.get_tool()]
@@ -163,7 +169,7 @@ def respond(
163
 
164
  outputs = ""
165
 
166
- settings.stream = True
167
  response_text = answer_agent.get_chat_response(
168
  f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
169
  result[0]["return_value"],
 
3
  import gradio as gr
4
  from transformers import pipeline
5
 
6
+ from llama_cpp_agent.providers import LLMProvider
7
  from llama_cpp_agent import LlamaCppAgent, MessagesFormatterType
8
  from llama_cpp_agent.chat_history import BasicChatHistory
9
  from llama_cpp_agent.chat_history.messages import Roles
 
18
  import json
19
  from datetime import datetime, timezone
20
  from typing import List
 
21
  from langchain_community.llms import HuggingFaceHub
22
 
 
 
 
23
  huggingface_token = os.environ.get("HUGGINGFACE_TOKEN")
24
 
25
  examples = [
 
46
  else:
47
  return MessagesFormatterType.CHATML
48
 
49
+ class HuggingFaceHubProvider(LLMProvider):
50
+ def __init__(self, model):
51
+ self.model = model
52
+
53
+ def create_completion(self, prompt, **kwargs):
54
+ response = self.model(prompt)
55
+ return {'choices': [{'text': response}]}
56
+
57
+ def get_provider_default_settings(self):
58
+ return self.model.model_kwargs
59
+
60
  def get_model(temperature, top_p, repetition_penalty):
61
  return HuggingFaceHub(
62
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
 
109
  if model is None:
110
  logging.error("Model is None. Please select a valid model.")
111
  return "Error: No model selected. Please choose a valid model."
112
+
 
 
113
  chat_template = get_messages_formatter_type(model)
114
+
115
+ # Create a new model instance for each request
116
+ llm = get_model(temperature, top_p, repeat_penalty)
117
+
118
+ provider = HuggingFaceHubProvider(llm)
119
  logging.info(f"Loaded chat examples: {chat_template}")
120
  search_tool = WebSearchTool(
121
  llm_provider=provider,
 
139
  )
140
 
141
  settings = provider.get_provider_default_settings()
142
+ settings['stream'] = False
143
+ settings['temperature'] = temperature
144
+ settings['top_k'] = top_k
145
+ settings['top_p'] = top_p
146
+ settings['max_tokens'] = max_tokens
147
+ settings['repeat_penalty'] = repeat_penalty
148
 
149
  output_settings = LlmStructuredOutputSettings.from_functions(
150
  [search_tool.get_tool()]
 
169
 
170
  outputs = ""
171
 
172
+ settings['stream'] = True
173
  response_text = answer_agent.get_chat_response(
174
  f"Write a detailed and complete research document that fulfills the following user request: '{message}', based on the information from the web below.\n\n" +
175
  result[0]["return_value"],