Spaces:
Running
Running
File size: 1,976 Bytes
8366946 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
"""Agent calling a remote ensemble model on Modal.
Computes final price from multiple predictions.
"""
import modal
from src.agents.base_agent import Agent
from src.agents.ft_price_agent import FTPriceAgent
from src.agents.rag_price_agent import RAGPriceAgent
from src.agents.xgb_price_agent import XGBoostPriceAgent
from src.config.constants import CURRENCY
from src.modal_services.app_config import APP_NAME
class EnsemblePriceAgent(Agent):
"""Agent that aggregates FT, RAG, and XGB predictions.
Sends them to the remote EnsemblePricer on Modal.
"""
name = "EnsemblePrice Agent"
color = "magenta"
def __init__(self) -> None:
"""Initialize the agent."""
self._modal_called = False
self.ft_agent = FTPriceAgent()
self.rag_agent = RAGPriceAgent()
self.xgb_agent = XGBoostPriceAgent()
remote_ensemble = modal.Cls.from_name(APP_NAME, "EnsemblePricer")
self.ensemble = remote_ensemble()
self.log("is ready")
def price(self, description: str) -> float:
"""Get individual predictions and pass them to the ensemble model."""
ft_pred = self.ft_agent.price(description)
rag_pred = self.rag_agent.price(description)
xgb_pred = self.xgb_agent.price(description)
if not self._modal_called:
self.log("📡 Connecting to Modal — Loading trained linear model...")
self._modal_called = True
self.log(
f"Predictions — FT={CURRENCY}{ft_pred}, "
f"RAG={CURRENCY}{rag_pred}, "
f"XGB={CURRENCY}{xgb_pred}"
)
try:
result = self.ensemble.price.remote(ft_pred, rag_pred, xgb_pred)
self.log(f"Final estimate: {CURRENCY}{result:.2f}")
return result
except Exception as e:
self.log(f"[ERROR] Remote EnsemblePricer failed: {e}")
raise RuntimeError("EnsemblePriceAgent failed to get final price.") from e
|