|
import os |
|
import asyncio |
|
from concurrent.futures import ThreadPoolExecutor |
|
from model2vec import StaticModel |
|
from transformers import AutoConfig |
|
from sentence_transformers import SentenceTransformer |
|
import torch |
|
import numpy as np |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from src.utils.api_key_manager import APIKeyManager |
|
from src.helpers.helper import chunk_text |
|
|
|
class LateChunker: |
|
def __init__( |
|
self, |
|
model_name='minishlab/potion-base-8M', |
|
max_workers=os.cpu_count() * 2, |
|
verbose=False |
|
): |
|
self.verbose = verbose |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.llm = APIKeyManager().get_llm() |
|
self.model_name = model_name |
|
|
|
|
|
self.model, self.context_length = self._initialize_model() |
|
|
|
|
|
self.executor = ThreadPoolExecutor(max_workers=max_workers) |
|
|
|
def _initialize_model(self): |
|
sentence_transformer_error = None |
|
model2vec_error = None |
|
|
|
|
|
try: |
|
|
|
config = AutoConfig.from_pretrained(self.model_name) |
|
max_length = config.max_position_embeddings |
|
|
|
|
|
model = SentenceTransformer(self.model_name, trust_remote_code=True) |
|
model.max_seq_length = max_length |
|
model.to(self.device).half() |
|
context_length = model.max_seq_length |
|
return model, context_length |
|
except Exception as e: |
|
sentence_transformer_error = str(e) |
|
|
|
|
|
try: |
|
|
|
model = StaticModel.from_pretrained( |
|
self.model_name |
|
) |
|
|
|
context_length = model.config['seq_length'] |
|
return model, context_length |
|
except Exception as e: |
|
model2vec_error = str(e) |
|
error_msg = ( |
|
f"Failed to load model {self.model_name}.\n" |
|
f"SentenceTransformer error: {sentence_transformer_error}\n" |
|
f"Model2Vec error: {model2vec_error}" |
|
) |
|
raise Exception(error_msg) from e |
|
|
|
async def late_chunking(self, text, span_annotations, current_chunk_idx=None, total_chunks=None): |
|
print(f"Processing chunk {current_chunk_idx+1}/{total_chunks}...") \ |
|
if self.verbose else None |
|
|
|
|
|
loop = asyncio.get_running_loop() |
|
|
|
|
|
chunk_embeddings = [] |
|
for start, end in span_annotations: |
|
chunk_text = text[start:end] |
|
print("Generating chunk embeddings...") if self.verbose else None |
|
chunk_embedding = await loop.run_in_executor( |
|
self.executor, |
|
lambda: torch.tensor( |
|
self.model.encode( |
|
chunk_text, |
|
convert_to_tensor=True |
|
) |
|
) |
|
) |
|
if isinstance(chunk_embedding, torch.Tensor): |
|
chunk_embedding = chunk_embedding.clone().detach().to(self.device) |
|
|
|
print(f"Chunk embedding shape: {chunk_embedding.shape}") if self.verbose else None |
|
chunk_embeddings.append(chunk_embedding) |
|
|
|
print("Late Chunking applied successfully!") if self.verbose else None |
|
return chunk_embeddings if chunk_embeddings else None |
|
|
|
def get_text_embedding(self, text): |
|
embeddings = self.model.encode(text, convert_to_tensor=True) |
|
if isinstance(embeddings, torch.Tensor): |
|
return embeddings.clone().detach().to(self.device) |
|
return torch.tensor(embeddings).to(self.device) |
|
|
|
def calculate_embedding_similarities(self, text1_embedding, text2_embedding): |
|
text1_embedding = text1_embedding.cpu().numpy() |
|
text2_embedding = text2_embedding.cpu().numpy() |
|
|
|
if text1_embedding.ndim == 1: |
|
text1_embedding = text1_embedding.reshape(1, -1) |
|
if text2_embedding.ndim == 1: |
|
text2_embedding = text2_embedding.reshape(1, -1) |
|
|
|
if text1_embedding.shape[1] != text2_embedding.shape[1]: |
|
text1_embedding = text1_embedding.T |
|
if text2_embedding.shape[1] != text1_embedding.shape[1]: |
|
text2_embedding = text2_embedding.T |
|
|
|
return cosine_similarity(text1_embedding, text2_embedding)[0] |
|
|
|
def select_relevant_chunks(self, similarities, chunks, max_tokens): |
|
sorted_indices = np.argsort(similarities)[::-1] |
|
selected_chunks = [] |
|
total_tokens = 0 |
|
|
|
for i, idx in enumerate(sorted_indices): |
|
print(f"Selected chunk {i+1}/{len(sorted_indices)} with similarity {similarities[idx]:.2f}") \ |
|
if self.verbose else None |
|
chunk_tokens = self.llm.get_num_tokens(chunks[idx]) |
|
print(f"Chunk tokens: {chunk_tokens}") if self.verbose else None |
|
|
|
if total_tokens + chunk_tokens > max_tokens: |
|
print(f"Total tokens exceed max tokens allowed ({total_tokens} > {max_tokens}). \ |
|
Stopping chunk selection.") if self.verbose else None |
|
break |
|
|
|
selected_chunks.append((idx, chunks[idx])) |
|
total_tokens += chunk_tokens |
|
|
|
print("Sorting selected chunks...") if self.verbose else None |
|
selected_chunks.sort(key=lambda x: x[0]) |
|
print("Selected chunks sorted successfully!") if self.verbose else None |
|
return " ".join([chunk for _, chunk in selected_chunks]) |
|
|
|
async def chunker(self, text, query, max_chunk_length=1000, max_tokens=2048, overlap=200): |
|
|
|
total_tokens = self.llm.get_num_tokens(text) |
|
|
|
|
|
if total_tokens <= max_tokens: |
|
print(f"Text is less than the max tokens allowed ({total_tokens} <= {max_tokens}). \ |
|
Returning original text.") if self.verbose else None |
|
return text |
|
|
|
|
|
print(f"Text is greater than the max tokens allowed ({total_tokens} > {max_tokens}). \ |
|
Chunking text...") if self.verbose else None |
|
chunks, span_annotations = chunk_text( |
|
text, |
|
max_chunk_length=max_chunk_length, |
|
overlap=overlap, |
|
|
|
context_length=min(self.context_length, max_tokens) |
|
) |
|
print(f"Text chunked into {len(chunks)} macro chunks.") if self.verbose else None |
|
|
|
|
|
chunk_embeddings = [] |
|
tasks = [] |
|
|
|
for i, macro_chunk in enumerate(chunks): |
|
|
|
start_offset = span_annotations[i][0] |
|
adjusted_spans = [ |
|
(start - start_offset, end - start_offset) |
|
for start, end in span_annotations |
|
if start >= start_offset and end <= start_offset + len(macro_chunk) |
|
] |
|
|
|
|
|
tasks.append(self.late_chunking(macro_chunk, adjusted_spans, i, len(chunks))) |
|
|
|
|
|
results = await asyncio.gather(*tasks) |
|
chunk_embeddings = torch.stack([result[0] for result in results]) |
|
|
|
|
|
print("Generating query embedding...") if self.verbose else None |
|
query_embedding = self.get_text_embedding(query) |
|
print(f"Query embedding shape: {query_embedding.shape}") if self.verbose else None |
|
|
|
|
|
print("Calculating embedding similarities...") if self.verbose else None |
|
similarities = self.calculate_embedding_similarities(query_embedding, chunk_embeddings) |
|
print(f"Similarities shape: {similarities.shape}") if self.verbose else None |
|
|
|
|
|
print("Selecting relevant chunks...") if self.verbose else None |
|
return self.select_relevant_chunks(similarities, chunks, max_tokens) |
|
|
|
if __name__ == "__main__": |
|
from sklearn.feature_extraction.text import TfidfVectorizer |
|
from sklearn.metrics.pairwise import cosine_similarity |
|
from src.reasoning.reasoner import Reasoner |
|
from src.search.search_engine import SearchEngine |
|
from src.crawl.crawler import CustomCrawler |
|
import time |
|
|
|
search_engine = SearchEngine() |
|
crawler = CustomCrawler() |
|
reasoner = Reasoner() |
|
chunking = LateChunker(verbose=True) |
|
|
|
loop = asyncio.new_event_loop() |
|
|
|
search1 = loop.run_until_complete(search_engine.search( |
|
"What is the history of climate change and pollution since the pre-indutrial revolution?", |
|
num_results=20, |
|
exclude_filetypes=["pdf"] |
|
)) |
|
urls = [result["link"] for result in search1] |
|
search2 = loop.run_until_complete(search_engine.search( |
|
"What is the impact of climate change on the Indian economy?", |
|
num_results=20, |
|
exclude_filetypes=["pdf"] |
|
)) |
|
urls.extend([result["link"] for result in search2]) |
|
search3 = loop.run_until_complete(search_engine.search( |
|
"What are some of the latest, state of art techniques used to fight climate change?", |
|
num_results=20, |
|
exclude_filetypes=["pdf"] |
|
)) |
|
urls.extend([result["link"] for result in search3]) |
|
search4 = loop.run_until_complete(search_engine.search( |
|
"What does the projection for climate change look like in the next 50 years?", |
|
num_results=20, |
|
exclude_filetypes=["pdf"] |
|
)) |
|
urls.extend([result["link"] for result in search4]) |
|
search5 = loop.run_until_complete(search_engine.search( |
|
"What efforts are being made by governments all around the world to combat climate change?", |
|
num_results=20, |
|
exclude_filetypes=["pdf"] |
|
)) |
|
urls.extend([result["link"] for result in search5]) |
|
|
|
results = loop.run_until_complete(crawler.fetch_page_contents( |
|
urls=urls, |
|
max_attempts=1, |
|
delay=0 |
|
)) |
|
text = "\n".join([f"Document {i}:\n{result}\n" for i, result in enumerate(results)]) |
|
|
|
num_tokens_before_chunking = chunking.llm.get_num_tokens(text) |
|
start_time = time.perf_counter() |
|
response = loop.run_until_complete(chunking.chunker( |
|
text, |
|
query="What is this text about? Give me a detailed answer", |
|
max_tokens=128000 |
|
)) |
|
end_time = time.perf_counter() |
|
num_tokens_after_chunking = chunking.llm.get_num_tokens(response) |
|
print(f"\nResponse:\n{response}") |
|
print(f"\nNumber of URLs: {len(urls)}") |
|
print(f"\nNumber of tokens before late chunking: {num_tokens_before_chunking}") |
|
print(f"\nNumber of tokens after late chunking: {num_tokens_after_chunking}") |
|
print(f"\nTime taken: {end_time - start_time:.2f} seconds") |
|
|
|
|
|
def calculate_cosine_similarity(text1, text2): |
|
vectorizer = TfidfVectorizer().fit_transform([text1, text2]) |
|
vectors = vectorizer.toarray() |
|
return cosine_similarity(vectors)[0][1] |
|
|
|
similarity = calculate_cosine_similarity(text, response) |
|
print(f"\nCosine similarity between original text and late chunked text: {similarity * 100:.2f}%") |
|
|