rishirajbal commited on
Commit
448426c
·
verified ·
1 Parent(s): 755c3f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -26
app.py CHANGED
@@ -7,9 +7,10 @@ import requests
7
  from langchain_groq import ChatGroq
8
  from langchain.agents import initialize_agent
9
  from langchain.prompts import PromptTemplate
10
- from langchain_core.runnables import RunnableSequence # modern
11
  from langchain.tools import StructuredTool
12
 
 
13
  IMG_HEIGHT = 256
14
  IMG_WIDTH = 256
15
 
@@ -28,27 +29,31 @@ print("Loading model...")
28
  model = tf.keras.models.load_model(model_path, compile=False)
29
 
30
 
31
- # === Segmentation + Stats ===
32
  def classify_image_and_stats(image_input):
33
  img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
34
- img = img / 255.0
35
- img = np.expand_dims(img, axis=0)
36
 
37
- prediction = model.predict(img)[0] # (256, 256, 1)
38
- mask = (prediction > 0.5).astype(np.uint8) * 255
39
 
40
  if mask.ndim == 3 and mask.shape[-1] == 1:
41
  mask = np.squeeze(mask, axis=-1)
42
 
43
  # Tumor stats
44
- tumor_area = np.sum(prediction > 0.5)
45
  total_area = IMG_HEIGHT * IMG_WIDTH
46
  tumor_ratio = tumor_area / total_area
47
 
48
- if tumor_ratio > 0.01:
49
- tumor_label = "Tumor Detected"
50
- else:
51
- tumor_label = "No Tumor Detected"
 
 
 
 
52
 
53
  stats = {
54
  "tumor_area": int(tumor_area),
@@ -57,51 +62,58 @@ def classify_image_and_stats(image_input):
57
  "tumor_label": tumor_label
58
  }
59
 
60
- return mask, stats
61
 
62
 
63
  # === Gradio handler ===
64
  def rishigpt_handler(image_input, groq_api_key):
65
  os.environ["GROQ_API_KEY"] = groq_api_key
66
 
67
- mask, stats = classify_image_and_stats(image_input)
68
 
69
  def segment_brain_tool(input_text: str) -> str:
70
- return f"Tumor label: {stats['tumor_label']}. Tumor area: {stats['tumor_area']}. Ratio: {stats['tumor_ratio']:.4f}."
 
 
 
 
71
 
72
  tool = StructuredTool.from_function(
73
  segment_brain_tool,
74
  name="segment_brain",
75
- description="Provide tumor segmentation stats for the MRI image. Takes dummy input text."
76
  )
77
 
78
  llm = ChatGroq(
79
  model="meta-llama/llama-4-scout-17b-16e-instruct",
80
- temperature=0.3
81
  )
82
 
83
  agent = initialize_agent(
84
  tools=[tool],
85
  llm=llm,
86
  agent="zero-shot-react-description",
87
- verbose=True
88
  )
89
 
90
- user_query = "Get segmentation details"
91
  classification = agent.run(user_query)
92
 
 
93
  prompt = PromptTemplate(
94
  input_variables=["result"],
95
  template=(
96
- "You are a medical imaging expert. Based on this tumor analysis result: {result}, "
97
- "explain what this means for the patient in simple language."
 
 
98
  )
99
  )
100
 
101
  chain = prompt | llm
102
- description = chain.invoke({"result": classification})
103
 
104
- return mask, description
105
 
106
 
107
  # === Gradio UI ===
@@ -111,8 +123,8 @@ inputs = [
111
  ]
112
 
113
  outputs = [
114
- gr.Image(type="numpy", label="Tumor Segmentation Mask"),
115
- gr.Textbox(label="Medical Explanation")
116
  ]
117
 
118
  if __name__ == "__main__":
@@ -120,6 +132,6 @@ if __name__ == "__main__":
120
  fn=rishigpt_handler,
121
  inputs=inputs,
122
  outputs=outputs,
123
- title="RishiGPT Medical Brain Segmentation",
124
- description="UNet++ Brain Tumor Segmentation with LangChain integration. Includes tumor stats and medical explanation."
125
  ).launch()
 
7
  from langchain_groq import ChatGroq
8
  from langchain.agents import initialize_agent
9
  from langchain.prompts import PromptTemplate
10
+ from langchain_core.runnables import RunnableSequence
11
  from langchain.tools import StructuredTool
12
 
13
+
14
  IMG_HEIGHT = 256
15
  IMG_WIDTH = 256
16
 
 
29
  model = tf.keras.models.load_model(model_path, compile=False)
30
 
31
 
32
+ # === Segmentation + Stats + Overlay ===
33
  def classify_image_and_stats(image_input):
34
  img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
35
+ img_norm = img / 255.0
36
+ img_batch = np.expand_dims(img_norm, axis=0)
37
 
38
+ prediction = model.predict(img_batch)[0] # (256, 256, 1)
39
+ mask = (prediction > 0.5).astype(np.uint8)
40
 
41
  if mask.ndim == 3 and mask.shape[-1] == 1:
42
  mask = np.squeeze(mask, axis=-1)
43
 
44
  # Tumor stats
45
+ tumor_area = np.sum(mask)
46
  total_area = IMG_HEIGHT * IMG_WIDTH
47
  tumor_ratio = tumor_area / total_area
48
 
49
+ tumor_label = "Tumor Detected" if tumor_ratio > 0.005 else "No Tumor Detected"
50
+
51
+ # === Overlay mask on original ===
52
+ overlay = np.array(img) # original resized input
53
+ red_mask = np.zeros_like(overlay)
54
+ red_mask[..., 0] = mask * 255 # Red channel
55
+
56
+ overlay_img = np.clip(0.6 * overlay + 0.4 * red_mask, 0, 255).astype(np.uint8)
57
 
58
  stats = {
59
  "tumor_area": int(tumor_area),
 
62
  "tumor_label": tumor_label
63
  }
64
 
65
+ return overlay_img, stats
66
 
67
 
68
  # === Gradio handler ===
69
  def rishigpt_handler(image_input, groq_api_key):
70
  os.environ["GROQ_API_KEY"] = groq_api_key
71
 
72
+ overlay_img, stats = classify_image_and_stats(image_input)
73
 
74
  def segment_brain_tool(input_text: str) -> str:
75
+ return (
76
+ f"Tumor label: {stats['tumor_label']}. "
77
+ f"Tumor area: {stats['tumor_area']}. "
78
+ f"Ratio: {stats['tumor_ratio']:.4f}."
79
+ )
80
 
81
  tool = StructuredTool.from_function(
82
  segment_brain_tool,
83
  name="segment_brain",
84
+ description="Provide tumor segmentation stats for the MRI image."
85
  )
86
 
87
  llm = ChatGroq(
88
  model="meta-llama/llama-4-scout-17b-16e-instruct",
89
+ temperature=0.4
90
  )
91
 
92
  agent = initialize_agent(
93
  tools=[tool],
94
  llm=llm,
95
  agent="zero-shot-react-description",
96
+ verbose=False
97
  )
98
 
99
+ user_query = "Give me the segmentation details"
100
  classification = agent.run(user_query)
101
 
102
+ # Better prompt + output parser
103
  prompt = PromptTemplate(
104
  input_variables=["result"],
105
  template=(
106
+ "You are a compassionate AI radiologist. "
107
+ "Read this tumor analysis result: {result}. "
108
+ "Summarize the situation like you're talking to the patient in calm, clear language. "
109
+ "Add any recommendations for next steps too, but keep it easy to understand."
110
  )
111
  )
112
 
113
  chain = prompt | llm
114
+ description = chain.invoke({"result": classification}).content.strip()
115
 
116
+ return overlay_img, description
117
 
118
 
119
  # === Gradio UI ===
 
123
  ]
124
 
125
  outputs = [
126
+ gr.Image(type="numpy", label="Overlay: Brain MRI + Tumor Mask"),
127
+ gr.Textbox(label="Doctor's Explanation")
128
  ]
129
 
130
  if __name__ == "__main__":
 
132
  fn=rishigpt_handler,
133
  inputs=inputs,
134
  outputs=outputs,
135
+ title="🧠 RishiGPT Medical Brain Segmentation",
136
+ description="UNet++ Brain Tumor Segmentation with mask overlay, detailed stats, and human-like explanation."
137
  ).launch()