JohnKouf commited on
Commit
545c2d2
·
verified ·
1 Parent(s): a0b7ba4

Update Similarity.py

Browse files
Files changed (1) hide show
  1. Similarity.py +47 -0
Similarity.py CHANGED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import nltk
2
+ from sentence_transformers import util
3
+
4
+ class Similarity:
5
+ def __init__(self, model):
6
+ self.model = model
7
+ # Make sure nltk punkt tokenizer is downloaded
8
+ nltk.download('punkt')
9
+
10
+ def chunk_text(self, text, chunk_size=1400, overlap_size=200):
11
+ sentences = nltk.sent_tokenize(text)
12
+ chunks = []
13
+ current_chunk = ""
14
+ for sentence in sentences:
15
+ if len(current_chunk) + len(sentence) <= chunk_size:
16
+ current_chunk += " " + sentence if current_chunk else sentence
17
+ else:
18
+ chunks.append(current_chunk)
19
+ # Start the next chunk with overlap
20
+ current_chunk = sentence[:overlap_size] + sentence[overlap_size:]
21
+ if current_chunk:
22
+ chunks.append(current_chunk)
23
+ return chunks
24
+
25
+ def get_sim_text(self, text, claim_embedding, min_threshold=0.4, chunk_size=1500):
26
+ if not text:
27
+ return []
28
+
29
+ filtered_results = []
30
+ chunks = self.chunk_text(text, chunk_size)
31
+ if not chunks:
32
+ return []
33
+
34
+ chunk_embeddings = self.model.encode(
35
+ chunks, convert_to_tensor=True, show_progress_bar=False
36
+ )
37
+ chunk_similarities = util.cos_sim(claim_embedding, chunk_embeddings)
38
+
39
+ for chunk, similarity in zip(chunks, chunk_similarities[0]):
40
+ if similarity >= min_threshold:
41
+ print(chunk)
42
+ print()
43
+ print(similarity)
44
+ print("--------------------------------------------------")
45
+ filtered_results.append(chunk)
46
+
47
+ return filtered_results