Shreyas094 commited on
Commit
8ec44be
·
verified ·
1 Parent(s): d1cccac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -17,6 +17,9 @@ from llama_cpp_agent.llm_output_settings import (
17
  )
18
  from llama_cpp_agent.tools import WebSearchTool
19
  from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
 
 
 
20
 
21
  # UI related imports and definitions
22
  css = """
@@ -92,6 +95,26 @@ examples = [
92
  ["filetype:pdf intitle:python"]
93
  ]
94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
  # Utility functions
96
  def get_server_time():
97
  utc_time = datetime.now(timezone.utc)
@@ -117,7 +140,7 @@ class CitingSources(BaseModel):
117
 
118
  # Model function
119
  def get_model(temperature, top_p, repetition_penalty):
120
- return HuggingFaceHub(
121
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
122
  model_kwargs={
123
  "temperature": temperature,
@@ -127,6 +150,7 @@ def get_model(temperature, top_p, repetition_penalty):
127
  },
128
  huggingfacehub_api_token=huggingface_token
129
  )
 
130
  def get_messages_formatter_type(model_name):
131
  model_name = model_name.lower()
132
  if any(keyword in model_name for keyword in ["meta", "aya"]):
 
17
  )
18
  from llama_cpp_agent.tools import WebSearchTool
19
  from llama_cpp_agent.prompt_templates import web_search_system_prompt, research_system_prompt
20
+ from langchain import HuggingFaceHub
21
+ from llama_cpp_agent.llm_output_settings import LlmStructuredOutputSettings, LlmStructuredOutputType
22
+
23
 
24
  # UI related imports and definitions
25
  css = """
 
95
  ["filetype:pdf intitle:python"]
96
  ]
97
 
98
+
99
+ class HuggingFaceHubWrapper:
100
+ def __init__(self, repo_id, model_kwargs, huggingfacehub_api_token):
101
+ self.model = HuggingFaceHub(
102
+ repo_id=repo_id,
103
+ model_kwargs=model_kwargs,
104
+ huggingfacehub_api_token=huggingfacehub_api_token
105
+ )
106
+
107
+ def get_provider_default_settings(self):
108
+ return LlmStructuredOutputSettings(
109
+ output_type=LlmStructuredOutputType.JSON,
110
+ include_system_prompt=False,
111
+ include_user_prompt=False,
112
+ include_assistant_prompt=False,
113
+ )
114
+
115
+ def __call__(self, *args, **kwargs):
116
+ return self.model(*args, **kwargs)
117
+
118
  # Utility functions
119
  def get_server_time():
120
  utc_time = datetime.now(timezone.utc)
 
140
 
141
  # Model function
142
  def get_model(temperature, top_p, repetition_penalty):
143
+ return HuggingFaceHubWrapper(
144
  repo_id="mistralai/Mistral-7B-Instruct-v0.3",
145
  model_kwargs={
146
  "temperature": temperature,
 
150
  },
151
  huggingfacehub_api_token=huggingface_token
152
  )
153
+
154
  def get_messages_formatter_type(model_name):
155
  model_name = model_name.lower()
156
  if any(keyword in model_name for keyword in ["meta", "aya"]):