transformer_api / Similarity.py
JohnKouf's picture
Update Similarity.py
58a769e verified
raw
history blame
1.71 kB
import nltk
from sentence_transformers import util
class Similarity:
def __init__(self, model):
self.model = SentenceTransformer("lighteternal/stsb-xlm-r-greek-transfer")
# Make sure nltk punkt tokenizer is downloaded
nltk.download('punkt')
def chunk_text(self, text, chunk_size=1400, overlap_size=200):
sentences = nltk.sent_tokenize(text)
chunks = []
current_chunk = ""
for sentence in sentences:
if len(current_chunk) + len(sentence) <= chunk_size:
current_chunk += " " + sentence if current_chunk else sentence
else:
chunks.append(current_chunk)
# Start the next chunk with overlap
current_chunk = sentence[:overlap_size] + sentence[overlap_size:]
if current_chunk:
chunks.append(current_chunk)
return chunks
def get_sim_text(self, text, claim_embedding, min_threshold=0.4, chunk_size=1500):
if not text:
return []
filtered_results = []
chunks = self.chunk_text(text, chunk_size)
if not chunks:
return []
chunk_embeddings = self.model.encode(
chunks, convert_to_tensor=True, show_progress_bar=False
)
chunk_similarities = util.cos_sim(claim_embedding, chunk_embeddings)
for chunk, similarity in zip(chunks, chunk_similarities[0]):
if similarity >= min_threshold:
print(chunk)
print()
print(similarity)
print("--------------------------------------------------")
filtered_results.append(chunk)
return filtered_results