File size: 5,753 Bytes
a33458e
 
9f0d171
 
 
a33458e
 
 
 
 
 
9f0d171
 
 
 
a33458e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f0d171
 
a33458e
9f0d171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a33458e
 
 
9f0d171
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a33458e
9f0d171
 
 
 
 
 
 
 
 
 
 
a33458e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9f0d171
 
 
 
 
a33458e
 
 
9f0d171
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import os
import sys
import time
import random
import logging
from langchain.vectorstores import Qdrant
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Add project root to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
from app.config import VECTOR_DB_PATH, COLLECTION_NAME
from app.core.llm import get_llm, get_embeddings, get_chat_model

class MemoryManager:
    """Manages the RAG memory system using a vector database."""
    
    def __init__(self):
        self.embeddings = get_embeddings()
        self.llm = get_llm()
        self.chat_model = get_chat_model()
        self.client = self._init_qdrant_client()
        self.vectorstore = self._init_vector_store()
        self.memory = ConversationBufferMemory(
            memory_key="chat_history",
            return_messages=True
        )
        
    def _init_qdrant_client(self):
        """Initialize the Qdrant client with retry logic for concurrent access issues."""
        # Create directory if it doesn't exist
        os.makedirs(VECTOR_DB_PATH, exist_ok=True)
        
        # Add a small random delay to reduce chance of concurrent access
        time.sleep(random.uniform(0.1, 0.5))
        
        # Generate a unique path for this instance to avoid collision
        instance_id = str(random.randint(10000, 99999))
        unique_path = os.path.join(VECTOR_DB_PATH, f"instance_{instance_id}")
        
        max_retries = 3
        retry_count = 0
        
        while retry_count < max_retries:
            try:
                logger.info(f"Attempting to initialize Qdrant client (attempt {retry_count+1}/{max_retries})")
                # Try to use the unique path first
                try:
                    os.makedirs(unique_path, exist_ok=True)
                    return QdrantClient(path=unique_path)
                except Exception as e:
                    logger.warning(f"Could not use unique path {unique_path}: {e}")
                    
                    # Try the main path as fallback
                    return QdrantClient(path=VECTOR_DB_PATH)
                    
            except RuntimeError as e:
                if "already accessed by another instance" in str(e):
                    retry_count += 1
                    wait_time = random.uniform(0.5, 2.0) * retry_count
                    logger.warning(f"Qdrant concurrent access detected. Retrying in {wait_time:.2f} seconds...")
                    time.sleep(wait_time)
                else:
                    # Different error, don't retry
                    raise
                    
        # If all retries failed, try to use in-memory storage as last resort
        logger.warning("All Qdrant client initialization attempts failed. Using in-memory mode.")
        return QdrantClient(":memory:")
    
    def _init_vector_store(self):
        """Initialize the vector store."""
        try:
            collections = self.client.get_collections().collections
            collection_names = [collection.name for collection in collections]
            
            # Get vector dimension from the embedding model
            vector_size = len(self.embeddings.embed_query("test"))
            
            if COLLECTION_NAME not in collection_names:
                # Create the collection with appropriate settings
                self.client.create_collection(
                    collection_name=COLLECTION_NAME,
                    vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE),
                )
                logger.info(f"Created new collection: {COLLECTION_NAME}")
                
            return Qdrant(
                client=self.client, 
                collection_name=COLLECTION_NAME,
                embeddings=self.embeddings
            )
        except Exception as e:
            logger.error(f"Error initializing vector store: {e}")
            # Create a simple in-memory fallback
            logger.warning("Using in-memory vector store as fallback.")
            return Qdrant.from_texts(
                ["Hello, I am your AI assistant."], 
                self.embeddings, 
                location=":memory:", 
                collection_name=COLLECTION_NAME
            )
    
    def get_retriever(self):
        """Get the retriever for RAG."""
        return self.vectorstore.as_retriever(
            search_type="similarity",
            search_kwargs={"k": 5}
        )
    
    def create_rag_chain(self):
        """Create a RAG chain for question answering."""
        # Using the chat model created with the regular LLM
        return ConversationalRetrievalChain.from_llm(
            llm=self.llm,
            retriever=self.get_retriever(),
            memory=self.memory,
            return_source_documents=True
        )
    
    def add_texts(self, texts, metadatas=None):
        """Add texts to the vector store."""
        try:
            return self.vectorstore.add_texts(texts=texts, metadatas=metadatas)
        except Exception as e:
            logger.error(f"Error adding texts to vector store: {e}")
            return ["error-id-" + str(random.randint(10000, 99999))]
    
    def similarity_search(self, query, k=5):
        """Perform a similarity search."""
        try:
            return self.vectorstore.similarity_search(query, k=k)
        except Exception as e:
            logger.error(f"Error during similarity search: {e}")
            return []