JohnKouf commited on
Commit
58a769e
·
verified ·
1 Parent(s): 34c4ca6

Update Similarity.py

Browse files
Files changed (1) hide show
  1. Similarity.py +42 -24
Similarity.py CHANGED
@@ -1,29 +1,47 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
- from Similarity import Similarity
4
 
5
- app = FastAPI()
6
- similarity_model = Similarity()
 
 
 
7
 
8
- class TextPairRequest(BaseModel):
9
- text: str # The big text to chunk and search
10
- claim: str # The claim text to embed and compare
 
 
 
 
 
 
 
 
 
 
 
11
 
12
- @app.post("/get_sim_text")
13
- def get_sim_text_endpoint(request: TextPairRequest):
14
- try:
15
- # Embed the claim text
16
- claim_embedding = similarity_model.model.encode(
17
- request.claim,
18
- convert_to_tensor=True,
19
- show_progress_bar=False
20
- )
21
- # Call get_sim_text with defaults (min_threshold=0.4, chunk_size=1500)
22
- results = similarity_model.get_sim_text(
23
- request.text,
24
- claim_embedding
25
  )
26
- return {"result": results}
 
 
 
 
 
 
 
 
27
 
28
- except Exception as e:
29
- return {"error": str(e)}
 
1
+ import nltk
2
+ from sentence_transformers import util
 
3
 
4
+ class Similarity:
5
+ def __init__(self, model):
6
+ self.model = SentenceTransformer("lighteternal/stsb-xlm-r-greek-transfer")
7
+ # Make sure nltk punkt tokenizer is downloaded
8
+ nltk.download('punkt')
9
 
10
+ def chunk_text(self, text, chunk_size=1400, overlap_size=200):
11
+ sentences = nltk.sent_tokenize(text)
12
+ chunks = []
13
+ current_chunk = ""
14
+ for sentence in sentences:
15
+ if len(current_chunk) + len(sentence) <= chunk_size:
16
+ current_chunk += " " + sentence if current_chunk else sentence
17
+ else:
18
+ chunks.append(current_chunk)
19
+ # Start the next chunk with overlap
20
+ current_chunk = sentence[:overlap_size] + sentence[overlap_size:]
21
+ if current_chunk:
22
+ chunks.append(current_chunk)
23
+ return chunks
24
 
25
+ def get_sim_text(self, text, claim_embedding, min_threshold=0.4, chunk_size=1500):
26
+ if not text:
27
+ return []
28
+
29
+ filtered_results = []
30
+ chunks = self.chunk_text(text, chunk_size)
31
+ if not chunks:
32
+ return []
33
+
34
+ chunk_embeddings = self.model.encode(
35
+ chunks, convert_to_tensor=True, show_progress_bar=False
 
 
36
  )
37
+ chunk_similarities = util.cos_sim(claim_embedding, chunk_embeddings)
38
+
39
+ for chunk, similarity in zip(chunks, chunk_similarities[0]):
40
+ if similarity >= min_threshold:
41
+ print(chunk)
42
+ print()
43
+ print(similarity)
44
+ print("--------------------------------------------------")
45
+ filtered_results.append(chunk)
46
 
47
+ return filtered_results