burhan112 commited on
Commit
c477934
·
verified ·
1 Parent(s): ca2f154

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -72
app.py CHANGED
@@ -7,98 +7,83 @@ import google.generativeai as genai
7
  import re
8
  import os
9
 
10
- # Load data and FAISS index
11
- def load_data_and_index():
12
- docs_df = pd.read_pickle("data.pkl") # Adjust path for HF Spaces
13
- embeddings = np.array(docs_df['embeddings'].tolist(), dtype=np.float32)
14
- dimension = embeddings.shape[1]
15
- index = faiss.IndexFlatL2(dimension)
16
- index.add(embeddings)
17
- return docs_df, index
18
 
19
- docs_df, index = load_data_and_index()
20
 
21
- # Load SentenceTransformer
22
- minilm = SentenceTransformer('all-MiniLM-L6-v2')
23
 
24
- # Configure Gemini API using Hugging Face Secrets
25
- GEMINI_API_KEY = os.environ.get("GEMINI_API_KEY")
26
- if not GEMINI_API_KEY:
27
- raise ValueError("Gemini API key not found. Please set it in Hugging Face Spaces secrets.")
28
- genai.configure(api_key=GEMINI_API_KEY)
29
- model = genai.GenerativeModel('gemini-2.0-flash')
30
 
31
- # Preprocess text function
32
- def preprocess_text(text):
33
  text = text.lower()
34
- text = text.replace('\n', ' ').replace('\t', ' ')
35
- text = re.sub(r'[^\w\s.,;:>-]', ' ', text)
36
- text = ' '.join(text.split()).strip()
37
- return text
38
 
39
- # Retrieve documents
40
- def retrieve_docs(query, k=5):
41
- query_embedding = minilm.encode([query], show_progress_bar=False)[0].astype(np.float32)
42
- distances, indices = index.search(np.array([query_embedding]), k)
43
- retrieved_docs = docs_df.iloc[indices[0]][['label', 'text', 'source']]
44
- retrieved_docs['distance'] = distances[0]
45
- return retrieved_docs
46
 
47
- # RAG pipeline integrated into respond function
48
- def respond(message, system_message, max_tokens, temperature):
49
- # Preprocess the user message
50
- preprocessed_query = preprocess_text(message)
51
 
52
- # Retrieve relevant documents
53
- retrieved_docs = retrieve_docs(preprocessed_query, k=5)
54
- context = "\n".join(retrieved_docs['text'].tolist())
55
-
56
- # Construct the prompt with system message and RAG context, asking for structured response
57
- prompt = f"{system_message}\n\n"
58
- prompt += (
59
- f"Query: {message}\n"
60
- f"Relevant Context: {context}\n"
61
- f"Generate a short, concise response to the query based only on the provided context. "
62
- f"Format the response as a structured list (e.g., bullet points or numbered items) instead of a paragraph."
 
 
 
63
  )
64
-
65
- # Generate response with Gemini
66
- response = model.generate_content(
67
  prompt,
68
  generation_config=genai.types.GenerationConfig(
69
  max_output_tokens=max_tokens,
70
- temperature=temperature
71
  )
72
  )
73
- answer = response.text.strip()
74
- if not answer.endswith('.'):
75
- last_period = answer.rfind('.')
76
- if last_period != -1:
77
- answer = answer[:last_period + 1]
78
- else:
79
- answer += "."
80
-
81
- return answer
82
-
83
- # Simple Gradio Interface
84
- def chatbot_interface(message, system_message, max_tokens, temperature):
85
- return respond(message, system_message, max_tokens, temperature)
86
 
 
87
  demo = gr.Interface(
88
- fn=chatbot_interface,
89
  inputs=[
90
- gr.Textbox(label="Your Query", placeholder="Enter your medical question here..."),
91
  gr.Textbox(
92
- value="You are a medical AI assistant diagnosing patients based on their query, using relevant context from past records of other patients.",
93
- label="System Message"
94
  ),
95
- gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max Tokens"),
96
- gr.Slider(minimum=0.1, maximum=4.0, value=0.75, step=0.1, label="Temperature"),
97
  ],
98
- outputs=gr.Textbox(label="Response"),
99
- title="🏥 Medical Chat Assistant",
100
- description="A simple medical assistant that diagnoses patient queries using AI and past records, providing structured responses."
101
  )
102
 
103
  if __name__ == "__main__":
104
- demo.launch()
 
7
  import re
8
  import os
9
 
10
+ # Load documents and FAISS index
11
+ def load_index_and_data():
12
+ df = pd.read_pickle("data.pkl")
13
+ vecs = np.array(df['embeddings'].tolist(), dtype=np.float32)
14
+ idx = faiss.IndexFlatL2(vecs.shape[1])
15
+ idx.add(vecs)
16
+ return df, idx
 
17
 
18
+ docs_df, index = load_index_and_data()
19
 
20
+ # Embedding model and Gemini setup
21
+ encoder = SentenceTransformer("all-MiniLM-L6-v2")
22
 
23
+ API_KEY = os.getenv("GEMINI_API_KEY")
24
+ if not API_KEY:
25
+ raise EnvironmentError("Missing Gemini API key.")
26
+ genai.configure(api_key=API_KEY)
27
+ llm = genai.GenerativeModel("gemini-2.0-flash")
 
28
 
29
+ # Clean text input
30
+ def clean_text(text):
31
  text = text.lower()
32
+ text = re.sub(r"[^\w\s.,]", " ", text)
33
+ return " ".join(text.split())
 
 
34
 
35
+ # Retrieve relevant document context
36
+ def get_context(query, k=5):
37
+ q_vec = encoder.encode([query])[0].astype(np.float32)
38
+ _, indices = index.search(np.array([q_vec]), k)
39
+ return "\n".join(docs_df.iloc[indices[0]]["text"].tolist())
 
 
40
 
41
+ # RAG-based Gemini response generation
42
+ def generate_answer(user_input, system_note, max_tokens, temp):
43
+ query = clean_text(user_input)
44
+ context = get_context(query)
45
 
46
+ prompt = (
47
+ f"Role Description:\n{system_note}\n\n"
48
+ f"User Question:\n{user_input}\n\n"
49
+ f"Knowledge Extracted From Records:\n{context}\n\n"
50
+ f"Instructions:\n"
51
+ f"- Analyze the user's query using ONLY the above context.\n"
52
+ f"- Do NOT add external or made-up information.\n"
53
+ f"- Begin with a brief summary of the identified condition or concern.\n"
54
+ f"- Provide detailed reasoning and explanation in bullet points:\n"
55
+ f" Include possible causes, symptoms, and diagnostic considerations.\n"
56
+ f" Mention relevant terms or observations from context.\n"
57
+ f" • Explain how the context supports the conclusions.\n"
58
+ f"- End with a short, clear recommendation (if context permits).\n"
59
+ f"- Avoid medical advice unless the context contains it."
60
  )
61
+
62
+ result = llm.generate_content(
 
63
  prompt,
64
  generation_config=genai.types.GenerationConfig(
65
  max_output_tokens=max_tokens,
66
+ temperature=temp
67
  )
68
  )
69
+ return result.text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ # Gradio interface
72
  demo = gr.Interface(
73
+ fn=generate_answer,
74
  inputs=[
75
+ gr.Textbox(label="Ask Something", placeholder="Describe your symptom or condition..."),
76
  gr.Textbox(
77
+ value="You are a virtual medical assistant using past medical records to respond intelligently.",
78
+ label="System Role"
79
  ),
80
+ gr.Slider(50, 500, value=300, step=10, label="Max Tokens"),
81
+ gr.Slider(0.0, 1.0, value=0.4, step=0.1, label="Creativity (Temperature)")
82
  ],
83
+ outputs=gr.Textbox(label="AI Diagnosis"),
84
+ title="🩺 Smart Medical Query Assistant",
85
+ description="Submit a health-related question. The assistant analyzes similar past records to respond accurately and clearly."
86
  )
87
 
88
  if __name__ == "__main__":
89
+ demo.launch()