File size: 2,265 Bytes
ec387c2
 
 
 
 
 
 
 
 
 
 
 
b31004e
 
ec387c2
82768ad
b31004e
 
 
 
 
 
 
 
ec387c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72fda4e
ec387c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b31004e
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import gradio as gr
import os
import tensorflow as tf
import numpy as np
import requests

from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain.chains import LLMChain


model_path = "unet_model.h5"

# Safe download with streaming to avoid incomplete file
if not os.path.exists(model_path):
    hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
    print(f"Downloading model from {hf_url}...")
    with requests.get(hf_url, stream=True) as r:
        r.raise_for_status()
        with open(model_path, "wb") as f:
            for chunk in r.iter_content(chunk_size=8192):
                f.write(chunk)

print("Loading model...")
model = tf.keras.models.load_model(model_path, compile=False)


def classify_image(image_input):
    img = tf.image.resize(image_input, (256, 256))
    img = img / 255.0
    img = np.expand_dims(img, axis=0)

    prediction = model.predict(img)[0]
    mask = (prediction > 0.5).astype(np.uint8) * 255

    return mask


def rishigpt_handler(image_input, groq_api_key):
    os.environ["GROQ_API_KEY"] = groq_api_key

    mask = classify_image(image_input)

    llm = ChatGroq(
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        temperature=0.3
    )

    prompt = PromptTemplate(
        input_variables=["result"],
        template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis."
    )

    llm_chain = LLMChain(
        llm=llm,
        prompt=prompt
    )

    classification = "The brain tumor mask has been generated and segmentation is complete."
    description = llm_chain.run({"result": classification})

    return mask, description


inputs = [
    gr.Image(type="numpy", label="Upload Brain MRI Slice"),
    gr.Textbox(type="password", label="Groq API Key")
]

outputs = [
    gr.Image(type="numpy", label="Tumor Segmentation Mask"),
    gr.Textbox(label="Medical Explanation")
]

if __name__ == "__main__":
    gr.Interface(
        fn=rishigpt_handler,
        inputs=inputs,
        outputs=outputs,
        title="RishiGPT Medical Brain Segmentation",
        description="UNet++ Brain Tumor Segmentation with LangChain integration"
    ).launch()