File size: 7,518 Bytes
c901ab2
 
 
 
 
8db52a5
 
c901ab2
 
8db52a5
c901ab2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8db52a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c901ab2
 
8db52a5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
import os
import json
import chromadb
import numpy as np
from dotenv import load_dotenv
import gradio as gr
from groq import Groq
import torch
from transformers import AutoTokenizer, AutoModel
import logging

# Load environment variables
load_dotenv()

# List of API keys for Groq
api_keys = [
    os.getenv("GROQ_API_KEY"),
    os.getenv("GROQ_API_KEY_2"),
    os.getenv("GROQ_API_KEY_3"),
    os.getenv("GROQ_API_KEY_4"),
]

if not any(api_keys):
    raise ValueError("At least one GROQ_API_KEY environment variable must be set.")

# Initialize Groq client with the first API key
current_key_index = 0
client = Groq(api_key=api_keys[current_key_index])

# Define Groq-based model with fallback
class GroqChatbot:
    def __init__(self, api_keys):
        self.api_keys = api_keys
        self.current_key_index = 0
        self.client = Groq(api_key=self.api_keys[self.current_key_index])

    def switch_key(self):
        """Switch to the next API key in the list."""
        self.current_key_index = (self.current_key_index + 1) % len(self.api_keys)
        self.client = Groq(api_key=self.api_keys[self.current_key_index])
        print(f"Switched to API key index {self.current_key_index}")

    def get_response(self, prompt):
        """Get a response from the API, switching keys on failure."""
        while True:
            try:
                response = self.client.chat.completions.create(
                    messages=[
                        {"role": "system", "content": "You are a helpful AI assistant."},
                        {"role": "user", "content": prompt}
                    ],
                    model="llama3-70b-8192",
                )
                return response.choices[0].message.content
            except Exception as e:
                print(f"Error: {e}")
                self.switch_key()
                if self.current_key_index == 0:
                    return "All API keys have been exhausted. Please try again later."

    def text_to_embedding(self, text):
        """Convert text to embedding using the current model."""
        try:
            # Load the model and tokenizer
            tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
            model = AutoModel.from_pretrained("NousResearch/Llama-3.2-1B")

            # Move model to GPU if available
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model = model.to(device)
            model.eval()

            # Ensure tokenizer has a padding token
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token

            # Tokenize the text
            encoded_input = tokenizer(
                text,
                padding=True,
                truncation=True,
                max_length=512,
                return_tensors='pt'
            ).to(device)

            # Generate embeddings
            with torch.no_grad():
                model_output = model(**encoded_input)
                sentence_embeddings = model_output.last_hidden_state

                # Mean pooling
                attention_mask = encoded_input['attention_mask']
                mask = attention_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float()
                masked_embeddings = sentence_embeddings * mask
                summed = torch.sum(masked_embeddings, dim=1)
                summed_mask = torch.clamp(torch.sum(attention_mask, dim=1).unsqueeze(-1), min=1e-9)
                mean_pooled = (summed / summed_mask).squeeze()

                # Move to CPU and convert to numpy
                embedding = mean_pooled.cpu().numpy()

                # Normalize the embedding vector
                embedding = embedding / np.linalg.norm(embedding)

                print(f"Generated embedding for text: {text}")
                return embedding
        except Exception as e:
            print(f"Error generating embedding: {e}")
            return None

# Modify LocalEmbeddingStore to use ChromaDB
class LocalEmbeddingStore:
    def __init__(self, storage_dir="./chromadb_storage"):
        self.client = chromadb.PersistentClient(path=storage_dir)  # Use ChromaDB client with persistent storage
        self.collection_name = "chatbot_docs"  # Collection for storing embeddings
        self.collection = self.client.get_or_create_collection(name=self.collection_name)

    def add_embedding(self, doc_id, embedding, metadata):
        """Add a document and its embedding to ChromaDB."""
        self.collection.add(
            documents=[doc_id],  # Document ID for identification
            embeddings=[embedding],  # Embedding for the document
            metadatas=[metadata],  # Optional metadata
            ids=[doc_id]  # Same ID as document ID
        )
        print(f"Added embedding for document ID: {doc_id}")

    def search_embedding(self, query_embedding, num_results=3):
        """Search for the most relevant document based on embedding similarity."""
        results = self.collection.query(
            query_embeddings=[query_embedding],
            n_results=num_results
        )
        print(f"Search results: {results}")
        return results['documents'], results['distances']  # Returning both document IDs and distances

# Modify RAGSystem to integrate ChromaDB search
class RAGSystem:
    def __init__(self, groq_client, embedding_store):
        self.groq_client = groq_client
        self.embedding_store = embedding_store

    def get_most_relevant_document(self, query_embedding):
        """Retrieve the most relevant document based on cosine similarity."""
        docs, distances = self.embedding_store.search_embedding(query_embedding)
        if docs:
            return docs[0], distances[0][0]  # Return the most relevant document and the first distance value
        return None, None

    def chat_with_rag(self, user_input):
        """Handle the RAG process."""
        query_embedding = self.groq_client.text_to_embedding(user_input)
        if query_embedding is None or query_embedding.size == 0:
            return "Failed to generate embeddings."

        context_document_id, similarity_score = self.get_most_relevant_document(query_embedding)
        if not context_document_id:
            return "No relevant documents found."

        # Assuming metadata retrieval works
        context_metadata = f"Metadata for {context_document_id}"  # Placeholder, implement as needed

        prompt = f"""Context (similarity score {similarity_score:.2f}):
{context_metadata}

User: {user_input}
AI:"""
        return self.groq_client.get_response(prompt)

# Initialize components
embedding_store = LocalEmbeddingStore(storage_dir="./chromadb_storage")
chatbot = GroqChatbot(api_keys=api_keys)
rag_system = RAGSystem(groq_client=chatbot, embedding_store=embedding_store)

# Gradio UI
def chat_ui(user_input, chat_history):
    """Handle chat interactions and update history."""
    if not user_input.strip():
        return chat_history
    ai_response = rag_system.chat_with_rag(user_input)
    chat_history.append((user_input, ai_response))
    return chat_history

# Gradio interface
with gr.Blocks() as demo:
    chat_history = gr.Chatbot(label="Groq Chatbot with RAG", elem_id="chatbox")
    user_input = gr.Textbox(placeholder="Enter your prompt here...")
    submit_button = gr.Button("Submit")
    submit_button.click(chat_ui, inputs=[user_input, chat_history], outputs=chat_history)

if __name__ == "__main__":
    demo.launch()