rishirajbal commited on
Commit
64183cb
·
verified ·
1 Parent(s): 72fda4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -16
app.py CHANGED
@@ -5,13 +5,13 @@ import numpy as np
5
  import requests
6
 
7
  from langchain_groq import ChatGroq
 
8
  from langchain.prompts import PromptTemplate
9
- from langchain.chains import LLMChain
10
-
11
 
 
12
  model_path = "unet_model.h5"
13
-
14
- # Safe download with streaming to avoid incomplete file
15
  if not os.path.exists(model_path):
16
  hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
17
  print(f"Downloading model from {hf_url}...")
@@ -24,44 +24,65 @@ if not os.path.exists(model_path):
24
  print("Loading model...")
25
  model = tf.keras.models.load_model(model_path, compile=False)
26
 
27
-
28
  def classify_image(image_input):
29
  img = tf.image.resize(image_input, (256, 256))
30
  img = img / 255.0
31
  img = np.expand_dims(img, axis=0)
32
 
33
- prediction = model.predict(img)[0]
34
- mask = (prediction > 0.5).astype(np.uint8) * 255
35
 
36
- return mask
37
 
 
 
 
38
 
 
 
 
39
  def rishigpt_handler(image_input, groq_api_key):
40
  os.environ["GROQ_API_KEY"] = groq_api_key
41
 
42
  mask = classify_image(image_input)
43
 
 
 
 
 
 
 
 
 
 
 
44
  llm = ChatGroq(
45
  model="meta-llama/llama-4-scout-17b-16e-instruct",
46
  temperature=0.3
47
  )
48
 
 
 
 
 
 
 
 
 
 
 
 
49
  prompt = PromptTemplate(
50
  input_variables=["result"],
51
  template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis."
52
  )
 
53
 
54
- llm_chain = LLMChain(
55
- llm=llm,
56
- prompt=prompt
57
- )
58
-
59
- classification = "The brain tumor mask has been generated and segmentation is complete."
60
- description = llm_chain.run({"result": classification})
61
 
62
  return mask, description
63
 
64
-
65
  inputs = [
66
  gr.Image(type="numpy", label="Upload Brain MRI Slice"),
67
  gr.Textbox(type="password", label="Groq API Key")
 
5
  import requests
6
 
7
  from langchain_groq import ChatGroq
8
+ from langchain.agents import initialize_agent
9
  from langchain.prompts import PromptTemplate
10
+ from langchain_core.runnables import RunnableSequence # Modern replacement for LLMChain
11
+ from langchain.tools import StructuredTool
12
 
13
+ # === Download model if not exists ===
14
  model_path = "unet_model.h5"
 
 
15
  if not os.path.exists(model_path):
16
  hf_url = "https://huggingface.co/rishirajbal/UNET_plus_plus_Brain_segmentation/resolve/main/unet_model.h5"
17
  print(f"Downloading model from {hf_url}...")
 
24
  print("Loading model...")
25
  model = tf.keras.models.load_model(model_path, compile=False)
26
 
27
+ # === Segmentation ===
28
  def classify_image(image_input):
29
  img = tf.image.resize(image_input, (256, 256))
30
  img = img / 255.0
31
  img = np.expand_dims(img, axis=0)
32
 
33
+ prediction = model.predict(img)[0] # (256, 256, 1) maybe
 
34
 
35
+ mask = (prediction > 0.5).astype(np.uint8) * 255 # binary mask
36
 
37
+ # Squeeze to (H, W) if needed
38
+ if mask.ndim == 3 and mask.shape[-1] == 1:
39
+ mask = np.squeeze(mask, axis=-1)
40
 
41
+ return mask
42
+
43
+ # === Main handler ===
44
  def rishigpt_handler(image_input, groq_api_key):
45
  os.environ["GROQ_API_KEY"] = groq_api_key
46
 
47
  mask = classify_image(image_input)
48
 
49
+ # Dummy tool for LangChain agent
50
+ def segment_brain_tool():
51
+ return "A brain tumor mask was generated."
52
+
53
+ tool = StructuredTool.from_function(
54
+ segment_brain_tool,
55
+ name="segment_brain",
56
+ description="Segment brain MRI for tumor detection."
57
+ )
58
+
59
  llm = ChatGroq(
60
  model="meta-llama/llama-4-scout-17b-16e-instruct",
61
  temperature=0.3
62
  )
63
 
64
+ agent = initialize_agent(
65
+ tools=[tool],
66
+ llm=llm,
67
+ agent="zero-shot-react-description",
68
+ verbose=True
69
+ )
70
+
71
+ user_query = "I uploaded a brain MRI. What does the segmentation say?"
72
+ classification = agent.run(user_query)
73
+
74
+ # New style: RunnableSequence
75
  prompt = PromptTemplate(
76
  input_variables=["result"],
77
  template="You are a medical imaging expert. Based on the result: {result}, explain what this means for diagnosis."
78
  )
79
+ chain = prompt | llm
80
 
81
+ description = chain.invoke({"result": classification})
 
 
 
 
 
 
82
 
83
  return mask, description
84
 
85
+ # === Gradio UI ===
86
  inputs = [
87
  gr.Image(type="numpy", label="Upload Brain MRI Slice"),
88
  gr.Textbox(type="password", label="Groq API Key")