Spaces:
Sleeping
Sleeping
import os | |
import pickle | |
import threading | |
import logging | |
from queue import Queue, Empty | |
from datetime import datetime | |
from functools import lru_cache | |
import numpy as np | |
import faiss | |
from sentence_transformers import SentenceTransformer | |
from time import perf_counter | |
# Configuration | |
MEMORY_FILE = os.environ.get("MEMORY_FILE", "memory.pkl") | |
INDEX_FILE = os.environ.get("INDEX_FILE", "memory.index") | |
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "all-MiniLM-L6-v2") | |
# Logging setup | |
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") | |
# Load embedding model | |
embedding_model = SentenceTransformer(EMBEDDING_MODEL) | |
# Initialize memory store and FAISS index | |
try: | |
memory_data = pickle.load(open(MEMORY_FILE, "rb")) | |
memory_index = faiss.read_index(INDEX_FILE) | |
logging.info("Loaded existing memory and index.") | |
except Exception: | |
memory_data = [] | |
dimension = embedding_model.get_sentence_embedding_dimension() | |
memory_index = faiss.IndexFlatL2(dimension) | |
logging.info("Initialized new memory and index.") | |
# Queue and worker for async flushing | |
_write_queue = Queue() | |
def _flush_worker(): | |
"""Background thread: batch writes to disk.""" | |
while True: | |
batch = [] | |
try: | |
item = _write_queue.get(timeout=5) | |
batch.append(item) | |
except Empty: | |
pass | |
# Drain queue | |
while not _write_queue.empty(): | |
batch.append(_write_queue.get_nowait()) | |
if batch: | |
try: | |
pickle.dump(memory_data, open(MEMORY_FILE, "wb")) | |
faiss.write_index(memory_index, INDEX_FILE) | |
logging.info(f"Flushed {len(batch)} entries to disk.") | |
except Exception as e: | |
logging.error(f"Flush error: {e}") | |
# Start flush thread | |
t = threading.Thread(target=_flush_worker, daemon=True) | |
t.start() | |
def get_embedding(text: str) -> np.ndarray: | |
"""Compute embedding with timing.""" | |
start = perf_counter() | |
vec = embedding_model.encode(text) | |
elapsed = perf_counter() - start | |
logging.info(f"get_embedding: {elapsed:.3f}s for '{text[:20]}...'") | |
return vec | |
def embed_and_store(text: str, agent: str = "system", topic: str = ""): | |
"""Embed text, add to FAISS and queue disk write.""" | |
try: | |
vec = get_embedding(text) | |
memory_index.add(np.array([vec], dtype='float32')) | |
memory_data.append({ | |
"text": text, | |
"agent": agent, | |
"topic": topic, | |
"timestamp": datetime.now().isoformat() | |
}) | |
_write_queue.put(True) | |
logging.info(f"Queued memory: {agent} / '{text[:20]}...'") | |
except Exception as e: | |
logging.error(f"embed_and_store error: {e}") | |
def retrieve_relevant(query: str, k: int = 5) -> list: | |
"""Return top-k relevant memory entries.""" | |
try: | |
q_vec = get_embedding(query) | |
D, I = memory_index.search(np.array([q_vec], dtype='float32'), k) | |
results = [] | |
for dist, idx in zip(D[0], I[0]): | |
if idx < len(memory_data): | |
entry = memory_data[idx] | |
entry_copy = entry.copy() | |
entry_copy['similarity'] = 1 - dist | |
results.append(entry_copy) | |
return results | |
except Exception as e: | |
logging.error(f"retrieve_relevant error: {e}") | |
return [] |