Spaces:
Build error
Build error
lazy load image model
Browse files
app.py
CHANGED
@@ -43,6 +43,9 @@ Your final answer: FINAL ANSWER: Paris
|
|
43 |
|
44 |
# --- Tool Definitions ---
|
45 |
|
|
|
|
|
|
|
46 |
|
47 |
@tool
|
48 |
def web_search(query: str):
|
@@ -66,14 +69,19 @@ def math_calculator(expression: str):
|
|
66 |
|
67 |
@tool
|
68 |
def image_analyzer(image_url: str):
|
69 |
-
"""Analyzes an image and returns a description."""
|
|
|
70 |
print(f"--- Calling Image Analyzer Tool with URL: {image_url} ---")
|
71 |
try:
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
return description
|
78 |
except Exception as e:
|
79 |
return f"Error analyzing image: {e}"
|
@@ -112,7 +120,8 @@ class GaiaAgent:
|
|
112 |
]
|
113 |
|
114 |
# Initialize the LLM
|
115 |
-
print("Loading LLM...")
|
|
|
116 |
llm = HuggingFacePipeline.from_model_id(
|
117 |
model_id="microsoft/Phi-3-mini-4k-instruct",
|
118 |
task="text-generation",
|
@@ -121,12 +130,12 @@ class GaiaAgent:
|
|
121 |
"top_k": 50,
|
122 |
"temperature": 0.1,
|
123 |
"do_sample": False,
|
124 |
-
"torch_dtype": "auto", # Let transformers figure out the best dtype
|
125 |
-
"device_map": "auto",
|
126 |
},
|
127 |
-
|
|
|
|
|
128 |
)
|
129 |
-
print("LLM loaded.")
|
130 |
|
131 |
# Create the agent graph
|
132 |
prompt = PromptTemplate(
|
@@ -142,7 +151,7 @@ Question: {question}
|
|
142 |
|
143 |
self.agent = prompt | llm | StrOutputParser()
|
144 |
self.graph = self._create_graph()
|
145 |
-
print("GaiaAgent initialized.")
|
146 |
|
147 |
def _create_graph(self):
|
148 |
graph = StateGraph(AgentState)
|
@@ -175,8 +184,8 @@ Question: {question}
|
|
175 |
print("--- Calling Tools ---")
|
176 |
raw_tool_call = state["messages"][-1]
|
177 |
|
178 |
-
# Simple regex to find tool calls like tool_name("argument")
|
179 |
-
tool_call_match = re.search(r"(\w+)\((.*?)\)", raw_tool_call)
|
180 |
if not tool_call_match:
|
181 |
return {"messages": ["No valid tool call found."], "sender": "tools"}
|
182 |
|
@@ -184,7 +193,9 @@ Question: {question}
|
|
184 |
tool_input_str = tool_call_match.group(2).strip()
|
185 |
|
186 |
# Remove quotes from the input string if they exist
|
187 |
-
if tool_input_str.startswith('"') and tool_input_str.endswith('"')
|
|
|
|
|
188 |
tool_input = tool_input_str[1:-1]
|
189 |
else:
|
190 |
tool_input = tool_input_str
|
|
|
43 |
|
44 |
# --- Tool Definitions ---
|
45 |
|
46 |
+
# Global variable to cache the image-to-text pipeline. This allows for "lazy loading".
|
47 |
+
image_to_text_pipeline = None
|
48 |
+
|
49 |
|
50 |
@tool
|
51 |
def web_search(query: str):
|
|
|
69 |
|
70 |
@tool
|
71 |
def image_analyzer(image_url: str):
|
72 |
+
"""Analyzes an image and returns a description. Loads the model on first use."""
|
73 |
+
global image_to_text_pipeline
|
74 |
print(f"--- Calling Image Analyzer Tool with URL: {image_url} ---")
|
75 |
try:
|
76 |
+
if image_to_text_pipeline is None:
|
77 |
+
print("--- Initializing Image Analyzer pipeline for the first time... ---")
|
78 |
+
# Lazy-load the pipeline to conserve memory on startup
|
79 |
+
image_to_text_pipeline = pipeline(
|
80 |
+
"image-to-text", model="Salesforce/blip-image-captioning-base"
|
81 |
+
)
|
82 |
+
print("--- Image Analyzer pipeline initialized. ---")
|
83 |
+
|
84 |
+
description = image_to_text_pipeline(image_url)[0]["generated_text"]
|
85 |
return description
|
86 |
except Exception as e:
|
87 |
return f"Error analyzing image: {e}"
|
|
|
120 |
]
|
121 |
|
122 |
# Initialize the LLM
|
123 |
+
print("Loading LLM... This may take a few minutes on first startup.")
|
124 |
+
# Using a smaller, CPU-friendly model to avoid memory issues on Hugging Face Spaces
|
125 |
llm = HuggingFacePipeline.from_model_id(
|
126 |
model_id="microsoft/Phi-3-mini-4k-instruct",
|
127 |
task="text-generation",
|
|
|
130 |
"top_k": 50,
|
131 |
"temperature": 0.1,
|
132 |
"do_sample": False,
|
|
|
|
|
133 |
},
|
134 |
+
torch_dtype="auto",
|
135 |
+
trust_remote_code=True, # Required for Phi-3
|
136 |
+
device_map="auto",
|
137 |
)
|
138 |
+
print("LLM loaded successfully.")
|
139 |
|
140 |
# Create the agent graph
|
141 |
prompt = PromptTemplate(
|
|
|
151 |
|
152 |
self.agent = prompt | llm | StrOutputParser()
|
153 |
self.graph = self._create_graph()
|
154 |
+
print("GaiaAgent initialized successfully.")
|
155 |
|
156 |
def _create_graph(self):
|
157 |
graph = StateGraph(AgentState)
|
|
|
184 |
print("--- Calling Tools ---")
|
185 |
raw_tool_call = state["messages"][-1]
|
186 |
|
187 |
+
# Simple regex to find tool calls like tool_name("argument") or tool_name(argument)
|
188 |
+
tool_call_match = re.search(r"(\w+)\s*\((.*?)\)", raw_tool_call, re.DOTALL)
|
189 |
if not tool_call_match:
|
190 |
return {"messages": ["No valid tool call found."], "sender": "tools"}
|
191 |
|
|
|
193 |
tool_input_str = tool_call_match.group(2).strip()
|
194 |
|
195 |
# Remove quotes from the input string if they exist
|
196 |
+
if (tool_input_str.startswith('"') and tool_input_str.endswith('"')) or (
|
197 |
+
tool_input_str.startswith("'") and tool_input_str.endswith("'")
|
198 |
+
):
|
199 |
tool_input = tool_input_str[1:-1]
|
200 |
else:
|
201 |
tool_input = tool_input_str
|