File size: 5,983 Bytes
30cb161
 
 
 
 
 
 
b837587
 
e5e0026
 
 
 
 
30cb161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79041c3
30cb161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1a7916d
 
 
e5e0026
 
 
 
 
 
1a7916d
 
 
 
 
 
 
 
30cb161
1a7916d
30cb161
e5e0026
 
 
1a7916d
30cb161
 
 
e5e0026
 
30cb161
 
 
1a7916d
 
 
 
e5e0026
 
 
 
 
 
 
 
 
1a7916d
30cb161
1a7916d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30cb161
 
 
 
 
 
1a7916d
30cb161
1a7916d
 
 
 
30cb161
1a7916d
30cb161
 
 
1a7916d
 
 
30cb161
79041c3
30cb161
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import os
import pinecone
import openai
import gradio as gr
import torch
from dotenv import load_dotenv
from pinecone import Pinecone
from langchain_community.embeddings import HuggingFaceEmbeddings  # Updated import
from rapidfuzz import fuzz  # Replaced fuzzywuzzy with rapidfuzz
import logging
import re  # To help with preprocessing

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Detect GPU availability and set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on device: {device}")

# Suppress specific warning about clean_up_tokenization_spaces
import warnings
warnings.filterwarnings("ignore", category=FutureWarning, message="clean_up_tokenization_spaces was not set")

# Load environment variables
load_dotenv()

# Access Pinecone and OpenAI API keys from environment variables
pinecone_api_key = os.getenv("PINECONE_API_KEY")
openai.api_key = os.getenv("OPENAI_API_KEY")
index_name = "amtrak-acela-ai-demo"

# Initialize Pinecone using a class-based method
pc = Pinecone(api_key=pinecone_api_key)

# Check if the index exists, if not, create it
def initialize_pinecone_index(index_name):
    available_indexes = pc.list_indexes().names()
    if index_name not in available_indexes:
        print(f"Index '{index_name}' does not exist.")
        # Create the index here if necessary for ZeroGPU usage
    return pc.Index(index_name)

index = initialize_pinecone_index(index_name)

# Initialize HuggingFace embedding model
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarco-distilbert-base-v4")

# Initialize chat history manually
chat_history = []

# Helper function to preprocess text (removing unnecessary words)
def preprocess_text(text):
    # Convert text to lowercase and remove special characters
    text = re.sub(r'[^\w\s]', '', text.lower())
    return text.strip()

# Helper function to recursively flatten any list to a string
def flatten_to_string(data):
    if isinstance(data, list):
        return " ".join([flatten_to_string(item) for item in data])
    if data is None:
        return ""
    return str(data)

# Function to interact with Pinecone and OpenAI GPT-4
def get_model_response(human_input):
    try:
        # Preprocess the human input (cleaning up unnecessary words)
        processed_input = preprocess_text(human_input)

        # Embed the query
        query_embedding = torch.tensor(embedding_model.embed_query(human_input)).to(device)
        query_embedding = query_embedding.cpu().numpy().tolist()

        # Query Pinecone index with top_k=5 to get more potential matches
        search_results = index.query(vector=query_embedding, top_k=5, include_metadata=True)

        context_list, images = [], []
        for ind, result in enumerate(search_results['matches']):
            document_content = flatten_to_string(result.get('metadata', {}).get('content', 'No content found'))
            image_url = flatten_to_string(result.get('metadata', {}).get('image_path', None))
            figure_desc = flatten_to_string(result.get('metadata', {}).get('figure_description', ''))

            # Preprocess the figure description and match keywords
            processed_figure_desc = preprocess_text(figure_desc)
            similarity_score = fuzz.token_set_ratio(processed_input, processed_figure_desc)
            logging.info(f"Matching '{processed_input}' with '{processed_figure_desc}', similarity score: {similarity_score}")

            if similarity_score >= 80:  # Keep the threshold at 80 for now
                context_list.append(f"Relevant information: {document_content}")
                if image_url and figure_desc:
                    images.append((figure_desc, image_url))

        context_string = '\n\n'.join(context_list)

        # Add user message to chat history
        chat_history.append({"role": "user", "content": human_input})

        # Create messages for OpenAI's API
        messages = [{"role": "system", "content": "You are a helpful assistant."}] + chat_history + [
            {"role": "system", "content": f"Here is some context:\n{context_string}"},
            {"role": "user", "content": human_input}
        ]

        # Validate messages before sending to OpenAI
        for message in messages:
            if not isinstance(message, dict) or "role" not in message or "content" not in message:
                raise ValueError(f"Invalid message format: {message}")

        # Send the conversation to OpenAI's API
        response = openai.ChatCompletion.create(
            model="gpt-3.5-turbo", 
            messages=messages, 
            max_tokens=500, 
            temperature=0.5
        )

        output_text = response['choices'][0]['message']['content'].strip()

        # Add assistant message to chat history
        chat_history.append({"role": "assistant", "content": output_text})

        return output_text, images

    except Exception as e:
        return f"Error invoking model: {str(e)}", []

# Function to format text and images for display and track conversation
def get_model_response_with_images(human_input, history=None):
    output_text, images = get_model_response(human_input)
    if images:
        # Append images in Markdown format for Gradio to render
        image_output = "".join([f"\n\n**{figure_desc}**\n![{figure_desc}]({image_path})" for figure_desc, image_path in images])
        return output_text + image_output
    return output_text

# Set up Gradio interface
gr_interface = gr.ChatInterface(
    fn=get_model_response_with_images, 
    title="Maintenance Assistant", 
    description="Ask questions related to the RMMM documents."
)

# Ensure ZeroGPU or Hugging Face Spaces handles launching properly
if __name__ == "__main__":
    gr_interface.launch()