leofltt commited on
Commit
3dd74bd
·
1 Parent(s): 7737a19

lazy load image model

Browse files
Files changed (1) hide show
  1. app.py +26 -15
app.py CHANGED
@@ -43,6 +43,9 @@ Your final answer: FINAL ANSWER: Paris
43
 
44
  # --- Tool Definitions ---
45
 
 
 
 
46
 
47
  @tool
48
  def web_search(query: str):
@@ -66,14 +69,19 @@ def math_calculator(expression: str):
66
 
67
  @tool
68
  def image_analyzer(image_url: str):
69
- """Analyzes an image and returns a description."""
 
70
  print(f"--- Calling Image Analyzer Tool with URL: {image_url} ---")
71
  try:
72
- # Using a CPU-friendly image-to-text model
73
- image_to_text = pipeline(
74
- "image-to-text", model="Salesforce/blip-image-captioning-base"
75
- )
76
- description = image_to_text(image_url)[0]["generated_text"]
 
 
 
 
77
  return description
78
  except Exception as e:
79
  return f"Error analyzing image: {e}"
@@ -112,7 +120,8 @@ class GaiaAgent:
112
  ]
113
 
114
  # Initialize the LLM
115
- print("Loading LLM...")
 
116
  llm = HuggingFacePipeline.from_model_id(
117
  model_id="microsoft/Phi-3-mini-4k-instruct",
118
  task="text-generation",
@@ -121,12 +130,12 @@ class GaiaAgent:
121
  "top_k": 50,
122
  "temperature": 0.1,
123
  "do_sample": False,
124
- "torch_dtype": "auto", # Let transformers figure out the best dtype
125
- "device_map": "auto",
126
  },
127
- trust_remote_code=True, # Phi-3 requires this
 
 
128
  )
129
- print("LLM loaded.")
130
 
131
  # Create the agent graph
132
  prompt = PromptTemplate(
@@ -142,7 +151,7 @@ Question: {question}
142
 
143
  self.agent = prompt | llm | StrOutputParser()
144
  self.graph = self._create_graph()
145
- print("GaiaAgent initialized.")
146
 
147
  def _create_graph(self):
148
  graph = StateGraph(AgentState)
@@ -175,8 +184,8 @@ Question: {question}
175
  print("--- Calling Tools ---")
176
  raw_tool_call = state["messages"][-1]
177
 
178
- # Simple regex to find tool calls like tool_name("argument")
179
- tool_call_match = re.search(r"(\w+)\((.*?)\)", raw_tool_call)
180
  if not tool_call_match:
181
  return {"messages": ["No valid tool call found."], "sender": "tools"}
182
 
@@ -184,7 +193,9 @@ Question: {question}
184
  tool_input_str = tool_call_match.group(2).strip()
185
 
186
  # Remove quotes from the input string if they exist
187
- if tool_input_str.startswith('"') and tool_input_str.endswith('"'):
 
 
188
  tool_input = tool_input_str[1:-1]
189
  else:
190
  tool_input = tool_input_str
 
43
 
44
  # --- Tool Definitions ---
45
 
46
+ # Global variable to cache the image-to-text pipeline. This allows for "lazy loading".
47
+ image_to_text_pipeline = None
48
+
49
 
50
  @tool
51
  def web_search(query: str):
 
69
 
70
  @tool
71
  def image_analyzer(image_url: str):
72
+ """Analyzes an image and returns a description. Loads the model on first use."""
73
+ global image_to_text_pipeline
74
  print(f"--- Calling Image Analyzer Tool with URL: {image_url} ---")
75
  try:
76
+ if image_to_text_pipeline is None:
77
+ print("--- Initializing Image Analyzer pipeline for the first time... ---")
78
+ # Lazy-load the pipeline to conserve memory on startup
79
+ image_to_text_pipeline = pipeline(
80
+ "image-to-text", model="Salesforce/blip-image-captioning-base"
81
+ )
82
+ print("--- Image Analyzer pipeline initialized. ---")
83
+
84
+ description = image_to_text_pipeline(image_url)[0]["generated_text"]
85
  return description
86
  except Exception as e:
87
  return f"Error analyzing image: {e}"
 
120
  ]
121
 
122
  # Initialize the LLM
123
+ print("Loading LLM... This may take a few minutes on first startup.")
124
+ # Using a smaller, CPU-friendly model to avoid memory issues on Hugging Face Spaces
125
  llm = HuggingFacePipeline.from_model_id(
126
  model_id="microsoft/Phi-3-mini-4k-instruct",
127
  task="text-generation",
 
130
  "top_k": 50,
131
  "temperature": 0.1,
132
  "do_sample": False,
 
 
133
  },
134
+ torch_dtype="auto",
135
+ trust_remote_code=True, # Required for Phi-3
136
+ device_map="auto",
137
  )
138
+ print("LLM loaded successfully.")
139
 
140
  # Create the agent graph
141
  prompt = PromptTemplate(
 
151
 
152
  self.agent = prompt | llm | StrOutputParser()
153
  self.graph = self._create_graph()
154
+ print("GaiaAgent initialized successfully.")
155
 
156
  def _create_graph(self):
157
  graph = StateGraph(AgentState)
 
184
  print("--- Calling Tools ---")
185
  raw_tool_call = state["messages"][-1]
186
 
187
+ # Simple regex to find tool calls like tool_name("argument") or tool_name(argument)
188
+ tool_call_match = re.search(r"(\w+)\s*\((.*?)\)", raw_tool_call, re.DOTALL)
189
  if not tool_call_match:
190
  return {"messages": ["No valid tool call found."], "sender": "tools"}
191
 
 
193
  tool_input_str = tool_call_match.group(2).strip()
194
 
195
  # Remove quotes from the input string if they exist
196
+ if (tool_input_str.startswith('"') and tool_input_str.endswith('"')) or (
197
+ tool_input_str.startswith("'") and tool_input_str.endswith("'")
198
+ ):
199
  tool_input = tool_input_str[1:-1]
200
  else:
201
  tool_input = tool_input_str