File size: 3,400 Bytes
1a08a52
cb5b9a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105

import os
import pickle
import threading
import logging
from queue import Queue, Empty
from datetime import datetime
from functools import lru_cache
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from time import perf_counter

# Configuration
MEMORY_FILE = os.environ.get("MEMORY_FILE", "memory.pkl")
INDEX_FILE = os.environ.get("INDEX_FILE", "memory.index")
EMBEDDING_MODEL = os.environ.get("EMBEDDING_MODEL", "all-MiniLM-L6-v2")

# Logging setup
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")

# Load embedding model
embedding_model = SentenceTransformer(EMBEDDING_MODEL)

# Initialize memory store and FAISS index
try:
    memory_data = pickle.load(open(MEMORY_FILE, "rb"))
    memory_index = faiss.read_index(INDEX_FILE)
    logging.info("Loaded existing memory and index.")
except Exception:
    memory_data = []
    dimension = embedding_model.get_sentence_embedding_dimension()
    memory_index = faiss.IndexFlatL2(dimension)
    logging.info("Initialized new memory and index.")

# Queue and worker for async flushing
_write_queue = Queue()

def _flush_worker():
    """Background thread: batch writes to disk."""
    while True:
        batch = []
        try:
            item = _write_queue.get(timeout=5)
            batch.append(item)
        except Empty:
            pass
        # Drain queue
        while not _write_queue.empty():
            batch.append(_write_queue.get_nowait())
        if batch:
            try:
                pickle.dump(memory_data, open(MEMORY_FILE, "wb"))
                faiss.write_index(memory_index, INDEX_FILE)
                logging.info(f"Flushed {len(batch)} entries to disk.")
            except Exception as e:
                logging.error(f"Flush error: {e}")

# Start flush thread
t = threading.Thread(target=_flush_worker, daemon=True)
t.start()

@lru_cache(maxsize=512)
def get_embedding(text: str) -> np.ndarray:
    """Compute embedding with timing."""
    start = perf_counter()
    vec = embedding_model.encode(text)
    elapsed = perf_counter() - start
    logging.info(f"get_embedding: {elapsed:.3f}s for '{text[:20]}...'")
    return vec


def embed_and_store(text: str, agent: str = "system", topic: str = ""):
    """Embed text, add to FAISS and queue disk write."""
    try:
        vec = get_embedding(text)
        memory_index.add(np.array([vec], dtype='float32'))
        memory_data.append({
            "text": text,
            "agent": agent,
            "topic": topic,
            "timestamp": datetime.now().isoformat()
        })
        _write_queue.put(True)
        logging.info(f"Queued memory: {agent} / '{text[:20]}...'")
    except Exception as e:
        logging.error(f"embed_and_store error: {e}")


def retrieve_relevant(query: str, k: int = 5) -> list:
    """Return top-k relevant memory entries."""
    try:
        q_vec = get_embedding(query)
        D, I = memory_index.search(np.array([q_vec], dtype='float32'), k)
        results = []
        for dist, idx in zip(D[0], I[0]):
            if idx < len(memory_data):
                entry = memory_data[idx]
                entry_copy = entry.copy()
                entry_copy['similarity'] = 1 - dist
                results.append(entry_copy)
        return results
    except Exception as e:
        logging.error(f"retrieve_relevant error: {e}")
        return []