Spaces:
Sleeping
Sleeping
import os | |
import json | |
import chromadb | |
import numpy as np | |
from dotenv import load_dotenv | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoTokenizer, AutoModel | |
from groq import Groq | |
import gradio as gr | |
import httpx # Used to make async HTTP requests to FastAPI | |
# Load environment variables | |
load_dotenv() | |
# List of API keys for Groq | |
api_keys = [ | |
os.getenv("GROQ_API_KEY"), | |
os.getenv("GROQ_API_KEY_2"), | |
os.getenv("GROQ_API_KEY_3"), | |
os.getenv("GROQ_API_KEY_4"), | |
] | |
if not any(api_keys): | |
raise ValueError("At least one GROQ_API_KEY environment variable must be set.") | |
# Initialize Groq client with the first API key | |
current_key_index = 0 | |
client = Groq(api_key=api_keys[current_key_index]) | |
# FastAPI app | |
app = FastAPI() | |
# Define Groq-based model with fallback | |
class GroqChatbot: | |
def __init__(self, api_keys): | |
self.api_keys = api_keys | |
self.current_key_index = 0 | |
self.client = Groq(api_key=self.api_keys[self.current_key_index]) | |
def switch_key(self): | |
"""Switch to the next API key in the list.""" | |
self.current_key_index = (self.current_key_index + 1) % len(self.api_keys) | |
self.client = Groq(api_key=self.api_keys[self.current_key_index]) | |
print(f"Switched to API key index {self.current_key_index}") | |
def get_response(self, prompt): | |
"""Get a response from the API, switching keys on failure.""" | |
while True: | |
try: | |
response = self.client.chat.completions.create( | |
messages=[ | |
{"role": "system", "content": "You are a helpful AI assistant."}, | |
{"role": "user", "content": prompt} | |
], | |
model="llama3-70b-8192", | |
) | |
return response.choices[0].message.content | |
except Exception as e: | |
print(f"Error: {e}") | |
self.switch_key() | |
if self.current_key_index == 0: | |
return "All API keys have been exhausted. Please try again later." | |
def text_to_embedding(self, text): | |
"""Convert text to embedding using the current model.""" | |
try: | |
# Load the model and tokenizer | |
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-3.2-1B") | |
model = AutoModel.from_pretrained("NousResearch/Llama-3.2-1B") | |
# Move model to GPU if available | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model = model.to(device) | |
model.eval() | |
# Ensure tokenizer has a padding token | |
if tokenizer.pad_token is None: | |
tokenizer.pad_token = tokenizer.eos_token | |
# Tokenize the text | |
encoded_input = tokenizer( | |
text, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors='pt' | |
).to(device) | |
# Generate embeddings | |
with torch.no_grad(): | |
model_output = model(**encoded_input) | |
sentence_embeddings = model_output.last_hidden_state | |
# Mean pooling | |
attention_mask = encoded_input['attention_mask'] | |
mask = attention_mask.unsqueeze(-1).expand(sentence_embeddings.size()).float() | |
masked_embeddings = sentence_embeddings * mask | |
summed = torch.sum(masked_embeddings, dim=1) | |
summed_mask = torch.clamp(torch.sum(attention_mask, dim=1).unsqueeze(-1), min=1e-9) | |
mean_pooled = (summed / summed_mask).squeeze() | |
# Move to CPU and convert to numpy | |
embedding = mean_pooled.cpu().numpy() | |
# Normalize the embedding vector | |
embedding = embedding / np.linalg.norm(embedding) | |
print(f"Generated embedding for text: {text}") | |
return embedding | |
except Exception as e: | |
print(f"Error generating embedding: {e}") | |
return None | |
# Modify LocalEmbeddingStore to use ChromaDB | |
class LocalEmbeddingStore: | |
def __init__(self, storage_dir="./chromadb_storage"): | |
self.client = chromadb.PersistentClient(path=storage_dir) # Use ChromaDB client with persistent storage | |
self.collection_name = "chatbot_docs" # Collection for storing embeddings | |
self.collection = self.client.get_or_create_collection(name=self.collection_name) | |
def add_embedding(self, doc_id, embedding, metadata): | |
"""Add a document and its embedding to ChromaDB.""" | |
self.collection.add( | |
documents=[doc_id], # Document ID for identification | |
embeddings=[embedding], # Embedding for the document | |
metadatas=[metadata], # Optional metadata | |
ids=[doc_id] # Same ID as document ID | |
) | |
print(f"Added embedding for document ID: {doc_id}") | |
def search_embedding(self, query_embedding, num_results=3): | |
"""Search for the most relevant document based on embedding similarity.""" | |
results = self.collection.query( | |
query_embeddings=[query_embedding], | |
n_results=num_results | |
) | |
print(f"Search results: {results}") | |
return results['documents'], results['distances'] # Returning both document IDs and distances | |
# Modify RAGSystem to integrate ChromaDB search | |
class RAGSystem: | |
def __init__(self, groq_client, embedding_store): | |
self.groq_client = groq_client | |
self.embedding_store = embedding_store | |
def get_most_relevant_document(self, query_embedding): | |
"""Retrieve the most relevant document based on cosine similarity.""" | |
docs, distances = self.embedding_store.search_embedding(query_embedding) | |
if docs: | |
return docs[0], distances[0][0] # Return the most relevant document and the first distance value | |
return None, None | |
def chat_with_rag(self, user_input): | |
"""Handle the RAG process.""" | |
query_embedding = self.groq_client.text_to_embedding(user_input) | |
if query_embedding is None or query_embedding.size == 0: | |
return "Failed to generate embeddings." | |
context_document_id, similarity_score = self.get_most_relevant_document(query_embedding) | |
if not context_document_id: | |
return "No relevant documents found." | |
# Assuming metadata retrieval works | |
context_metadata = f"Metadata for {context_document_id}" # Placeholder, implement as needed | |
prompt = f"""Context (similarity score {similarity_score:.2f}): | |
{context_metadata} | |
User: {user_input} | |
AI:""" | |
return self.groq_client.get_response(prompt) | |
# Initialize components | |
embedding_store = LocalEmbeddingStore(storage_dir="./chromadb_storage") | |
chatbot = GroqChatbot(api_keys=api_keys) | |
rag_system = RAGSystem(groq_client=chatbot, embedding_store=embedding_store) | |
# Pydantic models for API request and response | |
class UserInput(BaseModel): | |
input_text: str | |
class ChatResponse(BaseModel): | |
response: str | |
async def read_root(): | |
return {"message": "Welcome to the Groq and ChromaDB integration API!"} | |
async def chat(user_input: UserInput): | |
"""Handle chat interactions with Groq and ChromaDB.""" | |
ai_response = rag_system.chat_with_rag(user_input.input_text) | |
return ChatResponse(response=ai_response) | |
async def embed_text(user_input: UserInput): | |
"""Handle text embedding.""" | |
embedding = chatbot.text_to_embedding(user_input.input_text) | |
if embedding is not None: | |
return ChatResponse(response="Text embedded successfully.") | |
else: | |
raise HTTPException(status_code=400, detail="Embedding generation failed.") | |
async def add_document(user_input: UserInput): | |
"""Add a document embedding to ChromaDB.""" | |
embedding = chatbot.text_to_embedding(user_input.input_text) | |
if embedding is not None: | |
doc_id = "sample_document" # You can generate or pass a doc ID | |
embedding_store.add_embedding(doc_id, embedding, metadata={"source": "user_input"}) | |
return ChatResponse(response="Document added to the database.") | |
else: | |
raise HTTPException(status_code=400, detail="Embedding generation failed.") | |
# Gradio Interface for querying the FastAPI /chat endpoint | |
async def gradio_chatbot(input_text: str): | |
async with httpx.AsyncClient() as client: | |
response = await client.post( | |
"http://127.0.0.1:7860/chat", # FastAPI endpoint | |
json={"input_text": input_text} | |
) | |
response_data = response.json() | |
return response_data["response"] | |
# Gradio Interface | |
iface = gr.Interface(fn=gradio_chatbot, inputs="text", outputs="text") | |
if __name__ == "__main__": | |
# Launch the Gradio interface | |
iface.launch() | |