Spaces:
Running
Running
#βββ Basic imports βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import os | |
import math | |
import sqlite3 | |
import fitz # PyMuPDF for PDF parsing | |
from flask_socketio import SocketIO | |
# βββ Langchain Frameworks βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
from langchain.tools import Tool | |
from langchain.chat_models import ChatOpenAI | |
from langchain_groq import ChatGroq | |
from langchain_mistralai import ChatMistralAI | |
from langchain.agents import initialize_agent, AgentType | |
from langchain.schema import Document | |
from langchain.chains import RetrievalQA | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.prompts import PromptTemplate | |
from langchain_community.document_loaders import TextLoader, PyMuPDFLoader | |
# taking global variables from the app.py file | |
#from app import DB_PATH, DOC_PATH, IMG_PATH, OTH_PATH | |
# βββ File paths ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
import config | |
# Ensure this is at the very top | |
# βββ SQL Agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
from langchain_community.utilities import SQLDatabase | |
from langchain_community.agent_toolkits import SQLDatabaseToolkit | |
from langchain.chat_models import ChatOpenAI | |
from langgraph.prebuilt import create_react_agent | |
from langchain.agents import create_sql_agent | |
# βββ Memory βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
from langchain.memory import ConversationBufferMemory | |
from langchain.agents import initialize_agent, AgentType | |
from langchain.tools import Tool | |
from typing import List, Callable | |
from langchain.memory import ConversationBufferMemory | |
from langchain.schema import BaseMemory, AIMessage, HumanMessage, SystemMessage | |
from langchain.llms.base import LLM | |
from langchain.memory.chat_memory import BaseChatMemory | |
from pydantic import PrivateAttr | |
from langchain_core.messages import get_buffer_string | |
# 1) Create your memory object | |
from typing import List | |
from langchain.memory import ConversationBufferMemory | |
from langchain.schema import AIMessage, HumanMessage, SystemMessage | |
from langchain.llms.base import LLM | |
from langchain.memory.chat_memory import BaseChatMemory | |
from pydantic import PrivateAttr | |
class AutoSummaryMemory(ConversationBufferMemory): | |
_llm: LLM = PrivateAttr() | |
_max_entries: int = PrivateAttr() | |
_reduce_to: int = PrivateAttr() | |
_summary_system_prompt: str = PrivateAttr() | |
def __init__( | |
self, | |
llm: LLM, | |
memory_key: str = "chat_history", | |
return_messages: bool = True, | |
max_entries: int = 20, | |
reduce_to: int = 5, | |
summary_system_prompt: str = ( | |
"Summarize the following conversation so far in a concise paragraph. " | |
"Keep important facts and questions." | |
) | |
): | |
super().__init__(memory_key=memory_key, return_messages=return_messages) | |
self._llm = llm # PrivateAttr | |
self._max_entries = max_entries # PrivateAttr | |
self._reduce_to = reduce_to # PrivateAttr | |
self._summary_system_prompt = summary_system_prompt # PrivateAttr | |
def add_memory(self, inputs: dict, outputs: dict) -> None: | |
# Add the new turn as normal | |
super().add_memory(inputs=inputs, outputs=outputs) | |
# Check if memory length exceeded | |
msgs = self.chat_memory.messages | |
if len(msgs) >= self._max_entries: | |
full_text = "\n".join([f"{m.type}: {m.content}" for m in msgs]) | |
summary = self._llm.predict(f"{self._summary_system_prompt}\n\n{full_text}") | |
recent = msgs[-self._reduce_to:] | |
self.chat_memory.messages = [ | |
SystemMessage(content="Conversation summary: " + summary), | |
*recent | |
] | |
# βββ Image Processing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
from PIL import Image | |
import pytesseract | |
from transformers import pipeline | |
from groq import Groq | |
import config | |
import requests | |
from io import BytesIO | |
from PIL import Image | |
from transformers import pipeline, TrOCRProcessor, VisionEncoderDecoderModel | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
import base64 | |
from PIL import UnidentifiedImageError | |
# βββ Browser var βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
from typing import List, Dict | |
import json | |
from io import BytesIO | |
from langchain.tools import tool # or langchain_core.tools | |
from playwright.sync_api import sync_playwright | |
from duckduckgo_search import DDGS | |
from bs4 import BeautifulSoup | |
import requests | |
from playwright.sync_api import sync_playwright | |
# Attempt to import Playwright for dynamic page rendering | |
try: | |
from playwright.sync_api import sync_playwright | |
_playwright_available = True | |
except ImportError: | |
_playwright_available = False | |
# Define forbidden keywords for basic NSFW filtering | |
_forbidden = ["porn", "sex", "xxx", "nude", "erotic"] | |
# βββ LLM Setup βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# Load OpenAI API key from environment (required for LLM and embeddings) | |
import os | |
# API Keys from .env file | |
os.environ.setdefault("OPENAI_API_KEY", "<YOUR_OPENAI_KEY>") # Set your own key or env var | |
os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY", "default_key_or_placeholder") | |
os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY", "default_key_or_placeholder") | |
# Tavily API Key | |
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY", "default_key_or_placeholder") | |
_forbidden = ["nsfw", "porn", "sex", "explicit"] | |
_playwright_available = True # set False to disable Playwright | |
# Globals for RAG system | |
vector_store = None | |
rag_chain = None | |
DB_PATH = None # will be set when a .db is uploaded | |
DOC_PATH = None # will be set when a document is uploaded | |
IMG_PATH = None # will be set when an image is uploaded | |
OTH_PATH = None # will be set when an other file is uploaded | |
# βββ LLMS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
#llm = ChatOpenAI(model_name="gpt-3.5-turbo", streaming=True, temperature=0) | |
llm = ChatGroq(model="meta-llama/llama-4-maverick-17b-128e-instruct", streaming=True, temperature=0) | |
#llm = ChatMistralAI(model="mistral-large-latest", streaming=True, temperature=0) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Tool for browsing ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def tavily_search(query: str, top_k: int = 3) -> List[Dict]: | |
"""Call Tavily API and return a list of result dicts.""" | |
if not TAVILY_API_KEY: | |
print("[Tavily] No API key set. Skipping Tavily search.") | |
return [] | |
url = "https://api.tavily.com/search" | |
headers = { | |
"Authorization": f"Bearer {TAVILY_API_KEY}", | |
"Content-Type": "application/json", | |
} | |
payload = {"query": query, "num_results": top_k} | |
try: | |
resp = requests.post(url, headers=headers, json=payload, timeout=10) | |
resp.raise_for_status() | |
data = resp.json() | |
results = [] | |
for item in data.get("results", []): | |
results.append({ | |
"title": item.get("title", ""), | |
"url": item.get("url", ""), | |
"snippet": item.get("content", "")[:200], | |
"source": "Tavily" | |
}) | |
return results | |
except (requests.exceptions.RequestException, ValueError) as e: | |
print(f"[Tavily] search failed: {e}") | |
return [] | |
def duckduckgo_search(query: str, top_k: int = 3) -> List[Dict]: | |
"""Query DuckDuckGo and return up to top_k raw SERP hits.""" | |
try: | |
results = [] | |
with DDGS() as ddgs: | |
for hit in ddgs.text(query, safesearch="On", max_results=top_k): | |
results.append({ | |
"title": hit.get("title", ""), | |
"url": hit.get("href") or hit.get("url", ""), | |
"snippet": hit.get("body", ""), | |
"source": "DuckDuckGo" | |
}) | |
if len(results) >= top_k: | |
break | |
return results | |
except Exception as e: | |
print(f"[DuckDuckGo] search failed: {e}") | |
return [] | |
def hybrid_web_search(query: str, top_k: int = 3) -> str: | |
""" | |
Returns a JSON string with combined Tavily + DuckDuckGo results. | |
Always returns non-empty JSON with at least a placeholder result. | |
""" | |
tavily = tavily_search(query, top_k) | |
ddg = duckduckgo_search(query, top_k) | |
combined = tavily + ddg | |
# Always return at least a message to avoid agent crashes | |
if not combined: | |
combined = [{ | |
"title": "No results found", | |
"url": "", | |
"snippet": f"Could not find suitable web results for '{query}'.", | |
"source": "None" | |
}] | |
output = {"query": query, "results": combined} | |
return json.dumps(output, ensure_ascii=False, indent=2) | |
def web_search(query: str, top_k: int = 3) -> str: | |
""" | |
Full hybrid search with Playwright/BeautifulSoup scraping + Tavily/DuckDuckGo. | |
Always returns valid JSON output. | |
""" | |
results: List[Dict] = [] | |
# Step 1: DuckDuckGo + scraping | |
try: | |
with DDGS() as ddgs: | |
hits = ddgs.text(query, safesearch="On", max_results=top_k) | |
except Exception as e: | |
print(f"[web_search] DuckDuckGo lookup failed: {e}") | |
hits = [] | |
for hit in hits: | |
url = hit.get("href") or hit.get("url") | |
if not url: | |
continue | |
try: | |
with sync_playwright() as pw: | |
browser = pw.chromium.launch(headless=True) | |
page = browser.new_page() | |
page.goto(url, wait_until="domcontentloaded", timeout=15000) | |
html = page.content() | |
browser.close() | |
soup = BeautifulSoup(html, "html.parser") | |
text = soup.get_text(separator=" ", strip=True) | |
except Exception as e: | |
print(f"[web_search] scraping failed for {url}: {e}") | |
continue | |
if any(f in text.lower() for f in _forbidden): | |
continue | |
excerpt = " ".join(text.split()[:200]) | |
results.append({ | |
"title": hit.get("title", ""), | |
"url": url, | |
"snippet": hit.get("body", ""), | |
"content": excerpt | |
}) | |
# Step 2: Parse hybrid Tavily + DDG JSON into list | |
try: | |
raw = hybrid_web_search(query, top_k) | |
parsed = json.loads(raw) | |
other = parsed.get("results", []) | |
except Exception as e: | |
print(f"[web_search] parsing hybrid results failed: {e}") | |
other = [] | |
# Step 3: Combine and return | |
combined = results + other | |
if not combined: | |
combined = [{ | |
"title": "No results found", | |
"url": "", | |
"snippet": f"Could not find suitable content for '{query}'.", | |
"source": "None" | |
}] | |
output = { | |
"query": query, | |
"sources_count": len(combined), | |
"results": combined, | |
"sources": list({item.get("url", "") for item in combined if item.get("url")}) | |
} | |
return json.dumps(output, ensure_ascii=False, indent=2) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Tool for calculation βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def calculate(expr: str) -> str: | |
""" | |
Evaluates a mathematical expression safely. | |
Uses Python's numexpr for security and speed:contentReference[oaicite:21]{index=21}. | |
""" | |
try: | |
# Allow math constants | |
local_dict = {"pi": math.pi, "e": math.e} | |
# Evaluate expression using numexpr for safety/performance | |
import numexpr | |
result = numexpr.evaluate(expr, local_dict=local_dict) | |
return str(result.item()) | |
except Exception as e: | |
return f"Error calculating expression: {e}" | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Tool for Date and time βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def get_current_date(_: str = "") -> str: | |
""" | |
Returns the current date and time. Ignoring input. | |
""" | |
from datetime import datetime | |
return datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Tool for SQL Database ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def create_sql_agent_function(db_uri: str, top_k: int = 5): | |
""" | |
Creates a full-fledged SQL agent function that can answer natural language questions over a SQL database. | |
Args: | |
db_uri (str): The SQLAlchemy database URI, e.g. "sqlite:///Chinook.db" | |
top_k (int): Number of rows to limit in results (default 5) | |
Returns: | |
agent_executor: LangChain agent that can .run() or .stream() | |
""" | |
# 1) Initialize the database + LLM + toolkit | |
db = SQLDatabase.from_uri(db_uri) | |
llm = ChatGroq(model="meta-llama/llama-4-maverick-17b-128e-instruct", streaming=False, temperature=0) | |
toolkit = SQLDatabaseToolkit(db=db, llm=llm) | |
# 2) Prompt with all required variables declared AND used | |
prompt = PromptTemplate( | |
template=""" | |
You are an agent designed to interact with a SQL database. | |
Given the user question below, first generate a syntactically correct {dialect} query. | |
Then look at the results of that query, and return the answer. | |
Always limit to at most {top_k} rows unless the user specifies otherwise. | |
If you encounter an error, rewrite your SQL and retry. | |
DO NOT issue any INSERT/UPDATE/DELETE/DROP/ statements. | |
DO NOT try to create new database tables or columns when user has not asked for. | |
Always inspect the schema before querying. | |
Available tools: {tools} | |
Tool names: {tool_names} | |
User question: {input} | |
{agent_scratchpad} | |
""".strip(), | |
input_variables=["input", "dialect", "top_k", "agent_scratchpad", "tools", "tool_names"], | |
) | |
# 3) Create the agent with prompt + toolkit tools | |
agent_executor = create_sql_agent( | |
llm=llm, | |
toolkit=toolkit, | |
prompt=prompt, | |
verbose=False, | |
# pass top_k dynamically | |
extra_prompt_kwargs={"top_k": str(top_k), "dialect": db.dialect}, | |
) | |
return agent_executor | |
def execute_sql(query: str) -> str: | |
""" | |
Executes a SQL query against the uploaded SQLite DB (GLOBAL_DB_PATH). | |
Returns a string of results or error. | |
""" | |
if DB_PATH is None: | |
return "No database uploaded. Please upload a SQLite file first." | |
print("DB_PATH--------->:", DB_PATH) | |
db_uri = f"sqlite:///{DB_PATH}" | |
agent_executor2 = create_sql_agent_function(db_uri, top_k=5) | |
try: | |
result = agent_executor2.run(query) | |
except Exception as e: | |
result = f"Agent / SQL error: {e}" | |
return result | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Tool for RAG (Document Intelligence) βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def rag_index_document(DOC_PATH: str) -> str: | |
""" | |
Indexes the given document into the RAG vector store. | |
Supports text files or PDFs. Uses recursive text splitting for better chunking. | |
""" | |
global vector_store, rag_chain | |
text = "" | |
# Read text from file | |
if DOC_PATH and DOC_PATH.lower().endswith(".pdf"): | |
doc = fitz.open(DOC_PATH) | |
for page in doc: | |
text += page.get_text() | |
else: | |
with open(DOC_PATH, 'r', encoding='utf-8') as f: | |
text = f.read() | |
# Split text using recursive text splitter | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=500, # You can adjust this (e.g., 500-1000) | |
chunk_overlap=100 # Overlap for better context between chunks | |
) | |
# Split into chunks | |
texts = text_splitter.split_text(text) | |
# Create Document objects with metadata | |
docs = [Document(page_content=t, metadata={"source": DOC_PATH}) for t in texts] | |
# Initialize or append to FAISS vector store | |
embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2') | |
if vector_store is None: | |
vector_store = FAISS.from_documents(docs, embeddings) | |
else: | |
vector_store.add_documents(docs) | |
retriever = vector_store.as_retriever( | |
search_type="mmr", | |
search_kwargs={ | |
"k": 10, | |
"fetch_k": 10, | |
"lambda_mult": 0.25 | |
} | |
) | |
# Build or update the RetrievalQA chain | |
rag_chain = RetrievalQA.from_chain_type( | |
llm=llm, | |
chain_type="stuff", | |
retriever=retriever, | |
return_source_documents=False | |
) | |
def rag_answer(query: str) -> str: | |
""" | |
Answers a question using the RAG chain (on indexed documents). | |
""" | |
global rag_chain | |
if rag_chain is None: | |
return "No documents indexed. Please upload documents via /upload_doc." | |
try: | |
answer = rag_chain.run(query) | |
return answer | |
except Exception as e: | |
return f"RAG error: {e}" | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββ Tool for Image (understading, captioning & classification) βββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# Vision tools and functions | |
# Load image function | |
# def _load_image(): | |
# try: | |
# if IMG_PATH.startswith("http"): | |
# res = requests.get(IMG_PATH) | |
# res.raise_for_status() | |
# img = Image.open(BytesIO(res.content)) | |
# else: | |
# img = Image.open(IMG_PATH) | |
# return img.convert("RGB") | |
# except Exception as e: | |
# raise RuntimeError(f"Failed to load image: {e}") | |
def _load_image(resize_to=(512, 512)): | |
""" | |
Load and resize the image from IMG_PATH. | |
If the image is not valid, raise an error. | |
""" | |
try: | |
if IMG_PATH is None: | |
raise ValueError("No image uploaded. Please upload an image first.") | |
#return "No image uploaded. Please upload an image first." | |
with open(IMG_PATH, "rb") as f: | |
img = Image.open(f) | |
img.verify() # Verify it's an image | |
img = Image.open(IMG_PATH).convert("RGB") # Reopen after verify and convert | |
img = img.resize(resize_to) # resize image to reduce token size | |
return img | |
except UnidentifiedImageError: | |
raise ValueError(f"File at {IMG_PATH} is not a valid image.") | |
except Exception as e: | |
raise ValueError(f"Failed to load image at {IMG_PATH}: {str(e)}") | |
def _encode_image_to_base64(): | |
img = _load_image() | |
buffer = BytesIO() | |
img.save(buffer, format="PNG", optimize=True) # save optimized PNG | |
return base64.b64encode(buffer.getvalue()).decode("utf-8") | |
def _call_llama_llm(prompt_text: str) -> str: | |
b64 = _encode_image_to_base64() | |
message = HumanMessage( | |
content=[ | |
{"type": "text", "text": prompt_text}, | |
{ | |
"type": "image_url", | |
"image_url": { | |
"url": f"data:image/png;base64,{b64}" | |
} | |
} | |
] | |
) | |
response = llm.invoke([message]) | |
return response.content.strip() | |
def vision_query(task_prompt: str) -> str: | |
try: | |
return _call_llama_llm(task_prompt) | |
except Exception as llama_error: | |
print(f"[LLaMA-4V failed] {llama_error}") | |
try: | |
img = _load_image() | |
return pytesseract.image_to_string(img).strip() | |
except Exception as ocr_error: | |
print(f"[OCR fallback failed] {ocr_error}") | |
return "Unable to process the image or image is not uploaded. Please try again with a different input." | |
#### Create LangChain Tools #### | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Assigning tools as list ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
tool_list = [ | |
Tool(name="browse", func=web_search, description="Search the web and scrape top results. Uses DuckDuckGo (safe mode) for query. Prefers Playwright for loading pages, with requests/BeautifulSoup as fallback. Filters out any explicit content. Returns JSON with titles, URLs, and page text."), | |
Tool(name="calculate", func=calculate, description="Perform math calculations safely."), | |
Tool(name="date", func=get_current_date, description="Fetch the current date and time."), | |
Tool(name="sql", func=execute_sql, description="Execute SQL query on the uploaded database."), | |
Tool(name="rag", func=rag_answer, description="Answer questions using the uploaded documents with retrieval-augmented generation (RAG)."), | |
Tool( | |
name="vision", | |
func=vision_query, | |
description=( | |
"Perform any image-understanding taskβe.g. read text, classify objects, " | |
"generate captions, count or locate items, answer questions about the scene, " | |
"detect NSFW content, etc.βpowered by LLaMA 4-Vision. " | |
"If the request is OCR-style and LLaMA fails, it falls back to Tesseract OCR." | |
), | |
), | |
] | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Added Memory to Agent ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# 1) instantiate with your LLM | |
memory = AutoSummaryMemory( | |
llm=llm, | |
max_entries=20, # when chat β₯20 messages, trigger summary | |
reduce_to=5 # keep only last 5 after summarizing | |
) | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Initialize Agent βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# Initialize the agent with OpenAI and our tools. We use a zero-shot-react-description agent. | |
agent_executor = initialize_agent( | |
tools=tool_list, | |
llm=llm, | |
agent=AgentType.CHAT_CONVERSATIONAL_REACT_DESCRIPTION, | |
memory=memory, | |
verbose=True, | |
handle_parsing_errors=True, | |
#max_iterations=10, | |
) | |
# βββ Streaming & Fallback βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββ Streaming helper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def run_stream(query: str, data_paths: List[str] = None): | |
""" | |
Progressive tokenβbyβtoken streaming from the agent. | |
Args: | |
query: The userβs natural-language question. | |
data_paths: List of file paths (DB_PATH, DOC_PATH, IMG_PATH, OTH_PATH). | |
""" | |
# If no explicit list passed, rebuild from module globals | |
# if not data_paths: | |
# data_paths = [DB_PATH, DOC_PATH, IMG_PATH, OTH_PATH] | |
data_paths = [p for p in data_paths if p] | |
print(f"Data paths----------------->: {data_paths}") | |
# Re-inject each into the appropriate global (optionalβkeeps them current) | |
for path in data_paths: | |
ext = os.path.splitext(path)[1].lower() | |
if ext in {".png", ".jpg", ".jpeg", ".gif"}: | |
globals()['IMG_PATH'] = path | |
elif ext in {".pdf", ".txt", ".doc", ".docx"}: | |
globals()['DOC_PATH'] = path | |
elif ext in {".db", ".sqlite"}: | |
globals()['DB_PATH'] = path | |
else: | |
globals()['OTH_PATH'] = path | |
# Stream the agent response | |
hist = get_buffer_string(memory.chat_memory.messages) | |
print("Memory now contains:", memory.chat_memory.messages) | |
for chunk in agent_executor.stream({"input": query}): | |
text = chunk.get("text") | |
if text: | |
yield text | |
# # βββ Streaming & Fallback βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# def run_stream(query: str, data: str = None): | |
# """ | |
# Progressive tokenβbyβtoken streaming from the agent. | |
# Args: | |
# query: The userβs natural-language question. | |
# data: Path to a single uploaded file (image, document, or database). | |
# We will inspect its extension and set the appropriate config variable: | |
# .png/.jpg/.jpeg/.gif β IMG_PATH | |
# .pdf/.txt/.doc/.docx β DOC_PATH | |
# .db/.sqlite β DB_PATH | |
# others β OTH_PATH | |
# """ | |
# global DB_PATH, DOC_PATH, IMG_PATH, OTH_PATH | |
# # 1) If data provided, dispatch into the right config variable | |
# if data: | |
# ext = os.path.splitext(data)[1].lower() | |
# if ext in {".png", ".jpg", ".jpeg", ".gif"}: | |
# IMG_PATH = data | |
# print(f"Image path set to: {IMG_PATH}") | |
# elif ext in {".pdf", ".txt", ".doc", ".docx"}: | |
# DOC_PATH = data | |
# print(f"Document path set to: {DOC_PATH}") | |
# elif ext in {".db", ".sqlite"}: | |
# DB_PATH = data | |
# print(f"Database path set to: {DB_PATH}") | |
# else: | |
# OTH_PATH = data | |
# print(f"Other file path set to: {OTH_PATH}") | |
# # 2) Stream the agentβs response | |
# for chunk in agent_executor.stream({"input": query}): | |
# text = chunk.get("text") | |
# if text: | |
# yield text | |
def run_full(query: str) -> str: | |
""" | |
Fallback singleβshot answer (for pure-tool or final completeness). | |
""" | |
return agent_executor.run(query) | |
# Expose for Flask | |
class AgentInterface: | |
def __init__(self, executor): | |
self.executor = executor | |
def run_stream(self, q): | |
return run_stream(q) | |
def run_full(self, q): | |
return run_full(q) | |
agent = AgentInterface(agent_executor) | |
__all__ = [ | |
'agent_executor', 'run_stream', 'run_full', | |
'AgentInterface', 'GLOBAL_DB_PATH', 'rag_index_document' | |
] | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# βββββββββββββββββββββββββββββββββββββββββββββββ Refresh Memory Session βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
# Refresh Memory | |
def refresh_memory(): | |
memory.clear() # clear memory at start of each new session | |
memory.chat_memory.clear() # clear chat history |