File size: 6,430 Bytes
30cb161 ba1509a 30cb161 ba1509a 30cb161 ba1509a 79041c3 30cb161 ba1509a 30cb161 ba1509a 30cb161 ba1509a 30cb161 ba1509a 30cb161 ba1509a 30cb161 ba1509a e5e0026 ba1509a 30cb161 ba1509a 30cb161 ba1509a 30cb161 ba1509a 1a7916d ba1509a e5e0026 ba1509a 1a7916d ba1509a 30cb161 1a7916d ba1509a 1a7916d ba1509a 30cb161 1a7916d ba1509a 30cb161 1a7916d ba1509a 30cb161 1a7916d 30cb161 ba1509a 30cb161 ba1509a 30cb161 ba1509a 30cb161 ba1509a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
import os
import pinecone
import openai
import gradio as gr
from dotenv import load_dotenv
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.text_splitter import CharacterTextSplitter
from langchain.docstore.document import Document
import boto3
# Load environment variables
load_dotenv()
# Access secrets from environment variables
openai.api_key = os.getenv("OPENAI_API_KEY")
pinecone_api_key = os.getenv("PINECONE_API_KEY")
aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
bucket_name = 'amtrak-superliner-ai-poc'
txt_file_name = 'combined_extracted_text.txt'
index_name = "amtrak-acela-ai-demo"
# Initialize Pinecone using the new class-based method
pc = pinecone.Pinecone(api_key=pinecone_api_key)
# Initialize AWS S3 client
s3_client = boto3.client(
's3',
aws_access_key_id=aws_access_key,
aws_secret_access_key=aws_secret_key,
region_name='us-east-1'
)
# Initialize Pinecone index (check if it exists, otherwise create it)
def initialize_pinecone_index(index_name, embedding_dim):
available_indexes = pc.list_indexes().names()
if index_name not in available_indexes:
pc.create_index(
name=index_name,
dimension=embedding_dim,
metric="cosine",
spec=pinecone.ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
return pc.Index(index_name)
embedding_dim = 768
index = initialize_pinecone_index(index_name, embedding_dim)
# Initialize HuggingFace embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarco-distilbert-base-v4")
# Download and load text from S3
def download_text_from_s3(s3_client, bucket_name, file_name):
local_txt_path = os.path.join(os.getcwd(), file_name)
s3_client.download_file(bucket_name, file_name, local_txt_path)
with open(local_txt_path, 'r', encoding='utf-8') as f:
return f.read()
doc_text = download_text_from_s3(s3_client, bucket_name, txt_file_name)
# Split and embed the document text
def process_text_into_embeddings(doc_text):
text_splitter = CharacterTextSplitter(separator='\n', chunk_size=3000, chunk_overlap=500)
docs = text_splitter.split_documents([Document(page_content=doc_text)])
doc_embeddings = embedding_model.embed_documents([doc.page_content for doc in docs])
return docs, doc_embeddings
# Check if embeddings already exist in Pinecone
def check_embeddings_in_pinecone(index):
try:
stats = index.describe_index_stats()
return stats['total_vector_count'] > 0
except Exception as e:
print(f"Error checking Pinecone index: {e}")
return False
# Only process embeddings if they don't already exist in Pinecone
if not check_embeddings_in_pinecone(index):
split_docs, doc_embeddings = process_text_into_embeddings(doc_text)
for i, doc in enumerate(split_docs):
metadata = {'content': doc.page_content}
index.upsert(vectors=[(str(i), doc_embeddings[i], metadata)])
else:
print("Embeddings already exist in Pinecone. Skipping embedding process.")
# Query Pinecone and OpenAI GPT-4 to generate a response
def get_model_response(human_input, chat_history=None):
try:
# Embed the query using the embedding model
query_embedding = embedding_model.embed_query(human_input)
# Query Pinecone index to retrieve relevant content
search_results = index.query(vector=query_embedding, top_k=3, include_metadata=True)
# Prepare content and image data
context_list = []
images = []
# Extract the content from Pinecone's search results
for ind, result in enumerate(search_results['matches']):
document_content = result.get('metadata', {}).get('content', 'No content found')
image_url = result.get('metadata', {}).get('image_path', None)
figure_desc = result.get('metadata', {}).get('figure_description', '')
context_list.append(f"Document {ind+1}: {document_content}")
if image_url and figure_desc: # Only append images that exist and have description
images.append((figure_desc, image_url))
# Combine context from the search results
context_string = '\n\n'.join(context_list)
# Build messages list for OpenAI
messages = [
{"role": "system", "content": "You are a helpful assistant."}, # System prompt
{"role": "user", "content": f"Here is some context:\n{context_string}\n\nUser's question: {human_input}"}
]
# Send the conversation to OpenAI's API, using GPT-3.5 instead of GPT-4
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
max_tokens=500,
temperature=0.5
)
# Get the model's response
output_text = response['choices'][0]['message']['content'].strip()
# Return both the output and any images found
return output_text, images
except Exception as e:
return f"Error invoking model: {str(e)}", []
# Function to format text and images for display
def get_model_response_with_history(human_input, chat_history=None):
if chat_history is None:
chat_history = []
output_text, chat_history = get_model_response(human_input, chat_history)
# Handle image display
def process_image(image_data):
if isinstance(image_data, list):
# If a list is passed, flatten it to a string
return " ".join(str(item) for item in image_data)
return str(image_data)
if chat_history:
# Ensure that any file/image alt_text is handled correctly
for message in chat_history:
if "alt_text" in message:
message["alt_text"] = process_image(message.get("alt_text", ""))
return output_text
# Set up Gradio interface without share=True to avoid the error for now
gr_interface = gr.ChatInterface(
fn=get_model_response_with_history,
title="Maintenance Assistant",
description="Ask questions related to the RMM documents."
)
# Launch the Gradio interface
gr_interface.launch() |