ofermend commited on
Commit
972d8c6
·
1 Parent(s): 1261947
Files changed (2) hide show
  1. agent.py +58 -45
  2. requirements.txt +1 -1
agent.py CHANGED
@@ -34,9 +34,17 @@ def extract_components_from_citation(citation: str) -> dict:
34
 
35
  return {'volume': int(volume_num), 'reporter': reporter, 'first_page': int(first_page)}
36
 
37
- def create_assistant_tools(cfg, agent_config):
 
 
 
 
 
 
 
38
 
39
  def get_opinion_text(
 
40
  case_citation: str = Field(description = citation_description),
41
  summarize: bool = Field(default=True, description="if True returns case summary, otherwise the full text of the case")
42
  ) -> str:
@@ -52,7 +60,7 @@ def create_assistant_tools(cfg, agent_config):
52
  citation_dict = extract_components_from_citation(case_citation)
53
  if not citation_dict:
54
  return f"Citation is invalid: {case_citation}."
55
- summarize_text = ToolsCatalog(agent_config).summarize_text
56
  reporter = citation_dict['reporter']
57
  volume_num = citation_dict['volume']
58
  first_page = citation_dict['first_page']
@@ -73,6 +81,7 @@ def create_assistant_tools(cfg, agent_config):
73
  return output
74
 
75
  def get_case_document_pdf(
 
76
  case_citation = Field(description = citation_description)
77
  ) -> str:
78
  """
@@ -92,6 +101,7 @@ def create_assistant_tools(cfg, agent_config):
92
  return f"https://static.case.law/{reporter}/{volume_num}.pdf#page={page_number}"
93
 
94
  def get_case_document_page(
 
95
  case_citation = Field(description = citation_description)
96
  ) -> str:
97
  """
@@ -110,6 +120,7 @@ def create_assistant_tools(cfg, agent_config):
110
  return url
111
 
112
  def get_case_name(
 
113
  case_citation = Field(description = citation_description)
114
  ) -> Tuple[str, str]:
115
  """
@@ -128,6 +139,7 @@ def create_assistant_tools(cfg, agent_config):
128
  return res["name"], res["name_abbreviation"]
129
 
130
  def get_cited_cases(
 
131
  case_citation = Field(description = citation_description)
132
  ) -> List[dict]:
133
  """
@@ -147,7 +159,7 @@ def create_assistant_tools(cfg, agent_config):
147
  citations = res["cites_to"]
148
  res = []
149
  for citation in citations[:10]:
150
- name, name_abbreviation = get_case_name(citation["cite"])
151
  res.append({
152
  "citation": citation["cite"],
153
  "name": name,
@@ -156,6 +168,7 @@ def create_assistant_tools(cfg, agent_config):
156
  return res
157
 
158
  def validate_url(
 
159
  url = Field(description = "A web url pointing to case-law document")
160
  ) -> str:
161
  """
@@ -166,50 +179,50 @@ def create_assistant_tools(cfg, agent_config):
166
  document_pattern = re.compile(r'^https://case.law/caselaw/?reporter=.*')
167
  return "URL is valid" if bool(pdf_pattern.match(url)) | bool(document_pattern.match(url)) else "URL is bad"
168
 
169
- class QueryCaselawArgs(BaseModel):
170
- query: str = Field(..., description="The user query.")
 
171
 
172
- vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
173
- vectara_corpus_key=cfg.corpus_key)
174
- summarizer = 'vectara-experimental-summary-ext-2023-12-11-med-omni'
175
 
176
- ask_caselaw = vec_factory.create_rag_tool(
177
- tool_name = "ask_caselaw",
178
- tool_description = "A tool for asking questions about case law in Alaska. ",
179
- tool_args_schema = QueryCaselawArgs,
180
- reranker = "chain", rerank_k = 100,
181
- rerank_chain = [
182
- {
183
- "type": "slingshot",
184
- "cutoff": 0.2
185
- },
186
- {
187
- "type": "mmr",
188
- "diversity_bias": 0.1
189
- },
190
- {
191
- "type": "userfn",
192
- "user_function": "max(1000 * get('$.score') - hours(seconds(to_unix_timestamp(now()) - to_unix_timestamp(datetime_parse(get('$.document_metadata.decision_date'), 'yyyy-MM-dd')))) / 24 / 365, 0)"
193
- }
194
- ],
195
- n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
196
- summary_num_results = 15,
197
- vectara_summarizer = summarizer,
198
- include_citations = True,
199
- )
200
 
201
- tools_factory = ToolsFactory()
202
- return (
203
- [ask_caselaw] +
204
- [tools_factory.create_tool(tool) for tool in [
205
- get_opinion_text,
206
- get_case_document_pdf,
207
- get_case_document_page,
208
- get_cited_cases,
209
- get_case_name,
210
- validate_url
211
- ]]
212
- )
213
 
214
  def get_agent_config() -> OmegaConf:
215
  cfg = OmegaConf.create({
@@ -254,7 +267,7 @@ def initialize_agent(_cfg, agent_progress_callback=None):
254
  """
255
  agent_config = AgentConfig()
256
  agent = Agent(
257
- tools=create_assistant_tools(_cfg, agent_config=agent_config),
258
  topic="Case law in Alaska",
259
  custom_instructions=legal_assistant_instructions,
260
  agent_progress_callback=agent_progress_callback,
 
34
 
35
  return {'volume': int(volume_num), 'reporter': reporter, 'first_page': int(first_page)}
36
 
37
+ class AgentTools:
38
+ def __init__(self, _cfg, agent_config):
39
+ self.tools_factory = ToolsFactory()
40
+ self.agent_config = agent_config
41
+ self.cfg = _cfg
42
+ self.vec_factory = VectaraToolFactory(vectara_api_key=_cfg.api_key,
43
+ vectara_corpus_key=_cfg.corpus_key)
44
+
45
 
46
  def get_opinion_text(
47
+ self,
48
  case_citation: str = Field(description = citation_description),
49
  summarize: bool = Field(default=True, description="if True returns case summary, otherwise the full text of the case")
50
  ) -> str:
 
60
  citation_dict = extract_components_from_citation(case_citation)
61
  if not citation_dict:
62
  return f"Citation is invalid: {case_citation}."
63
+ summarize_text = ToolsCatalog(self.agent_config).summarize_text
64
  reporter = citation_dict['reporter']
65
  volume_num = citation_dict['volume']
66
  first_page = citation_dict['first_page']
 
81
  return output
82
 
83
  def get_case_document_pdf(
84
+ self,
85
  case_citation = Field(description = citation_description)
86
  ) -> str:
87
  """
 
101
  return f"https://static.case.law/{reporter}/{volume_num}.pdf#page={page_number}"
102
 
103
  def get_case_document_page(
104
+ self,
105
  case_citation = Field(description = citation_description)
106
  ) -> str:
107
  """
 
120
  return url
121
 
122
  def get_case_name(
123
+ self,
124
  case_citation = Field(description = citation_description)
125
  ) -> Tuple[str, str]:
126
  """
 
139
  return res["name"], res["name_abbreviation"]
140
 
141
  def get_cited_cases(
142
+ self,
143
  case_citation = Field(description = citation_description)
144
  ) -> List[dict]:
145
  """
 
159
  citations = res["cites_to"]
160
  res = []
161
  for citation in citations[:10]:
162
+ name, name_abbreviation = self.get_case_name(citation["cite"])
163
  res.append({
164
  "citation": citation["cite"],
165
  "name": name,
 
168
  return res
169
 
170
  def validate_url(
171
+ self,
172
  url = Field(description = "A web url pointing to case-law document")
173
  ) -> str:
174
  """
 
179
  document_pattern = re.compile(r'^https://case.law/caselaw/?reporter=.*')
180
  return "URL is valid" if bool(pdf_pattern.match(url)) | bool(document_pattern.match(url)) else "URL is bad"
181
 
182
+ def get_tools(self):
183
+ class QueryCaselawArgs(BaseModel):
184
+ query: str = Field(..., description="The user query.")
185
 
186
+ vec_factory = VectaraToolFactory(vectara_api_key=cfg.api_key,
187
+ vectara_corpus_key=cfg.corpus_key)
188
+ summarizer = 'vectara-experimental-summary-ext-2023-12-11-med-omni'
189
 
190
+ ask_caselaw = vec_factory.create_rag_tool(
191
+ tool_name = "ask_caselaw",
192
+ tool_description = "A tool for asking questions about case law in Alaska. ",
193
+ tool_args_schema = QueryCaselawArgs,
194
+ reranker = "chain", rerank_k = 100,
195
+ rerank_chain = [
196
+ {
197
+ "type": "slingshot",
198
+ "cutoff": 0.2
199
+ },
200
+ {
201
+ "type": "mmr",
202
+ "diversity_bias": 0.1
203
+ },
204
+ {
205
+ "type": "userfn",
206
+ "user_function": "max(1000 * get('$.score') - hours(seconds(to_unix_timestamp(now()) - to_unix_timestamp(datetime_parse(get('$.document_metadata.decision_date'), 'yyyy-MM-dd')))) / 24 / 365, 0)"
207
+ }
208
+ ],
209
+ n_sentences_before = 2, n_sentences_after = 2, lambda_val = 0.005,
210
+ summary_num_results = 15,
211
+ vectara_summarizer = summarizer,
212
+ include_citations = True,
213
+ )
214
 
215
+ return (
216
+ [ask_caselaw] +
217
+ [self.tools_factory.create_tool(tool) for tool in [
218
+ self.get_opinion_text,
219
+ self.get_case_document_pdf,
220
+ self.get_case_document_page,
221
+ self.get_cited_cases,
222
+ self.get_case_name,
223
+ self.validate_url
224
+ ]]
225
+ )
 
226
 
227
  def get_agent_config() -> OmegaConf:
228
  cfg = OmegaConf.create({
 
267
  """
268
  agent_config = AgentConfig()
269
  agent = Agent(
270
+ tools=AgentTools(_cfg, agent_config).get_tools(),
271
  topic="Case law in Alaska",
272
  custom_instructions=legal_assistant_instructions,
273
  agent_progress_callback=agent_progress_callback,
requirements.txt CHANGED
@@ -6,4 +6,4 @@ streamlit-feedback==0.1.3
6
  uuid==1.30
7
  langdetect==1.0.9
8
  langcodes==3.4.0
9
- vectara-agentic==0.2.0
 
6
  uuid==1.30
7
  langdetect==1.0.9
8
  langcodes==3.4.0
9
+ vectara-agentic==0.2.1