Spaces:
Runtime error
Runtime error
import nltk | |
from sentence_transformers import util | |
class Similarity: | |
def __init__(self, model): | |
self.model = model | |
# Make sure nltk punkt tokenizer is downloaded | |
nltk.download('punkt') | |
def chunk_text(self, text, chunk_size=1400, overlap_size=200): | |
sentences = nltk.sent_tokenize(text) | |
chunks = [] | |
current_chunk = "" | |
for sentence in sentences: | |
if len(current_chunk) + len(sentence) <= chunk_size: | |
current_chunk += " " + sentence if current_chunk else sentence | |
else: | |
chunks.append(current_chunk) | |
# Start the next chunk with overlap | |
current_chunk = sentence[:overlap_size] + sentence[overlap_size:] | |
if current_chunk: | |
chunks.append(current_chunk) | |
return chunks | |
def get_sim_text(self, text, claim_embedding, min_threshold=0.4, chunk_size=1500): | |
if not text: | |
return [] | |
filtered_results = [] | |
chunks = self.chunk_text(text, chunk_size) | |
if not chunks: | |
return [] | |
chunk_embeddings = self.model.encode( | |
chunks, convert_to_tensor=True, show_progress_bar=False | |
) | |
chunk_similarities = util.cos_sim(claim_embedding, chunk_embeddings) | |
for chunk, similarity in zip(chunks, chunk_similarities[0]): | |
if similarity >= min_threshold: | |
print(chunk) | |
print() | |
print(similarity) | |
print("--------------------------------------------------") | |
filtered_results.append(chunk) | |
return filtered_results | |