Spaces:
Sleeping
Sleeping
import os | |
import torch.cuda | |
import numpy as np | |
import faiss | |
import gradio as gr | |
import re | |
from openai import OpenAI | |
from langchain_community.document_loaders import TextLoader, DirectoryLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import FAISS | |
from sentence_transformers import SentenceTransformer | |
class MultiAgentRAG: | |
def __init__(self, embedding_model_name, openai_model_id, data_folder, api_key=None): | |
self.use_gpu = torch.cuda.is_available() | |
self.all_splits = self.load_documents(data_folder) | |
self.embeddings = SentenceTransformer(embedding_model_name) | |
self.faiss_index = self.create_faiss_index() | |
self.openai_client = OpenAI(api_key=api_key or os.environ.get("OPENAI_API_KEY")) | |
self.openai_model_id = openai_model_id | |
def load_documents(self, folder_path): | |
loader = DirectoryLoader(folder_path, loader_cls=TextLoader) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=250) | |
all_splits = text_splitter.split_documents(documents) | |
return all_splits | |
def create_faiss_index(self): | |
all_texts = [split.page_content for split in self.all_splits] | |
embeddings = self.embeddings.encode(all_texts, convert_to_tensor=True).cpu().numpy() | |
index = faiss.IndexFlatL2(embeddings.shape[1]) | |
index.add(embeddings) | |
try: | |
gpu_resource = faiss.StandardGpuResources() | |
gpu_index = faiss.index_cpu_to_gpu(gpu_resource, 0, index) | |
return gpu_index | |
except: | |
return index | |
def generate_openai_response(self, messages, max_tokens=1000): | |
try: | |
response = self.openai_client.chat.completions.create( | |
model=self.openai_model_id, | |
messages=messages, | |
max_tokens=max_tokens, | |
temperature=0.8, | |
top_p=1.0, | |
frequency_penalty=0, | |
presence_penalty=0 | |
) | |
return response.choices[0].message.content | |
except: | |
return "Text generation process encountered an error" | |
def retrieval_agent(self, query): | |
query_embedding = self.embeddings.encode(query, convert_to_tensor=True).cpu().numpy() | |
distances, indices = self.faiss_index.search(np.array([query_embedding]), k=3) | |
content = "" | |
for idx in indices[0]: | |
content += self.all_splits[idx].page_content + "\n" | |
return content | |
def grading_agent(self, query, retrieved_content): | |
messages = [ | |
{"role": "system", "content": "You are an expert at evaluating relevance."}, | |
{"role": "user", "content": f"Query: {query}\nRetrieved Content:\n{retrieved_content}\nRate the relevance on a scale of 1-10."} | |
] | |
grading_response = self.generate_openai_response(messages) | |
match = re.search(r'\b([1-9]|10)\b', grading_response) | |
rating = int(match.group()) if match else 5 | |
return rating, grading_response | |
def query_rewrite_agent(self, original_query): | |
messages = [ | |
{"role": "system", "content": "You are an expert at rewriting queries."}, | |
{"role": "user", "content": f"Original Query: {original_query}\nRewritten Query:"} | |
] | |
return self.generate_openai_response(messages).strip() | |
def generation_agent(self, query, retrieved_content): | |
messages = [ | |
{"role": "system", "content": "You are a knowledgeable assistant."}, | |
{"role": "user", "content": f"Query: {query}\nSolution==>"} | |
] | |
return self.generate_openai_response(messages) | |
def run_multi_agent_rag(self, query): | |
for _ in range(3): | |
retrieved_content = self.retrieval_agent(query) | |
relevance_score, grading_explanation = self.grading_agent(query, retrieved_content) | |
if relevance_score >= 7: | |
return self.generation_agent(query, retrieved_content), retrieved_content, grading_explanation | |
query = self.query_rewrite_agent(query) | |
return "Unable to find a relevant answer.", "", "Low relevance across all attempts." | |
def qa_infer_gradio(self, query): | |
answer, retrieved_content, grading_explanation = self.run_multi_agent_rag(query) | |
return answer, f"Retrieved Content:\n{retrieved_content}\n\nGrading Explanation:\n{grading_explanation}" | |
def launch_interface(doc_retrieval_gen): | |
css_code = """ | |
.gradio-container { background-color: #daccdb; } | |
button { background-color: #927fc7; color: black; border: 1px solid black; padding: 10px; margin-right: 10px; font-size: 16px; font-weight: bold; } | |
""" | |
EXAMPLES = [ | |
"On which devices can the VIP and CSI2 modules operate simultaneously?", | |
"I'm using Code Composer Studio 5.4.0.00091 and enabled FPv4SPD16 floating point support for CortexM4 in TDA2. However, after building the project, the .asm file shows --float_support=vfplib instead of FPv4SPD16. Why is this happening?", | |
"Could you clarify the maximum number of cameras that can be connected simultaneously to the video input ports on the TDA2x SoC?" | |
] | |
interface = gr.Interface( | |
fn=doc_retrieval_gen.qa_infer_gradio, | |
inputs=[gr.Textbox(label="QUERY", placeholder="Enter your query here")], | |
allow_flagging='never', | |
examples=EXAMPLES, | |
cache_examples=False, | |
outputs=[gr.Textbox(label="RESPONSE"), gr.Textbox(label="RELATED QUERIES")], | |
css=css_code, | |
title="TI E2E FORUM Multi-Agent RAG" | |
) | |
interface.launch(debug=True) | |
if __name__ == "__main__": | |
embedding_model_name = 'flax-sentence-embeddings/all_datasets_v3_MiniLM-L12' | |
openai_model_id = "gpt-4-turbo" | |
data_folder = 'sample_embedding_folder2' | |
try: | |
multi_agent_rag = MultiAgentRAG(embedding_model_name, openai_model_id, data_folder) | |
launch_interface(multi_agent_rag) | |
except Exception as e: | |
print(f"Error initializing Multi-Agent RAG: {str(e)}") | |