DrishtiSharma's picture
Update app.py
bc05c29 verified
raw
history blame
5.23 kB
import streamlit as st
import pandas as pd
import os
from pandasai import SmartDataframe
from pandasai.llm import OpenAI
import tempfile
import matplotlib.pyplot as plt
from datasets import load_dataset
# Load environment variables
openai_api_key = os.getenv("OPENAI_API_KEY")
# Ensure OpenAI API key is provided
if not openai_api_key:
st.error("OpenAI API key is not set. Please add it to a .env file.")
st.stop()
# Initialize the LLM
#llm = OpenAI(api_token=openai_api_key)
def initialize_llm(model_choice):
groq_api_key = os.getenv("GROQ_API_KEY")
openai_api_key = os.getenv("OPENAI_API_KEY")
if model_choice == "llama-3.3-70b":
if not groq_api_key:
st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
return None
return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
elif model_choice == "GPT-4o":
if not openai_api_key:
st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
return None
return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
llm = initialize_llm(model_choice)
def load_dataset_into_session():
input_option = st.radio(
"Select Dataset Input:",
["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"],
)
# Option 1: Load dataset from the repo directory
if input_option == "Use Repo Directory Dataset":
file_path = "./source/test.csv"
if st.button("Load Dataset"):
try:
st.session_state.df = pd.read_csv(file_path)
st.success(f"File loaded successfully from '{file_path}'!")
except Exception as e:
st.error(f"Error loading dataset from the repo directory: {e}")
# Option 2: Load dataset from Hugging Face
elif input_option == "Use Hugging Face Dataset":
dataset_name = st.text_input(
"Enter Hugging Face Dataset Name:", value="HUPD/hupd"
)
if st.button("Load Hugging Face Dataset"):
try:
dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
if hasattr(dataset, "to_pandas"):
st.session_state.df = dataset.to_pandas()
else:
st.session_state.df = pd.DataFrame(dataset)
st.session_state.df = validate_and_clean_dataset(st.session_state.df)
st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
except Exception as e:
st.error(f"Error loading Hugging Face dataset: {e}")
# Option 3: Upload CSV File
elif input_option == "Upload CSV File":
uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
if uploaded_file:
try:
st.session_state.df = pd.read_csv(uploaded_file)
st.session_state.df = validate_and_clean_dataset(st.session_state.df)
st.success("File uploaded successfully!")
except Exception as e:
st.error(f"Error reading uploaded file: {e}")
st.title("Chat with Patent Dataset Using PandasAI")
# Instructions
with st.sidebar:
st.header("Instructions:")
st.markdown(
"1. Select how you want to input the dataset.\n"
"2. Upload, select, or fetch the dataset using the provided options.\n"
"3. Enter a question to interact with the patent data.\n"
" - Example: 'Predict if the patent will be accepted.'\n"
" - Example: 'What is the primary classification of this patent?'\n"
" - Example: 'Summarize the abstract of this patent.'\n"
"4. Enter a query to generate and view graphs based on patent attributes.\n"
)
# Load dataset into session
load_dataset_into_session()
if "df" in st.session_state:
df = st.session_state.df
st.write("### Data Preview")
st.dataframe(df.head(10))
# Create SmartDataFrame
chat_df = SmartDataframe(df, config={"llm": llm})
st.write("### Chat with Your Patent Data")
user_query = st.text_input("Enter your question about the patent data (e.g., 'Predict if the patent will be accepted.'):")
if user_query:
try:
response = chat_df.chat(user_query)
st.success(f"Response: {response}")
except Exception as e:
st.error(f"Error: {e}")
st.write("### Generate and View Graphs")
plot_query = st.text_input("Enter a query to generate a graph (e.g., 'Plot the number of patents by filing year.'):")
if plot_query:
try:
with tempfile.TemporaryDirectory() as temp_dir:
# PandasAI can handle plotting
chat_df.chat(plot_query)
# Save and display the plot
temp_plot_path = os.path.join(temp_dir, "plot.png")
plt.savefig(temp_plot_path)
st.image(temp_plot_path, caption="Generated Plot", use_column_width=True)
except Exception as e:
st.error(f"Error: {e}")