File size: 5,225 Bytes
b405caf
 
 
 
 
 
 
48a23ed
b405caf
 
 
 
 
 
 
 
 
 
bc05c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b405caf
48a23ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b405caf
e8cdaa1
 
48a23ed
e8cdaa1
48a23ed
 
 
 
 
 
 
e8cdaa1
 
48a23ed
 
b405caf
48a23ed
 
b405caf
 
 
 
 
 
48a23ed
 
b405caf
 
 
 
 
 
 
 
 
48a23ed
b405caf
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st
import pandas as pd
import os
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
import tempfile
import matplotlib.pyplot as plt
from datasets import load_dataset

# Load environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")

# Ensure OpenAI API key is provided
if not openai_api_key:
    st.error("OpenAI API key is not set. Please add it to a .env file.")
    st.stop()

# Initialize the LLM
#llm = OpenAI(api_token=openai_api_key)

def initialize_llm(model_choice):
    groq_api_key = os.getenv("GROQ_API_KEY")
    openai_api_key = os.getenv("OPENAI_API_KEY")

    if model_choice == "llama-3.3-70b":
        if not groq_api_key:
            st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
            return None
        return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
    elif model_choice == "GPT-4o":
        if not openai_api_key:
            st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
            return None
        return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")

model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
llm = initialize_llm(model_choice)


def load_dataset_into_session():
    input_option = st.radio(
        "Select Dataset Input:",
        ["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"],
    )

    # Option 1: Load dataset from the repo directory
    if input_option == "Use Repo Directory Dataset":
        file_path = "./source/test.csv"
        if st.button("Load Dataset"):
            try:
                st.session_state.df = pd.read_csv(file_path)
                st.success(f"File loaded successfully from '{file_path}'!")
            except Exception as e:
                st.error(f"Error loading dataset from the repo directory: {e}")

    # Option 2: Load dataset from Hugging Face
    elif input_option == "Use Hugging Face Dataset":
        dataset_name = st.text_input(
            "Enter Hugging Face Dataset Name:", value="HUPD/hupd"
        )
        if st.button("Load Hugging Face Dataset"):
            try:
                dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
                if hasattr(dataset, "to_pandas"):
                    st.session_state.df = dataset.to_pandas()
                else:
                    st.session_state.df = pd.DataFrame(dataset)
                st.session_state.df = validate_and_clean_dataset(st.session_state.df)
                st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
            except Exception as e:
                st.error(f"Error loading Hugging Face dataset: {e}")

    # Option 3: Upload CSV File
    elif input_option == "Upload CSV File":
        uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
        if uploaded_file:
            try:
                st.session_state.df = pd.read_csv(uploaded_file)
                st.session_state.df = validate_and_clean_dataset(st.session_state.df)
                st.success("File uploaded successfully!")
            except Exception as e:
                st.error(f"Error reading uploaded file: {e}")

st.title("Chat with Patent Dataset Using PandasAI")

# Instructions
with st.sidebar:
    st.header("Instructions:")
    st.markdown(
        "1. Select how you want to input the dataset.\n"
        "2. Upload, select, or fetch the dataset using the provided options.\n"
        "3. Enter a question to interact with the patent data.\n"
        "   - Example: 'Predict if the patent will be accepted.'\n"
        "   - Example: 'What is the primary classification of this patent?'\n"
        "   - Example: 'Summarize the abstract of this patent.'\n"
        "4. Enter a query to generate and view graphs based on patent attributes.\n"
    )

# Load dataset into session
load_dataset_into_session()

if "df" in st.session_state:
    df = st.session_state.df
    st.write("### Data Preview")
    st.dataframe(df.head(10))

    # Create SmartDataFrame
    chat_df = SmartDataframe(df, config={"llm": llm})

    st.write("### Chat with Your Patent Data")
    user_query = st.text_input("Enter your question about the patent data (e.g., 'Predict if the patent will be accepted.'):")

    if user_query:
        try:
            response = chat_df.chat(user_query)
            st.success(f"Response: {response}")
        except Exception as e:
            st.error(f"Error: {e}")

    st.write("### Generate and View Graphs")
    plot_query = st.text_input("Enter a query to generate a graph (e.g., 'Plot the number of patents by filing year.'):")

    if plot_query:
        try:
            with tempfile.TemporaryDirectory() as temp_dir:
                # PandasAI can handle plotting
                chat_df.chat(plot_query)

                # Save and display the plot
                temp_plot_path = os.path.join(temp_dir, "plot.png")
                plt.savefig(temp_plot_path)
                st.image(temp_plot_path, caption="Generated Plot", use_column_width=True)

        except Exception as e:
            st.error(f"Error: {e}")