Spaces:
Runtime error
Runtime error
| from typing import List | |
| import chromadb | |
| from transformers import AutoTokenizer, AutoModel | |
| from chromadb.config import Settings | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import os | |
| from hazm import * | |
| class RAG: | |
| def __init__(self, | |
| model_name: str = "HooshvareLab/bert-base-parsbert-uncased", | |
| collection_name: str = "legal_cases", | |
| persist_directory: str = "chromadb_collections/", | |
| top_k: int = 2 | |
| ) -> None: | |
| self.cases_df = pd.read_csv('processed_cases.csv') | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name) | |
| self.normalizer = Normalizer() | |
| self.top_k = top_k | |
| self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model.to(self.device) | |
| self.client = chromadb.PersistentClient(path=persist_directory) | |
| self.collection = self.client.get_collection(name=collection_name) | |
| def query_pre_process(self, query: str) -> str: | |
| return self.normalizer.normalize(query) | |
| def embed_single_text(self, text: str) -> np.ndarray: | |
| inputs = self.tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) | |
| inputs = {key: value.to(self.device) for key, value in inputs.items()} | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| return outputs.last_hidden_state.mean(dim=1).squeeze().cpu().numpy() | |
| def extract_case_title_from_df(self, case_id: str) -> str: | |
| case_id_int = int(case_id.split("_")[1]) | |
| try: | |
| case_title = self.cases_df.loc[case_id_int, 'title'] | |
| return case_title | |
| except KeyError: | |
| return "Case ID not found in DataFrame." | |
| def extract_case_text_from_df(self, case_id: str) -> str: | |
| case_id_int = int(case_id.split("_")[1]) | |
| try: | |
| case_text = self.cases_df.loc[case_id_int, 'text'] | |
| return case_text | |
| except KeyError: | |
| return "Case ID not found in DataFrame." | |
| def retrieve_relevant_cases(self, query_text: str) -> List[str]: | |
| normalized_query_text = self.query_pre_process(query_text) | |
| query_embedding = self.embed_single_text(normalized_query_text) | |
| query_embedding_list = query_embedding.tolist() | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding_list], | |
| n_results=self.top_k | |
| ) | |
| retrieved_cases = [] | |
| for i in range(len(results['metadatas'][0])): | |
| case_id = results['ids'][0][i] | |
| case_text = self.extract_case_text_from_df(case_id) | |
| case_title = self.extract_case_title_from_df(case_id) | |
| retrieved_cases.append({ | |
| "text": case_text, | |
| "title": case_title | |
| }) | |
| return retrieved_cases | |
| def get_information(self, query: str) -> List[str]: | |
| return self.retrieve_relevant_cases(query) | |
| from typing import List | |
| class RAG: | |
| def __init__(self) -> None: | |
| pass | |
| def get_information(self, query: str) -> List[str]: | |
| return [] | |
| def query_pre_process(self, query: str): | |
| return query |