File size: 7,084 Bytes
b405caf
 
 
 
 
 
 
48a23ed
8b2a8c3
 
888ea53
b405caf
 
 
8b2a8c3
b405caf
ea65a33
 
8b2a8c3
bc05c29
 
 
 
 
 
 
 
 
 
 
 
b11659d
 
 
48a23ed
888ea53
 
 
 
 
 
 
da8bfb8
888ea53
 
 
 
 
 
5987bcb
888ea53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
981ebc4
 
48a23ed
 
 
981ebc4
48a23ed
 
 
 
 
 
 
5987bcb
888ea53
48a23ed
 
 
 
 
 
 
 
 
b0290f9
48a23ed
888ea53
48a23ed
 
 
 
 
 
 
 
 
888ea53
48a23ed
 
 
 
 
 
b405caf
8b2a8c3
48a23ed
981ebc4
 
 
 
 
 
 
 
3e59218
981ebc4
 
b405caf
 
 
 
981ebc4
f1dbd3a
c47fc55
b405caf
 
 
 
 
 
 
 
981ebc4
b405caf
c47fc55
b405caf
 
 
 
 
 
 
 
 
 
f6cdc89
b405caf
 
 
b11659d
981ebc4
f6164cd
 
 
 
 
 
 
981ebc4
 
b11659d
4189793
b11659d
dcbbc0e
b11659d
8395e91
b11659d
 
 
 
f64187e
4189793
f64187e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import streamlit as st
import pandas as pd
import os
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
import tempfile
import matplotlib.pyplot as plt
from datasets import load_dataset
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
import time

# Load environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
groq_api_key = os.getenv("GROQ_API_KEY")

st.title("Chat with Patent Dataset Using PandasAI")

# Initialize the LLM based on user selection
def initialize_llm(model_choice):
    if model_choice == "llama-3.3-70b":
        if not groq_api_key:
            st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
            return None
        return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
    elif model_choice == "GPT-4o":
        if not openai_api_key:
            st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
            return None
        return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")

# Select LLM model
model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
llm = initialize_llm(model_choice)

# Dataset loading without caching to support progress bar
def load_huggingface_dataset(dataset_name):
    # Initialize progress bar
    progress_bar = st.progress(0)
    try:
        # Incrementally update progress
        progress_bar.progress(10)
        dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True, uniform_split=True)
        progress_bar.progress(50)
        if hasattr(dataset, "to_pandas"):
            df = dataset.to_pandas()
        else:
            df = pd.DataFrame(dataset)
        progress_bar.progress(100)  # Final update to 100%
        return df
    except Exception as e:
        progress_bar.progress(0)  # Reset progress bar on failure
        raise e

def load_uploaded_csv(uploaded_file):
    # Initialize progress bar
    progress_bar = st.progress(0)
    try:
        # Simulate progress
        progress_bar.progress(10)
        time.sleep(1)  # Simulate file processing delay
        progress_bar.progress(50)
        df = pd.read_csv(uploaded_file)
        progress_bar.progress(100)  # Final update
        return df
    except Exception as e:
        progress_bar.progress(0)  # Reset progress bar on failure
        raise e

# Dataset selection logic
def load_dataset_into_session():
    input_option = st.radio(
        "Select Dataset Input:",
        ["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"], index=1, horizontal=True
    )

    # Option 1: Load dataset from the repo directory
    if input_option == "Use Repo Directory Dataset":
        file_path = "./source/test.csv"
        if st.button("Load Dataset"):
            try:
                with st.spinner("Loading dataset from the repo directory..."):
                    st.session_state.df = pd.read_csv(file_path)
                st.success(f"File loaded successfully from '{file_path}'!")
            except Exception as e:
                st.error(f"Error loading dataset from the repo directory: {e}")

    # Option 2: Load dataset from Hugging Face
    elif input_option == "Use Hugging Face Dataset":
        dataset_name = st.text_input(
            "Enter Hugging Face Dataset Name:", value="HUPD/hupd"
        )
        if st.button("Load Dataset"):
            try:
                st.session_state.df = load_huggingface_dataset(dataset_name)
                st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
            except Exception as e:
                st.error(f"Error loading Hugging Face dataset: {e}")

    # Option 3: Upload CSV File
    elif input_option == "Upload CSV File":
        uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
        if uploaded_file:
            try:
                st.session_state.df = load_uploaded_csv(uploaded_file)
                st.success("File uploaded successfully!")
            except Exception as e:
                st.error(f"Error reading uploaded file: {e}")

# Load dataset into session
load_dataset_into_session()

if "df" in st.session_state and llm:
    df = st.session_state.df

    # Display dataset metadata
    st.write("### Dataset Metadata")
    st.text(f"Number of Rows: {df.shape[0]}")
    st.text(f"Number of Columns: {df.shape[1]}")
    st.text(f"Column Names: {', '.join(df.columns)}")

    # Display dataset preview
    st.write("### Dataset Preview")
    num_rows = st.slider("Select number of rows to display:", min_value=5, max_value=50, value=10)
    st.dataframe(df.head(num_rows))

    # Create SmartDataFrame
    chat_df = SmartDataframe(df, config={"llm": llm})

    # Chat functionality
    st.write("### Chat with Patent Data")
    user_query = st.text_input("Enter your question about the patent data:", value = "Have the patents with the numbers 14908945, 14994130, 14909084, and 14995057 been accepted or rejected? What are their titles?")

    if user_query:
        try:
            response = chat_df.chat(user_query)
            st.success(f"Response: {response}")
        except Exception as e:
            st.error(f"Error: {e}")

    # Plot generation functionality
    st.write("### Generate and View Graphs")
    plot_query = st.text_input("Enter a query to generate a graph:", value = "What is the distribution of patents categorized as 'ACCEPTED', 'REJECTED', or 'PENDING'?")

    if plot_query:
        try:
            with tempfile.TemporaryDirectory() as temp_dir:
                # PandasAI can handle plotting
                chat_df.chat(plot_query)

                # Save and display the plot
                temp_plot_path = os.path.join(temp_dir, "plot.png")
                plt.savefig(temp_plot_path)
                st.image(temp_plot_path, caption="Generated Plot", use_container_width=True)

        except Exception as e:
            st.error(f"Error: {e}")

    # Download processed dataset
    #st.write("### Download Processed Dataset")
    #st.download_button(
     #   label="Download Dataset as CSV",
     #   data=df.to_csv(index=False),
     #   file_name="processed_dataset.csv",
     #   mime="text/csv"
    #)

# Sidebar instructions
with st.sidebar:
    st.header("πŸ“‹ Instructions:")
    st.markdown(
        "1. Choose an LLM (Groq-based or OpenAI-based) to interact with the data.\n"
        "2. Upload, select, or fetch the dataset using the provided options.\n"
        "3. Enter a query to generate and view graphs based on patent attributes.\n"
        "   - Example: 'Predict if the patent will be accepted.'\n"
        "   - Example: 'What is the primary classification of this patent?'\n"
        "   - Example: 'Summarize the abstract of this patent.'\n"
    )
    st.markdown("---")
    st.header("πŸ“š References:")
    st.markdown(
        "1. [Chat With Your CSV File With PandasAI - Prince Krampah](https://medium.com/aimonks/chat-with-your-csv-file-with-pandasai-22232a13c7b7)"
    )