rishirajbal commited on
Commit
755c3f6
·
verified ·
1 Parent(s): 64183cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -20
app.py CHANGED
@@ -7,9 +7,12 @@ import requests
7
  from langchain_groq import ChatGroq
8
  from langchain.agents import initialize_agent
9
  from langchain.prompts import PromptTemplate
10
- from langchain_core.runnables import RunnableSequence # Modern replacement for LLMChain
11
  from langchain.tools import StructuredTool
12
 
 
 
 
13
  # === Download model if not exists ===
14
  model_path = "unet_model.h5"
15
  if not os.path.exists(model_path):
@@ -24,36 +27,52 @@ if not os.path.exists(model_path):
24
  print("Loading model...")
25
  model = tf.keras.models.load_model(model_path, compile=False)
26
 
27
- # === Segmentation ===
28
- def classify_image(image_input):
29
- img = tf.image.resize(image_input, (256, 256))
 
30
  img = img / 255.0
31
  img = np.expand_dims(img, axis=0)
32
 
33
- prediction = model.predict(img)[0] # (256, 256, 1) maybe
34
-
35
- mask = (prediction > 0.5).astype(np.uint8) * 255 # binary mask
36
 
37
- # Squeeze to (H, W) if needed
38
  if mask.ndim == 3 and mask.shape[-1] == 1:
39
  mask = np.squeeze(mask, axis=-1)
40
 
41
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # === Main handler ===
 
 
 
44
  def rishigpt_handler(image_input, groq_api_key):
45
  os.environ["GROQ_API_KEY"] = groq_api_key
46
 
47
- mask = classify_image(image_input)
48
 
49
- # Dummy tool for LangChain agent
50
- def segment_brain_tool():
51
- return "A brain tumor mask was generated."
52
 
53
  tool = StructuredTool.from_function(
54
  segment_brain_tool,
55
  name="segment_brain",
56
- description="Segment brain MRI for tumor detection."
57
  )
58
 
59
  llm = ChatGroq(
@@ -68,20 +87,23 @@ def rishigpt_handler(image_input, groq_api_key):
68
  verbose=True
69
  )
70
 
71
- user_query = "I uploaded a brain MRI. What does the segmentation say?"
72
  classification = agent.run(user_query)
73
 
74
- # New style: RunnableSequence
75
  prompt = PromptTemplate(
76
  input_variables=["result"],
77
- template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis."
 
 
 
78
  )
79
- chain = prompt | llm
80
 
 
81
  description = chain.invoke({"result": classification})
82
 
83
  return mask, description
84
 
 
85
  # === Gradio UI ===
86
  inputs = [
87
  gr.Image(type="numpy", label="Upload Brain MRI Slice"),
@@ -99,5 +121,5 @@ if __name__ == "__main__":
99
  inputs=inputs,
100
  outputs=outputs,
101
  title="RishiGPT Medical Brain Segmentation",
102
- description="UNet++ Brain Tumor Segmentation with LangChain integration"
103
  ).launch()
 
7
  from langchain_groq import ChatGroq
8
  from langchain.agents import initialize_agent
9
  from langchain.prompts import PromptTemplate
10
+ from langchain_core.runnables import RunnableSequence # modern
11
  from langchain.tools import StructuredTool
12
 
13
+ IMG_HEIGHT = 256
14
+ IMG_WIDTH = 256
15
+
16
  # === Download model if not exists ===
17
  model_path = "unet_model.h5"
18
  if not os.path.exists(model_path):
 
27
  print("Loading model...")
28
  model = tf.keras.models.load_model(model_path, compile=False)
29
 
30
+
31
+ # === Segmentation + Stats ===
32
+ def classify_image_and_stats(image_input):
33
+ img = tf.image.resize(image_input, [IMG_HEIGHT, IMG_WIDTH])
34
  img = img / 255.0
35
  img = np.expand_dims(img, axis=0)
36
 
37
+ prediction = model.predict(img)[0] # (256, 256, 1)
38
+ mask = (prediction > 0.5).astype(np.uint8) * 255
 
39
 
 
40
  if mask.ndim == 3 and mask.shape[-1] == 1:
41
  mask = np.squeeze(mask, axis=-1)
42
 
43
+ # Tumor stats
44
+ tumor_area = np.sum(prediction > 0.5)
45
+ total_area = IMG_HEIGHT * IMG_WIDTH
46
+ tumor_ratio = tumor_area / total_area
47
+
48
+ if tumor_ratio > 0.01:
49
+ tumor_label = "Tumor Detected"
50
+ else:
51
+ tumor_label = "No Tumor Detected"
52
+
53
+ stats = {
54
+ "tumor_area": int(tumor_area),
55
+ "total_area": total_area,
56
+ "tumor_ratio": tumor_ratio,
57
+ "tumor_label": tumor_label
58
+ }
59
 
60
+ return mask, stats
61
+
62
+
63
+ # === Gradio handler ===
64
  def rishigpt_handler(image_input, groq_api_key):
65
  os.environ["GROQ_API_KEY"] = groq_api_key
66
 
67
+ mask, stats = classify_image_and_stats(image_input)
68
 
69
+ def segment_brain_tool(input_text: str) -> str:
70
+ return f"Tumor label: {stats['tumor_label']}. Tumor area: {stats['tumor_area']}. Ratio: {stats['tumor_ratio']:.4f}."
 
71
 
72
  tool = StructuredTool.from_function(
73
  segment_brain_tool,
74
  name="segment_brain",
75
+ description="Provide tumor segmentation stats for the MRI image. Takes dummy input text."
76
  )
77
 
78
  llm = ChatGroq(
 
87
  verbose=True
88
  )
89
 
90
+ user_query = "Get segmentation details"
91
  classification = agent.run(user_query)
92
 
 
93
  prompt = PromptTemplate(
94
  input_variables=["result"],
95
+ template=(
96
+ "You are a medical imaging expert. Based on this tumor analysis result: {result}, "
97
+ "explain what this means for the patient in simple language."
98
+ )
99
  )
 
100
 
101
+ chain = prompt | llm
102
  description = chain.invoke({"result": classification})
103
 
104
  return mask, description
105
 
106
+
107
  # === Gradio UI ===
108
  inputs = [
109
  gr.Image(type="numpy", label="Upload Brain MRI Slice"),
 
121
  inputs=inputs,
122
  outputs=outputs,
123
  title="RishiGPT Medical Brain Segmentation",
124
+ description="UNet++ Brain Tumor Segmentation with LangChain integration. Includes tumor stats and medical explanation."
125
  ).launch()