Spaces:
Sleeping
Sleeping
File size: 3,400 Bytes
1a08a52 cb5b9a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import pickle
import threading
import logging
from queue import Queue, Empty
from datetime import datetime
from functools import lru_cache
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from time import perf_counter
# Configuration
MEMORY_FILE = os.environ.get("MEMORY_FILE", "memory.pkl")
INDEX_FILE = os.environ.get("INDEX_FILE", "memory.index")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "all-MiniLM-L6-v2")
# Logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
# Load embedding model
embedding_model = SentenceTransformer(EMBEDDING_MODEL)
# Initialize memory store and FAISS index
try:
memory_data = pickle.load(open(MEMORY_FILE, "rb"))
memory_index = faiss.read_index(INDEX_FILE)
logging.info("Loaded existing memory and index.")
except Exception:
memory_data = []
dimension = embedding_model.get_sentence_embedding_dimension()
memory_index = faiss.IndexFlatL2(dimension)
logging.info("Initialized new memory and index.")
# Queue and worker for async flushing
_write_queue = Queue()
def _flush_worker():
"""Background thread: batch writes to disk."""
while True:
batch = []
try:
item = _write_queue.get(timeout=5)
batch.append(item)
except Empty:
pass
# Drain queue
while not _write_queue.empty():
batch.append(_write_queue.get_nowait())
if batch:
try:
pickle.dump(memory_data, open(MEMORY_FILE, "wb"))
faiss.write_index(memory_index, INDEX_FILE)
logging.info(f"Flushed {len(batch)} entries to disk.")
except Exception as e:
logging.error(f"Flush error: {e}")
# Start flush thread
t = threading.Thread(target=_flush_worker, daemon=True)
t.start()
@lru_cache(maxsize=512)
def get_embedding(text: str) -> np.ndarray:
"""Compute embedding with timing."""
start = perf_counter()
vec = embedding_model.encode(text)
elapsed = perf_counter() - start
logging.info(f"get_embedding: {elapsed:.3f}s for '{text[:20]}...'")
return vec
def embed_and_store(text: str, agent: str = "system", topic: str = ""):
"""Embed text, add to FAISS and queue disk write."""
try:
vec = get_embedding(text)
memory_index.add(np.array([vec], dtype='float32'))
memory_data.append({
"text": text,
"agent": agent,
"topic": topic,
"timestamp": datetime.now().isoformat()
})
_write_queue.put(True)
logging.info(f"Queued memory: {agent} / '{text[:20]}...'")
except Exception as e:
logging.error(f"embed_and_store error: {e}")
def retrieve_relevant(query: str, k: int = 5) -> list:
"""Return top-k relevant memory entries."""
try:
q_vec = get_embedding(query)
D, I = memory_index.search(np.array([q_vec], dtype='float32'), k)
results = []
for dist, idx in zip(D[0], I[0]):
if idx < len(memory_data):
entry = memory_data[idx]
entry_copy = entry.copy()
entry_copy['similarity'] = 1 - dist
results.append(entry_copy)
return results
except Exception as e:
logging.error(f"retrieve_relevant error: {e}")
return [] |