File size: 6,239 Bytes
b405caf 48a23ed 8b2a8c3 b405caf 8b2a8c3 b405caf ea65a33 8b2a8c3 bc05c29 b11659d 48a23ed 981ebc4 48a23ed 981ebc4 48a23ed 981ebc4 48a23ed b0290f9 48a23ed 981ebc4 48a23ed 981ebc4 48a23ed b405caf 8b2a8c3 48a23ed 981ebc4 3e59218 981ebc4 b405caf 981ebc4 48a23ed b405caf 981ebc4 b405caf 48a23ed b405caf f6cdc89 b405caf b11659d 981ebc4 b11659d 14fde95 b11659d 981ebc4 b11659d f64187e 14fde95 f64187e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 |
import streamlit as st
import pandas as pd
import os
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
import tempfile
import matplotlib.pyplot as plt
from datasets import load_dataset
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
# Load environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
groq_api_key = os.getenv("GROQ_API_KEY")
st.title("Chat with Patent Dataset Using PandasAI")
# Initialize the LLM based on user selection
def initialize_llm(model_choice):
if model_choice == "llama-3.3-70b":
if not groq_api_key:
st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
return None
return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
elif model_choice == "GPT-4o":
if not openai_api_key:
st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
return None
return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
# Select LLM model
model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
llm = initialize_llm(model_choice)
# Cache dataset loading
@st.cache_data
def load_repo_dataset(file_path):
return pd.read_csv(file_path)
@st.cache_data
def load_huggingface_dataset(dataset_name):
dataset = load_dataset(dataset_name, name="all", split="train", trust_remote_code=True, uniform_split=True)
if hasattr(dataset, "to_pandas"):
return dataset.to_pandas()
return pd.DataFrame(dataset)
@st.cache_data
def load_uploaded_csv(uploaded_file):
return pd.read_csv(uploaded_file)
# Dataset selection logic
def load_dataset_into_session():
input_option = st.radio(
"Select Dataset Input:",
["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"], index=1, horizontal=True
)
# Option 1: Load dataset from the repo directory
if input_option == "Use Repo Directory Dataset":
file_path = "./source/test.csv"
if st.button("Load Dataset"):
try:
st.session_state.df = load_repo_dataset(file_path)
st.success(f"File loaded successfully from '{file_path}'!")
except Exception as e:
st.error(f"Error loading dataset from the repo directory: {e}")
# Option 2: Load dataset from Hugging Face
elif input_option == "Use Hugging Face Dataset":
dataset_name = st.text_input(
"Enter Hugging Face Dataset Name:", value="HUPD/hupd"
)
if st.button("Load Dataset"):
try:
st.session_state.df = load_huggingface_dataset(dataset_name)
st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
except Exception as e:
st.error(f"Error loading Hugging Face dataset: {e}")
# Option 3: Upload CSV File
elif input_option == "Upload CSV File":
uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
if uploaded_file:
try:
st.session_state.df = load_uploaded_csv(uploaded_file)
st.success("File uploaded successfully!")
except Exception as e:
st.error(f"Error reading uploaded file: {e}")
# Load dataset into session
load_dataset_into_session()
if "df" in st.session_state and llm:
df = st.session_state.df
# Display dataset metadata
st.write("### Dataset Metadata")
st.text(f"Number of Rows: {df.shape[0]}")
st.text(f"Number of Columns: {df.shape[1]}")
st.text(f"Column Names: {', '.join(df.columns)}")
# Display dataset preview
st.write("### Dataset Preview")
num_rows = st.slider("Select number of rows to display:", min_value=5, max_value=50, value=10)
st.dataframe(df.head(num_rows))
# Create SmartDataFrame
chat_df = SmartDataframe(df, config={"llm": llm})
# Chat functionality
st.write("### Chat with Your Patent Data")
user_query = st.text_input("Enter your question about the patent data (e.g., 'Predict if the patent will be accepted.'):")
if user_query:
try:
response = chat_df.chat(user_query)
st.success(f"Response: {response}")
except Exception as e:
st.error(f"Error: {e}")
# Plot generation functionality
st.write("### Generate and View Graphs")
plot_query = st.text_input("Enter a query to generate a graph (e.g., 'Plot the number of patents by filing year.'):")
if plot_query:
try:
with tempfile.TemporaryDirectory() as temp_dir:
# PandasAI can handle plotting
chat_df.chat(plot_query)
# Save and display the plot
temp_plot_path = os.path.join(temp_dir, "plot.png")
plt.savefig(temp_plot_path)
st.image(temp_plot_path, caption="Generated Plot", use_container_width=True)
except Exception as e:
st.error(f"Error: {e}")
# Download processed dataset
st.write("### Download Processed Dataset")
st.download_button(
label="Download Dataset as CSV",
data=df.to_csv(index=False),
file_name="processed_dataset.csv",
mime="text/csv"
)
# Sidebar instructions
with st.sidebar:
st.header("Instructions:")
st.markdown(
"1. Select how you want to input the dataset.\n"
"2. Upload, select, or fetch the dataset using the provided options.\n"
"3. Choose an LLM (Groq-based or OpenAI-based) to interact with the data.\n"
" - Example: 'Predict if the patent will be accepted.'\n"
" - Example: 'What is the primary classification of this patent?'\n"
" - Example: 'Summarize the abstract of this patent.'\n"
"4. Enter a query to generate and view graphs based on patent attributes.\n"
"5. Download the processed dataset as a CSV file."
)
st.markdown("---")
st.header("References:")
st.markdown(
"1. [Chat With Your CSV File With PandasAI - Prince Krampah](https://medium.com/aimonks/chat-with-your-csv-file-with-pandasai-22232a13c7b7)"
)
|