rishirajbal commited on
Commit
ec387c2
·
verified ·
1 Parent(s): 623f660

Create app.py

Browse files

Gradio utilisation for RishiGPT Medical Brain Segmentation

Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import tensorflow as tf
4
+ import numpy as np
5
+ import requests
6
+
7
+ from langchain_groq import ChatGroq
8
+ from langchain.agents import initialize_agent
9
+ from langchain.prompts import PromptTemplate
10
+ from langchain.chains import LLMChain
11
+ from langchain.tools import StructuredTool
12
+ from tensorflow.keras.preprocessing import image
13
+
14
+
15
+ model_path = "unet_model.h5"
16
+ if not os.path.exists(model_path):
17
+ hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
18
+ r = requests.get(hf_url)
19
+ with open(model_path, "wb") as f:
20
+ f.write(r.content)
21
+
22
+ model = tf.keras.models.load_model(model_path, compile=False)
23
+
24
+
25
+ def classify_image(image_input):
26
+ img = tf.image.resize(image_input, (256, 256))
27
+ img = img / 255.0
28
+ img = np.expand_dims(img, axis=0)
29
+
30
+ prediction = model.predict(img)[0]
31
+ mask = (prediction > 0.5).astype(np.uint8) * 255
32
+
33
+ return mask
34
+
35
+
36
+ def rishigpt_handler(image_input, groq_api_key):
37
+ os.environ["GROQ_API_KEY"] = groq_api_key
38
+
39
+ mask = classify_image(image_input)
40
+
41
+ def classify_image_tool(img_path):
42
+ return "Brain tumor mask generated."
43
+
44
+ tool = StructuredTool.from_function(
45
+ classify_image_tool,
46
+ name="segment_brain",
47
+ description="Segment brain MRI for tumor detection."
48
+ )
49
+
50
+ llm = ChatGroq(
51
+ model="meta-llama/llama-4-scout-17b-16e-instruct",
52
+ temperature=0.3
53
+ )
54
+
55
+ agent = initialize_agent(
56
+ tools=[tool],
57
+ llm=llm,
58
+ agent="zero-shot-react-description",
59
+ verbose=True
60
+ )
61
+
62
+ user_query = "I uploaded a brain MRI. What does the segmentation say?"
63
+ classification = agent.run(user_query)
64
+
65
+ prompt = PromptTemplate(
66
+ input_variables=["result"],
67
+ template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis."
68
+ )
69
+
70
+ llm_chain = LLMChain(
71
+ llm=llm,
72
+ prompt=prompt
73
+ )
74
+
75
+ description = llm_chain.run({"result": classification})
76
+
77
+ return mask, description
78
+
79
+
80
+ inputs = [
81
+ gr.Image(type="numpy", label="Upload Brain MRI Slice"),
82
+ gr.Textbox(type="password", label="Groq API Key")
83
+ ]
84
+
85
+ outputs = [
86
+ gr.Image(type="numpy", label="Tumor Segmentation Mask"),
87
+ gr.Textbox(label="Medical Explanation")
88
+ ]
89
+
90
+ gr.Interface(
91
+ fn=rishigpt_handler,
92
+ inputs=inputs,
93
+ outputs=outputs,
94
+ title="RishiGPT Medical Brain Segmentation",
95
+ description="UNet++ Brain Tumor Segmentation"
96
+ ).launch()