borodache commited on
Commit
fb0495b
·
verified ·
1 Parent(s): 7074e7f

Upload 6 files

Browse files

Uploading local files in order to get this space started

Files changed (6) hide show
  1. generator.py +62 -0
  2. main.py +117 -0
  3. rag_agent.py +64 -0
  4. reranker.py +22 -0
  5. retriever.py +43 -0
  6. text_embedder_encoder.py +56 -0
generator.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from retriever import Retriever
2
+ from reranker import Reranker
3
+ from anthropic import Anthropic
4
+ from typing import List
5
+
6
+
7
+ retriever = Retriever()
8
+ reranker = Reranker()
9
+
10
+
11
+ class RAGAgent:
12
+ def __init__(
13
+ self,
14
+ retriever=retriever,
15
+ reranker=reranker,
16
+ anthropic_api_key: str = "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",
17
+ model: str = "claude-3-5-sonnet-20241022",
18
+ max_tokens: int = 1024,
19
+ temperature: float = 0.0,
20
+ ):
21
+ self.retriever = retriever
22
+ self.reranker = reranker
23
+ self.client = Anthropic(api_key=anthropic_api_key)
24
+ self.model = model
25
+ self.max_tokens = max_tokens
26
+ self.temperature = temperature
27
+
28
+ def get_context(self, query: str) -> List[str]:
29
+ # Get initial candidates from retriever
30
+ retrieved_docs = self.retriever.search_similar(query)
31
+
32
+ # Rerank the candidates
33
+ context = self.reranker.rerank(query, retrieved_docs)
34
+
35
+ return context
36
+
37
+ def generate_prompt(self, context: List[str]) -> str:
38
+ context = "\n".join(context)
39
+ prompt = f"""
40
+ "אתה רופא שיניים, דובר עברית בלבד. קוראים לך 'רופא השיניים העברי האלקטרוני הראשון'. ענה למטופל על השאלה שלו על סמך הקונטקס הבא: {context}. הוסף כמה שיותר פרטים, ודאג שהתחביר יהיה תקין ויפה. תעצור כשאתה מרגיש שמיצית את עצמך. אל תמציא דברים. ואל תענה בשפות שהן לא עברית.
41
+ """
42
+ return prompt
43
+
44
+ def get_response(self, question: str) -> str:
45
+ # Get relevant context
46
+ context = self.get_context(question)
47
+
48
+ # Generate prompt with context
49
+ prompt = self.generate_prompt(context)
50
+
51
+ # Get response from Claude
52
+ response = self.client.messages.create(
53
+ model=self.model,
54
+ max_tokens=self.max_tokens,
55
+ temperature=self.temperature,
56
+ messages=[
57
+ {"role": "assistant", "content": prompt},
58
+ {"role": "user", "content": f"{question}"}
59
+ ]
60
+ )
61
+
62
+ return response.content[0].text
main.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import time
3
+
4
+
5
+ from rag_agent import RAGAgent
6
+
7
+
8
+ rag_agent = RAGAgent()
9
+
10
+
11
+ class ChatBot:
12
+ def __init__(self, rag_agent):
13
+ self.message_history = []
14
+ self.rag_agent = rag_agent
15
+
16
+ def get_response(self, message):
17
+ return self.rag_agent.get_response(message)
18
+
19
+ def chat(self, message, history):
20
+ time.sleep(1)
21
+ bot_response = self.get_response(message)
22
+ self.message_history.append((message, bot_response))
23
+ return bot_response
24
+
25
+
26
+ def create_chat_interface(rag_agent=rag_agent):
27
+ chatbot = ChatBot(rag_agent=rag_agent)
28
+
29
+ custom_css = """
30
+ #chatbot {
31
+ direction: rtl;
32
+ height: 400px;
33
+ }
34
+ .message {
35
+ font-size: 16px;
36
+ text-align: right;
37
+ }
38
+ .message-wrap {
39
+ direction: rtl !important;
40
+ }
41
+ .message-wrap > div {
42
+ direction: rtl !important;
43
+ text-align: right !important;
44
+ }
45
+ .input-box {
46
+ direction: rtl !important;
47
+ text-align: right !important;
48
+ }
49
+ .container {
50
+ direction: rtl;
51
+ }
52
+ .contain {
53
+ direction: rtl !important;
54
+ }
55
+ .bubble {
56
+ direction: rtl !important;
57
+ text-align: right !important;
58
+ }
59
+ textarea, input {
60
+ direction: rtl !important;
61
+ text-align: right !important;
62
+ }
63
+ .user-message, .bot-message {
64
+ direction: rtl !important;
65
+ text-align: right !important;
66
+ }
67
+ """
68
+
69
+ with gr.Blocks(css=custom_css) as interface:
70
+ with gr.Column(elem_classes="container"):
71
+ gr.Markdown("רופא שיניים אלקטרוני", rtl=True)
72
+
73
+ chatbot_component = gr.Chatbot(
74
+ [],
75
+ elem_id="chatbot",
76
+ height=400,
77
+ rtl=True,
78
+ elem_classes="message-wrap"
79
+ )
80
+
81
+ with gr.Row():
82
+ submit_btn = gr.Button("שלח", variant="primary")
83
+ txt = gr.Textbox(
84
+ show_label=False,
85
+ placeholder="הקלד את ההודעה שלך כאן...",
86
+ container=False,
87
+ elem_classes="input-box",
88
+ rtl=True
89
+ )
90
+
91
+ clear_btn = gr.Button("נקה צ'אט")
92
+
93
+ def user_message(user_message, history):
94
+ return "", history + [[user_message, None]]
95
+
96
+ def bot_message(history):
97
+ user_message = history[-1][0]
98
+ bot_response = chatbot.chat(user_message, history)
99
+ history[-1][1] = bot_response
100
+ return history
101
+
102
+ txt_msg = txt.submit(user_message, [txt, chatbot_component], [txt, chatbot_component], queue=False).then(
103
+ bot_message, chatbot_component, chatbot_component
104
+ )
105
+
106
+ submit_btn.click(user_message, [txt, chatbot_component], [txt, chatbot_component], queue=False).then(
107
+ bot_message, chatbot_component, chatbot_component
108
+ )
109
+
110
+ clear_btn.click(lambda: None, None, chatbot_component, queue=False)
111
+
112
+ return interface
113
+
114
+
115
+ # Launch the interface
116
+ chat_interface = create_chat_interface(rag_agent=rag_agent)
117
+ chat_interface.launch(share=True)
rag_agent.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from anthropic import Anthropic
2
+ from typing import List
3
+
4
+
5
+ from retriever import Retriever
6
+ from reranker import Reranker
7
+
8
+
9
+ retriever = Retriever()
10
+ reranker = Reranker()
11
+
12
+
13
+ class RAGAgent:
14
+ def __init__(
15
+ self,
16
+ retriever=retriever,
17
+ reranker=reranker,
18
+ anthropic_api_key: str = "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",
19
+ model: str = "claude-3-5-sonnet-20241022",
20
+ max_tokens: int = 1024,
21
+ temperature: float = 0.0,
22
+ ):
23
+ self.retriever = retriever
24
+ self.reranker = reranker
25
+ self.client = Anthropic(api_key=anthropic_api_key)
26
+ self.model = model
27
+ self.max_tokens = max_tokens
28
+ self.temperature = temperature
29
+
30
+ def get_context(self, query: str) -> List[str]:
31
+ # Get initial candidates from retriever
32
+ retrieved_docs = self.retriever.search_similar(query)
33
+
34
+ # Rerank the candidates
35
+ context = self.reranker.rerank(query, retrieved_docs)
36
+
37
+ return context
38
+
39
+ def generate_prompt(self, context: List[str]) -> str:
40
+ context = "\n".join(context)
41
+ prompt = f"""
42
+ "אתה רופא שיניים, דובר עברית בלבד. קוראים לך 'רופא השיניים האלקטרוני העברי הראשון', ענה למטופל על השאלה שלו על סמך הקונטקס הבא: {context}. הוסף כמה שיותר פרטים, ודאג שהתחביר יהיה תקין ויפה. תעצור כשאתה מרגיש שמיצית את עצמך. אל תמציא דברים. ואל תענה בשפות שהן לא עברית.
43
+ """
44
+ return prompt
45
+
46
+ def get_response(self, question: str) -> str:
47
+ # Get relevant context
48
+ context = self.get_context(question)
49
+
50
+ # Generate prompt with context
51
+ prompt = self.generate_prompt(context)
52
+
53
+ # Get response from Claude
54
+ response = self.client.messages.create(
55
+ model=self.model,
56
+ max_tokens=self.max_tokens,
57
+ temperature=self.temperature,
58
+ messages=[
59
+ {"role": "assistant", "content": prompt},
60
+ {"role": "user", "content": f"{question}"}
61
+ ]
62
+ )
63
+
64
+ return response.content[0].text
reranker.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sklearn.metrics.pairwise import cosine_similarity
2
+
3
+
4
+ from text_embedder_encoder import TextEmbedder
5
+
6
+
7
+ class Reranker:
8
+ def __init__(self):
9
+ self.text_embedder = TextEmbedder()
10
+
11
+ def rerank(self, query, retrieved_docs, top_n=5):
12
+ # Encode query and documents
13
+ query_embedding = self.text_embedder.encode(query)
14
+ doc_embeddings = self.text_embedder.encode_many(retrieved_docs)
15
+ similarity_scores = cosine_similarity([query_embedding], doc_embeddings)[0]
16
+
17
+ similarity_scores_with_idxes = list(zip(similarity_scores, range(len(similarity_scores))))
18
+ similarity_scores_with_idxes.sort(reverse=True)
19
+ similarity_scores_with_idxes_final = similarity_scores_with_idxes[:top_n]
20
+ reranked_docs = [retrieved_docs[idx] for score, idx in similarity_scores_with_idxes_final if score >= 0.7]
21
+
22
+ return reranked_docs
retriever.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pinecone import Pinecone
2
+
3
+
4
+ from text_embedder_encoder import TextEmbedder, encoder_model_name
5
+
6
+
7
+ class Retriever:
8
+ def __init__(self,
9
+ pinecone_api_key="pcsk_468XZz_QfKbP3dWCh6nLatJjd882DGF5HDh6TupzEAeRpFLAMtDfDiPDNRC537Q4jAtxhV",
10
+ index_name=f"hebrew-dentist-qa-{encoder_model_name.replace('/', '-')}".lower()):
11
+ # Initialize Pinecone connection
12
+ self.pc = Pinecone(api_key=pinecone_api_key)
13
+ self.index_name = index_name
14
+ self.text_embedder = TextEmbedder()
15
+ self.vector_dim = 768
16
+
17
+ def search_similar(self, query_text, top_k=50):
18
+ """
19
+ Search for similar content using vector similarity in Pinecone
20
+ """
21
+ try:
22
+ # Generate embedding for query
23
+ query_vector = self.text_embedder.encode(query_text)
24
+
25
+ # Get Pinecone index
26
+ index = self.pc.Index(self.index_name)
27
+
28
+ # Execute search
29
+ results = index.query(
30
+ vector=query_vector,
31
+ top_k=top_k,
32
+ include_metadata=True,
33
+ )
34
+
35
+ answers = []
36
+ for match in results['matches']:
37
+ answer = match['metadata']['answer']
38
+ answers.append(answer)
39
+
40
+ return answers
41
+ except Exception as e:
42
+ print(f"Error performing similarity search: {e}")
43
+ return []
text_embedder_encoder.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from sentence_transformers import SentenceTransformer
4
+ from typing import List
5
+
6
+
7
+ encoder_model_name = 'MPA/sambert'
8
+
9
+
10
+ class TextEmbedder:
11
+ def __init__(self):
12
+ """
13
+ Initialize the Hebrew text embedder using dictabert-large-heq model
14
+ """
15
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ self.model = SentenceTransformer(encoder_model_name)
17
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+ self.model.to(self.device)
19
+ self.model.eval()
20
+
21
+ def encode(self, text) -> np.ndarray:
22
+ """
23
+ Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
24
+
25
+ Args:
26
+ text (str): Hebrew text to encode
27
+ model_name (str): Name of the model to use
28
+ # max_seq_length (int): Maximum sequence length for the model
29
+ strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
30
+
31
+ Returns:
32
+ numpy.ndarray: Text embedding
33
+ """
34
+ # Get embeddings for the text
35
+ embeddings = [float(x) for x in self.model.encode([text])[0]]
36
+
37
+ return embeddings
38
+
39
+ def encode_many(self, texts: List[str]) -> np.ndarray:
40
+ """
41
+ Encode Hebrew text using LaBSE model with handling for texts longer than max_seq_length.
42
+
43
+ Args:
44
+ text (str): Hebrew text to encode
45
+ model_name (str): Name of the model to use
46
+ # max_seq_length (int): Maximum sequence length for the model
47
+ strategy (str): Strategy for combining sentence embeddings ('mean' or 'concat')
48
+
49
+ Returns:
50
+ numpy.ndarray: Text embedding
51
+ """
52
+ # Get embeddings for the text
53
+ embeddings = self.model.encode(texts)
54
+ embeddings = [[float(x) for x in embedding] for embedding in embeddings]
55
+
56
+ return embeddings