File size: 2,105 Bytes
58a769e
70988c4
9bfb691
a20a987
58a769e
279453f
a20a987
 
 
 
 
 
 
 
 
 
 
 
545c2d2
58a769e
 
 
 
 
 
 
 
 
 
 
 
 
 
545c2d2
58a769e
a20a987
 
58a769e
 
 
 
 
 
 
 
 
 
817e62c
58a769e
 
 
 
 
 
 
 
 
545c2d2
58a769e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import nltk
nltk.data.path.append("./nltk_data")
from sentence_transformers import SentenceTransformer, util

class Similarity:
    def __init__(self):
        self.model = None
        # Download punkt tokenizer once, suppress if already present
        try:
            nltk.data.find('tokenizers/punkt')
        except LookupError:
            nltk.download('punkt', download_dir='./nltk_data')

    def load_model(self):
        if self.model is None:
            print("Loading SentenceTransformer model...")
            self.model = SentenceTransformer("lighteternal/stsb-xlm-r-greek-transfer")
            print("Model loaded.")

    def chunk_text(self, text, chunk_size=1400, overlap_size=200):
        sentences = nltk.sent_tokenize(text)
        chunks = []
        current_chunk = ""
        for sentence in sentences:
            if len(current_chunk) + len(sentence) <= chunk_size:
                current_chunk += " " + sentence if current_chunk else sentence
            else:
                chunks.append(current_chunk)
                # Start the next chunk with overlap
                current_chunk = sentence[:overlap_size] + sentence[overlap_size:]
        if current_chunk:
            chunks.append(current_chunk)
        return chunks

    def get_sim_text(self, text, claim_embedding, min_threshold=0.4, chunk_size=1500):
        self.load_model()

        if not text:
            return []

        filtered_results = []
        chunks = self.chunk_text(text, chunk_size)
        if not chunks:
            return []

        chunk_embeddings = self.model.encode(
            chunks, convert_to_tensor=True, show_progress_bar=False
        )
        chunk_similarities = util.cos_sim(claim_embedding, chunk_embeddings)

        for chunk, similarity in zip(chunks, chunk_similarities[0]):
            if similarity >= min_threshold:
                print(chunk)
                print()
                print(similarity)
                print("--------------------------------------------------")
                filtered_results.append(chunk)

        return filtered_results