Spaces:
Sleeping
Sleeping
File size: 7,518 Bytes
c901ab2 8db52a5 c901ab2 8db52a5 c901ab2 8db52a5 c901ab2 8db52a5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import os
import json
import chromadb
import numpy as np
from dotenv import load_dotenv
import gradio as gr
from groq import Groq
import torch
from transformers import AutoTokenizer, AutoModel
import logging
# Load environment variables
load_dotenv()
# List of API keys for Groq
api_keys = [
os.getenv("GROQ_API_KEY"),
os.getenv("GROQ_API_KEY_2"),
os.getenv("GROQ_API_KEY_3"),
os.getenv("GROQ_API_KEY_4"),
]
if not any(api_keys):
raise ValueError("At least one GROQ_API_KEY environment variable must be set.")
# Initialize Groq client with the first API key
current_key_index = 0
client = Groq(api_key=api_keys[current_key_index])
# Define Groq-based model with fallback
class GroqChatbot:
def __init__(self, api_keys):
self.api_keys = api_keys
self.current_key_index = 0
self.client = Groq(api_key=self.api_keys[self.current_key_index])
def switch_key(self):
"""Switch to the next API key in the list."""
self.current_key_index = (self.current_key_index + 1) % len(self.api_keys)
self.client = Groq(api_key=self.api_keys[self.current_key_index])
print(f"Switched to API key index {self.current_key_index}")
def get_response(self, prompt):
"""Get a response from the API, switching keys on failure."""
while True:
try:
response = self.client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a helpful AI assistant."},
{"role": "user", "content": prompt}
],
model="llama3-70b-8192",
)
return response.choices[0].message.content
except Exception as e:
print(f"Error: {e}")
self.switch_key()
if self.current_key_index == 0:
return "All API keys have been exhausted. Please try again later."
def text_to_embedding(self, text):
"""Convert text to embedding using the current model."""
try:
# Load the model and tokenizer
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B")
model = AutoModel.from_pretrained("NousResearch/Llama-3.2-1B")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
# Ensure tokenizer has a padding token
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Tokenize the text
encoded_input = tokenizer(
text,
padding=True,
truncation=True,
max_length=512,
return_tensors='pt'
).to(device)
# Generate embeddings
with torch.no_grad():
model_output = model(**encoded_input)
sentence_embeddings = model_output.last_hidden_state
# Mean pooling
attention_mask = encoded_input['attention_mask']
mask = attention_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float()
masked_embeddings = sentence_embeddings * mask
summed = torch.sum(masked_embeddings, dim=1)
summed_mask = torch.clamp(torch.sum(attention_mask, dim=1).unsqueeze(-1), min=1e-9)
mean_pooled = (summed / summed_mask).squeeze()
# Move to CPU and convert to numpy
embedding = mean_pooled.cpu().numpy()
# Normalize the embedding vector
embedding = embedding / np.linalg.norm(embedding)
print(f"Generated embedding for text: {text}")
return embedding
except Exception as e:
print(f"Error generating embedding: {e}")
return None
# Modify LocalEmbeddingStore to use ChromaDB
class LocalEmbeddingStore:
def __init__(self, storage_dir="./chromadb_storage"):
self.client = chromadb.PersistentClient(path=storage_dir) # Use ChromaDB client with persistent storage
self.collection_name = "chatbot_docs" # Collection for storing embeddings
self.collection = self.client.get_or_create_collection(name=self.collection_name)
def add_embedding(self, doc_id, embedding, metadata):
"""Add a document and its embedding to ChromaDB."""
self.collection.add(
documents=[doc_id], # Document ID for identification
embeddings=[embedding], # Embedding for the document
metadatas=[metadata], # Optional metadata
ids=[doc_id] # Same ID as document ID
)
print(f"Added embedding for document ID: {doc_id}")
def search_embedding(self, query_embedding, num_results=3):
"""Search for the most relevant document based on embedding similarity."""
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=num_results
)
print(f"Search results: {results}")
return results['documents'], results['distances'] # Returning both document IDs and distances
# Modify RAGSystem to integrate ChromaDB search
class RAGSystem:
def __init__(self, groq_client, embedding_store):
self.groq_client = groq_client
self.embedding_store = embedding_store
def get_most_relevant_document(self, query_embedding):
"""Retrieve the most relevant document based on cosine similarity."""
docs, distances = self.embedding_store.search_embedding(query_embedding)
if docs:
return docs[0], distances[0][0] # Return the most relevant document and the first distance value
return None, None
def chat_with_rag(self, user_input):
"""Handle the RAG process."""
query_embedding = self.groq_client.text_to_embedding(user_input)
if query_embedding is None or query_embedding.size == 0:
return "Failed to generate embeddings."
context_document_id, similarity_score = self.get_most_relevant_document(query_embedding)
if not context_document_id:
return "No relevant documents found."
# Assuming metadata retrieval works
context_metadata = f"Metadata for {context_document_id}" # Placeholder, implement as needed
prompt = f"""Context (similarity score {similarity_score:.2f}):
{context_metadata}
User: {user_input}
AI:"""
return self.groq_client.get_response(prompt)
# Initialize components
embedding_store = LocalEmbeddingStore(storage_dir="./chromadb_storage")
chatbot = GroqChatbot(api_keys=api_keys)
rag_system = RAGSystem(groq_client=chatbot, embedding_store=embedding_store)
# Gradio UI
def chat_ui(user_input, chat_history):
"""Handle chat interactions and update history."""
if not user_input.strip():
return chat_history
ai_response = rag_system.chat_with_rag(user_input)
chat_history.append((user_input, ai_response))
return chat_history
# Gradio interface
with gr.Blocks() as demo:
chat_history = gr.Chatbot(label="Groq Chatbot with RAG", elem_id="chatbox")
user_input = gr.Textbox(placeholder="Enter your prompt here...")
submit_button = gr.Button("Submit")
submit_button.click(chat_ui, inputs=[user_input, chat_history], outputs=chat_history)
if __name__ == "__main__":
demo.launch()
|