juancho72h's picture
Upload 2 files
b837587 verified
raw
history blame
5.98 kB
import os
import pinecone
import openai
import gradio as gr
import torch
from dotenv import load_dotenv
from pinecone import Pinecone
from langchain_community.embeddings import HuggingFaceEmbeddings # Updated import
from rapidfuzz import fuzz # Replaced fuzzywuzzy with rapidfuzz
import logging
import re # To help with preprocessing
# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# Detect GPU availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")
# Suppress specific warning about clean_up_tokenization_spaces
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="clean_up_tokenization_spaces was not set")
# Load environment variables
load_dotenv()
# Access Pinecone and OpenAI API keys from environment variables
pinecone_api_key = os.getenv("PINECONE_API_KEY")
openai.api_key = os.getenv("OPENAI_API_KEY")
index_name = "amtrak-acela-ai-demo"
# Initialize Pinecone using a class-based method
pc = Pinecone(api_key=pinecone_api_key)
# Check if the index exists, if not, create it
def initialize_pinecone_index(index_name):
available_indexes = pc.list_indexes().names()
if index_name not in available_indexes:
print(f"Index '{index_name}' does not exist.")
# Create the index here if necessary for ZeroGPU usage
return pc.Index(index_name)
index = initialize_pinecone_index(index_name)
# Initialize HuggingFace embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarco-distilbert-base-v4")
# Initialize chat history manually
chat_history = []
# Helper function to preprocess text (removing unnecessary words)
def preprocess_text(text):
# Convert text to lowercase and remove special characters
text = re.sub(r'[^\w\s]', '', text.lower())
return text.strip()
# Helper function to recursively flatten any list to a string
def flatten_to_string(data):
if isinstance(data, list):
return " ".join([flatten_to_string(item) for item in data])
if data is None:
return ""
return str(data)
# Function to interact with Pinecone and OpenAI GPT-4
def get_model_response(human_input):
try:
# Preprocess the human input (cleaning up unnecessary words)
processed_input = preprocess_text(human_input)
# Embed the query
query_embedding = torch.tensor(embedding_model.embed_query(human_input)).to(device)
query_embedding = query_embedding.cpu().numpy().tolist()
# Query Pinecone index with top_k=5 to get more potential matches
search_results = index.query(vector=query_embedding, top_k=5, include_metadata=True)
context_list, images = [], []
for ind, result in enumerate(search_results['matches']):
document_content = flatten_to_string(result.get('metadata', {}).get('content', 'No content found'))
image_url = flatten_to_string(result.get('metadata', {}).get('image_path', None))
figure_desc = flatten_to_string(result.get('metadata', {}).get('figure_description', ''))
# Preprocess the figure description and match keywords
processed_figure_desc = preprocess_text(figure_desc)
similarity_score = fuzz.token_set_ratio(processed_input, processed_figure_desc)
logging.info(f"Matching '{processed_input}' with '{processed_figure_desc}', similarity score: {similarity_score}")
if similarity_score >= 80: # Keep the threshold at 80 for now
context_list.append(f"Relevant information: {document_content}")
if image_url and figure_desc:
images.append((figure_desc, image_url))
context_string = '\n\n'.join(context_list)
# Add user message to chat history
chat_history.append({"role": "user", "content": human_input})
# Create messages for OpenAI's API
messages = [{"role": "system", "content": "You are a helpful assistant."}] + chat_history + [
{"role": "system", "content": f"Here is some context:\n{context_string}"},
{"role": "user", "content": human_input}
]
# Validate messages before sending to OpenAI
for message in messages:
if not isinstance(message, dict) or "role" not in message or "content" not in message:
raise ValueError(f"Invalid message format: {message}")
# Send the conversation to OpenAI's API
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=500,
temperature=0.5
)
output_text = response['choices'][0]['message']['content'].strip()
# Add assistant message to chat history
chat_history.append({"role": "assistant", "content": output_text})
return output_text, images
except Exception as e:
return f"Error invoking model: {str(e)}", []
# Function to format text and images for display and track conversation
def get_model_response_with_images(human_input, history=None):
output_text, images = get_model_response(human_input)
if images:
# Append images in Markdown format for Gradio to render
image_output = "".join([f"\n\n**{figure_desc}**\n![{figure_desc}]({image_path})" for figure_desc, image_path in images])
return output_text + image_output
return output_text
# Set up Gradio interface
gr_interface = gr.ChatInterface(
fn=get_model_response_with_images,
title="Maintenance Assistant",
description="Ask questions related to the RMMM documents."
)
# Ensure ZeroGPU or Hugging Face Spaces handles launching properly
if __name__ == "__main__":
gr_interface.launch()