File size: 5,376 Bytes
22324df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c9d8a81
22324df
 
 
 
 
 
 
 
 
 
c9d8a81
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import streamlit as st
import pandas as pd
import os
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
import tempfile
import matplotlib.pyplot as plt
from datasets import load_dataset
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI

# Load environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
groq_api_key = os.getenv("GROQ_API_KEY")

st.title("Chat with Patent Dataset Using PandasAI")

# Initialize the LLM based on user selection
def initialize_llm(model_choice):
    if model_choice == "llama-3.3-70b":
        if not groq_api_key:
            st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
            return None
        return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
    elif model_choice == "GPT-4o":
        if not openai_api_key:
            st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
            return None
        return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")

# Select LLM model
model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
llm = initialize_llm(model_choice)

def load_dataset_into_session():
    input_option = st.radio(
        "Select Dataset Input:",
        ["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"], index=0, horizontal=True
    )

    # Option 1: Load dataset from the repo directory
    if input_option == "Use Repo Directory Dataset":
        file_path = "./source/test.csv"
        if st.button("Load Dataset"):
            try:
                st.session_state.df = pd.read_csv(file_path)
                st.success(f"File loaded successfully from '{file_path}'!")
            except Exception as e:
                st.error(f"Error loading dataset from the repo directory: {e}")

    # Option 2: Load dataset from Hugging Face
    elif input_option == "Use Hugging Face Dataset":
        dataset_name = st.text_input(
            "Enter Hugging Face Dataset Name:", value="HUPD/hupd"
        )
        if st.button("Load Hugging Face Dataset"):
            try:
                dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
                if hasattr(dataset, "to_pandas"):
                    st.session_state.df = dataset.to_pandas()
                else:
                    st.session_state.df = pd.DataFrame(dataset)
                st.session_state.df = validate_and_clean_dataset(st.session_state.df)
                st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
            except Exception as e:
                st.error(f"Error loading Hugging Face dataset: {e}")

    # Option 3: Upload CSV File
    elif input_option == "Upload CSV File":
        uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
        if uploaded_file:
            try:
                st.session_state.df = pd.read_csv(uploaded_file)
                st.session_state.df = validate_and_clean_dataset(st.session_state.df)
                st.success("File uploaded successfully!")
            except Exception as e:
                st.error(f"Error reading uploaded file: {e}")

# Load dataset into session
load_dataset_into_session()

if "df" in st.session_state and llm:
    df = st.session_state.df
    st.write("### Data Preview")
    st.dataframe(df.head(10))

    # Create SmartDataFrame
    chat_df = SmartDataframe(df, config={"llm": llm})

    st.write("### Chat with Your Patent Data")
    user_query = st.text_input("Enter your question about the patent data (e.g., 'Predict if the patent will be accepted.'):")

    if user_query:
        try:
            response = chat_df.chat(user_query)
            st.success(f"Response: {response}")
        except Exception as e:
            st.error(f"Error: {e}")

    st.write("### Generate and View Graphs")
    plot_query = st.text_input("Enter a query to generate a graph (e.g., 'Plot the number of patents by filing year.'):")

    if plot_query:
        try:
            with tempfile.TemporaryDirectory() as temp_dir:
                # PandasAI can handle plotting
                chat_df.chat(plot_query)

                # Save and display the plot
                temp_plot_path = os.path.join(temp_dir, "plot.png")
                plt.savefig(temp_plot_path)
                st.image(temp_plot_path, caption="Generated Plot", use_column_width=True)

        except Exception as e:
            st.error(f"Error: {e}")

# Instructions
with st.sidebar:
    st.header("Instructions")
    st.markdown(
        "1. Select how you want to input the dataset.\n"
        "2. Upload, select, or fetch the dataset using the provided options.\n"
        "3. Choose an LLM (Groq-based or OpenAI-based) to interact with the data.\n"
        "   - Example: 'Predict if the patent will be accepted.'\n"
        "   - Example: 'What is the primary classification of this patent?'\n"
        "   - Example: 'Summarize the abstract of this patent.'\n"
        "4. Enter a query to generate and view graphs based on patent attributes.\n"
    )
    st.markdown("---")
    st.header("References")
    st.markdown(
        "1. [Chat With Your CSV File With PandasAI - Prince Krampah](https://medium.com/aimonks/chat-with-your-csv-file-with-pandasai-22232a13c7b7)"
    )