File size: 2,265 Bytes
ec387c2 b31004e ec387c2 82768ad b31004e ec387c2 72fda4e ec387c2 b31004e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests
from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain
model_path = "unet_model.h5"
# Safe download with streaming to avoid incomplete file
if not os.path.exists(model_path):
hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
print(f"Downloading model from {hf_url}...")
with requests.get(hf_url, stream=True) as r:
r.raise_for_status()
with open(model_path, "wb") as f:
for chunk in r.iter_content(chunk_size=8192):
f.write(chunk)
print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)
def classify_image(image_input):
img = tf.image.resize(image_input, (256, 256))
img = img / 255.0
img = np.expand_dims(img, axis=0)
prediction = model.predict(img)[0]
mask = (prediction > 0.5).astype(np.uint8) * 255
return mask
def rishigpt_handler(image_input, groq_api_key):
os.environ["GROQ_API_KEY"] = groq_api_key
mask = classify_image(image_input)
llm = ChatGroq(
model="meta-llama/llama-4-scout-17b-16e-instruct",
temperature=0.3
)
prompt = PromptTemplate(
input_variables=["result"],
template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis."
)
llm_chain = LLMChain(
llm=llm,
prompt=prompt
)
classification = "The brain tumor mask has been generated and segmentation is complete."
description = llm_chain.run({"result": classification})
return mask, description
inputs = [
gr.Image(type="numpy", label="Upload Brain MRI Slice"),
gr.Textbox(type="password", label="Groq API Key")
]
outputs = [
gr.Image(type="numpy", label="Tumor Segmentation Mask"),
gr.Textbox(label="Medical Explanation")
]
if __name__ == "__main__":
gr.Interface(
fn=rishigpt_handler,
inputs=inputs,
outputs=outputs,
title="RishiGPT Medical Brain Segmentation",
description="UNet++ Brain Tumor Segmentation with LangChain integration"
).launch()
|