JohnKouf commited on
Commit
817e62c
·
verified ·
1 Parent(s): 0abf1cd

Update Similarity.py

Browse files
Files changed (1) hide show
  1. Similarity.py +24 -42
Similarity.py CHANGED
@@ -1,47 +1,29 @@
1
- import nltk
2
- from sentence_transformers import util
 
3
 
4
- class Similarity:
5
- def __init__(self):
6
- self.model = SentenceTransformer("lighteternal/stsb-xlm-r-greek-transfer")
7
- # Make sure nltk punkt tokenizer is downloaded
8
- nltk.download('punkt')
9
 
10
- def chunk_text(self, text, chunk_size=1400, overlap_size=200):
11
- sentences = nltk.sent_tokenize(text)
12
- chunks = []
13
- current_chunk = ""
14
- for sentence in sentences:
15
- if len(current_chunk) + len(sentence) <= chunk_size:
16
- current_chunk += " " + sentence if current_chunk else sentence
17
- else:
18
- chunks.append(current_chunk)
19
- # Start the next chunk with overlap
20
- current_chunk = sentence[:overlap_size] + sentence[overlap_size:]
21
- if current_chunk:
22
- chunks.append(current_chunk)
23
- return chunks
24
 
25
- def get_sim_text(self, text, claim_embedding, min_threshold=0.4, chunk_size=1500):
26
- if not text:
27
- return []
28
-
29
- filtered_results = []
30
- chunks = self.chunk_text(text, chunk_size)
31
- if not chunks:
32
- return []
33
-
34
- chunk_embeddings = self.model.encode(
35
- chunks, convert_to_tensor=True, show_progress_bar=False
36
  )
37
- chunk_similarities = util.cos_sim(claim_embedding, chunk_embeddings)
38
-
39
- for chunk, similarity in zip(chunks, chunk_similarities[0]):
40
- if similarity >= min_threshold:
41
- print(chunk)
42
- print()
43
- print(similarity)
44
- print("--------------------------------------------------")
45
- filtered_results.append(chunk)
46
 
47
- return filtered_results
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from Similarity import Similarity
4
 
5
+ app = FastAPI()
6
+ similarity_model = Similarity()
 
 
 
7
 
8
+ class TextPairRequest(BaseModel):
9
+ text: str # The big text to chunk and search
10
+ claim: str # The claim text to embed and compare
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ @app.post("/get_sim_text")
13
+ def get_sim_text_endpoint(request: TextPairRequest):
14
+ try:
15
+ # Embed the claim text
16
+ claim_embedding = similarity_model.model.encode(
17
+ request.claim,
18
+ convert_to_tensor=True,
19
+ show_progress_bar=False
 
 
 
20
  )
21
+ # Call get_sim_text with defaults (min_threshold=0.4, chunk_size=1500)
22
+ results = similarity_model.get_sim_text(
23
+ request.text,
24
+ claim_embedding
25
+ )
26
+ return {"result": results}
 
 
 
27
 
28
+ except Exception as e:
29
+ return {"error": str(e)}