Hemang Thakur
deploy
d5c104e
import os
import asyncio
from concurrent.futures import ThreadPoolExecutor
from model2vec import StaticModel
from transformers import AutoConfig
from sentence_transformers import SentenceTransformer
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from src.utils.api_key_manager import APIKeyManager
from src.helpers.helper import chunk_text
class LateChunker:
def __init__(
self,
model_name='minishlab/potion-base-8M',
max_workers=os.cpu_count() * 2,
verbose=False
):
self.verbose = verbose
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.llm = APIKeyManager().get_llm()
self.model_name = model_name
# Initialize model using the fallback strategy
self.model, self.context_length = self._initialize_model()
# Initialize ThreadPoolExecutor
self.executor = ThreadPoolExecutor(max_workers=max_workers)
def _initialize_model(self):
sentence_transformer_error = None
model2vec_error = None
# First attempt: Try SentenceTransformer
try:
# Get the model config to check max context length
config = AutoConfig.from_pretrained(self.model_name)
max_length = config.max_position_embeddings
# Initialize SentenceTransformer model
model = SentenceTransformer(self.model_name, trust_remote_code=True)
model.max_seq_length = max_length # Set the correct max length
model.to(self.device).half()
context_length = model.max_seq_length
return model, context_length
except Exception as e:
sentence_transformer_error = str(e)
# Second attempt: Try Model2Vec
try:
# Initialize Model2Vec model
model = StaticModel.from_pretrained(
self.model_name
)
# Get max sequence length from static model config
context_length = model.config['seq_length']
return model, context_length
except Exception as e:
model2vec_error = str(e)
error_msg = (
f"Failed to load model {self.model_name}.\n"
f"SentenceTransformer error: {sentence_transformer_error}\n"
f"Model2Vec error: {model2vec_error}"
)
raise Exception(error_msg) from e
async def late_chunking(self, text, span_annotations, current_chunk_idx=None, total_chunks=None):
print(f"Processing chunk {current_chunk_idx+1}/{total_chunks}...") \
if self.verbose else None
# Get the current running event loop
loop = asyncio.get_running_loop()
# Generate chunk embeddings
chunk_embeddings = []
for start, end in span_annotations:
chunk_text = text[start:end]
print("Generating chunk embeddings...") if self.verbose else None
chunk_embedding = await loop.run_in_executor(
self.executor,
lambda: torch.tensor(
self.model.encode(
chunk_text,
convert_to_tensor=True
)
)
)
if isinstance(chunk_embedding, torch.Tensor):
chunk_embedding = chunk_embedding.clone().detach().to(self.device)
print(f"Chunk embedding shape: {chunk_embedding.shape}") if self.verbose else None
chunk_embeddings.append(chunk_embedding)
print("Late Chunking applied successfully!") if self.verbose else None
return chunk_embeddings if chunk_embeddings else None
def get_text_embedding(self, text):
embeddings = self.model.encode(text, convert_to_tensor=True)
if isinstance(embeddings, torch.Tensor):
return embeddings.clone().detach().to(self.device)
return torch.tensor(embeddings).to(self.device)
def calculate_embedding_similarities(self, text1_embedding, text2_embedding):
text1_embedding = text1_embedding.cpu().numpy()
text2_embedding = text2_embedding.cpu().numpy()
if text1_embedding.ndim == 1:
text1_embedding = text1_embedding.reshape(1, -1)
if text2_embedding.ndim == 1:
text2_embedding = text2_embedding.reshape(1, -1)
if text1_embedding.shape[1] != text2_embedding.shape[1]:
text1_embedding = text1_embedding.T
if text2_embedding.shape[1] != text1_embedding.shape[1]:
text2_embedding = text2_embedding.T
return cosine_similarity(text1_embedding, text2_embedding)[0]
def select_relevant_chunks(self, similarities, chunks, max_tokens):
sorted_indices = np.argsort(similarities)[::-1]
selected_chunks = []
total_tokens = 0
for i, idx in enumerate(sorted_indices):
print(f"Selected chunk {i+1}/{len(sorted_indices)} with similarity {similarities[idx]:.2f}") \
if self.verbose else None
chunk_tokens = self.llm.get_num_tokens(chunks[idx])
print(f"Chunk tokens: {chunk_tokens}") if self.verbose else None
if total_tokens + chunk_tokens > max_tokens:
print(f"Total tokens exceed max tokens allowed ({total_tokens} > {max_tokens}). \
Stopping chunk selection.") if self.verbose else None
break
selected_chunks.append((idx, chunks[idx]))
total_tokens += chunk_tokens
print("Sorting selected chunks...") if self.verbose else None
selected_chunks.sort(key=lambda x: x[0])
print("Selected chunks sorted successfully!") if self.verbose else None
return " ".join([chunk for _, chunk in selected_chunks])
async def chunker(self, text, query, max_chunk_length=1000, max_tokens=2048, overlap=200):
# Tokenize the entire text to check its length
total_tokens = self.llm.get_num_tokens(text)
# If the text is less than max tokens, return the text as is
if total_tokens <= max_tokens:
print(f"Text is less than the max tokens allowed ({total_tokens} <= {max_tokens}). \
Returning original text.") if self.verbose else None
return text
# Chunk the text if it exceeds max tokens
print(f"Text is greater than the max tokens allowed ({total_tokens} > {max_tokens}). \
Chunking text...") if self.verbose else None
chunks, span_annotations = chunk_text(
text,
max_chunk_length=max_chunk_length,
overlap=overlap,
# Use the smaller of either context length or max tokens
context_length=min(self.context_length, max_tokens)
)
print(f"Text chunked into {len(chunks)} macro chunks.") if self.verbose else None
# Process each macro chunk individually
chunk_embeddings = []
tasks = []
for i, macro_chunk in enumerate(chunks):
# Adjust span annotations relative to the current macro chunk
start_offset = span_annotations[i][0]
adjusted_spans = [
(start - start_offset, end - start_offset)
for start, end in span_annotations
if start >= start_offset and end <= start_offset + len(macro_chunk)
]
# Apply late chunking for the current macro chunk
tasks.append(self.late_chunking(macro_chunk, adjusted_spans, i, len(chunks)))
# Aggregate embeddings asynchronously
results = await asyncio.gather(*tasks)
chunk_embeddings = torch.stack([result[0] for result in results])
# Generate query embedding
print("Generating query embedding...") if self.verbose else None
query_embedding = self.get_text_embedding(query)
print(f"Query embedding shape: {query_embedding.shape}") if self.verbose else None
# Calculate similarities between query embedding and chunk embeddings
print("Calculating embedding similarities...") if self.verbose else None
similarities = self.calculate_embedding_similarities(query_embedding, chunk_embeddings)
print(f"Similarities shape: {similarities.shape}") if self.verbose else None
# Select relevant chunks based on similarity
print("Selecting relevant chunks...") if self.verbose else None
return self.select_relevant_chunks(similarities, chunks, max_tokens)
if __name__ == "__main__":
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from src.reasoning.reasoner import Reasoner
from src.search.search_engine import SearchEngine
from src.crawl.crawler import CustomCrawler
import time
search_engine = SearchEngine()
crawler = CustomCrawler()
reasoner = Reasoner()
chunking = LateChunker(verbose=True)
loop = asyncio.new_event_loop()
search1 = loop.run_until_complete(search_engine.search(
"What is the history of climate change and pollution since the pre-indutrial revolution?",
num_results=20,
exclude_filetypes=["pdf"]
))
urls = [result["link"] for result in search1]
search2 = loop.run_until_complete(search_engine.search(
"What is the impact of climate change on the Indian economy?",
num_results=20,
exclude_filetypes=["pdf"]
))
urls.extend([result["link"] for result in search2])
search3 = loop.run_until_complete(search_engine.search(
"What are some of the latest, state of art techniques used to fight climate change?",
num_results=20,
exclude_filetypes=["pdf"]
))
urls.extend([result["link"] for result in search3])
search4 = loop.run_until_complete(search_engine.search(
"What does the projection for climate change look like in the next 50 years?",
num_results=20,
exclude_filetypes=["pdf"]
))
urls.extend([result["link"] for result in search4])
search5 = loop.run_until_complete(search_engine.search(
"What efforts are being made by governments all around the world to combat climate change?",
num_results=20,
exclude_filetypes=["pdf"]
))
urls.extend([result["link"] for result in search5])
results = loop.run_until_complete(crawler.fetch_page_contents(
urls=urls,
max_attempts=1,
delay=0
))
text = "\n".join([f"Document {i}:\n{result}\n" for i, result in enumerate(results)])
num_tokens_before_chunking = chunking.llm.get_num_tokens(text)
start_time = time.perf_counter()
response = loop.run_until_complete(chunking.chunker(
text,
query="What is this text about? Give me a detailed answer",
max_tokens=128000
))
end_time = time.perf_counter()
num_tokens_after_chunking = chunking.llm.get_num_tokens(response)
print(f"\nResponse:\n{response}")
print(f"\nNumber of URLs: {len(urls)}")
print(f"\nNumber of tokens before late chunking: {num_tokens_before_chunking}")
print(f"\nNumber of tokens after late chunking: {num_tokens_after_chunking}")
print(f"\nTime taken: {end_time - start_time:.2f} seconds")
# Calculate cosine similarity between original text and response
def calculate_cosine_similarity(text1, text2):
vectorizer = TfidfVectorizer().fit_transform([text1, text2])
vectors = vectorizer.toarray()
return cosine_similarity(vectors)[0][1]
similarity = calculate_cosine_similarity(text, response)
print(f"\nCosine similarity between original text and late chunked text: {similarity * 100:.2f}%")