DrishtiSharma commited on
Commit
043fd83
·
verified ·
1 Parent(s): c9d8a81

Create interim.py

Browse files
Files changed (1) hide show
  1. lab/interim.py +134 -0
lab/interim.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ from pandasai import SmartDataframe
5
+ from pandasai.llm import OpenAI
6
+ import tempfile
7
+ import matplotlib.pyplot as plt
8
+ from datasets import load_dataset
9
+ from langchain_groq import ChatGroq
10
+ from langchain_openai import ChatOpenAI
11
+
12
+ # Load environment variables
13
+ openai_api_key = os.getenv("OPENAI_API_KEY")
14
+ groq_api_key = os.getenv("GROQ_API_KEY")
15
+
16
+ st.title("Chat with Patent Dataset Using PandasAI")
17
+
18
+ # Initialize the LLM based on user selection
19
+ def initialize_llm(model_choice):
20
+ if model_choice == "llama-3.3-70b":
21
+ if not groq_api_key:
22
+ st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
23
+ return None
24
+ return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
25
+ elif model_choice == "GPT-4o":
26
+ if not openai_api_key:
27
+ st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
28
+ return None
29
+ return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
30
+
31
+ # Select LLM model
32
+ model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
33
+ llm = initialize_llm(model_choice)
34
+
35
+ def load_dataset_into_session():
36
+ input_option = st.radio(
37
+ "Select Dataset Input:",
38
+ ["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"], index=0, horizontal=True
39
+ )
40
+
41
+ # Option 1: Load dataset from the repo directory
42
+ if input_option == "Use Repo Directory Dataset":
43
+ file_path = "./source/test.csv"
44
+ if st.button("Load Dataset"):
45
+ try:
46
+ st.session_state.df = pd.read_csv(file_path)
47
+ st.success(f"File loaded successfully from '{file_path}'!")
48
+ except Exception as e:
49
+ st.error(f"Error loading dataset from the repo directory: {e}")
50
+
51
+ # Option 2: Load dataset from Hugging Face
52
+ elif input_option == "Use Hugging Face Dataset":
53
+ dataset_name = st.text_input(
54
+ "Enter Hugging Face Dataset Name:", value="HUPD/hupd"
55
+ )
56
+ if st.button("Load Hugging Face Dataset"):
57
+ try:
58
+ dataset = load_dataset(dataset_name, split="train", trust_remote_code=True)
59
+ if hasattr(dataset, "to_pandas"):
60
+ st.session_state.df = dataset.to_pandas()
61
+ else:
62
+ st.session_state.df = pd.DataFrame(dataset)
63
+ st.session_state.df = validate_and_clean_dataset(st.session_state.df)
64
+ st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
65
+ except Exception as e:
66
+ st.error(f"Error loading Hugging Face dataset: {e}")
67
+
68
+ # Option 3: Upload CSV File
69
+ elif input_option == "Upload CSV File":
70
+ uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
71
+ if uploaded_file:
72
+ try:
73
+ st.session_state.df = pd.read_csv(uploaded_file)
74
+ st.session_state.df = validate_and_clean_dataset(st.session_state.df)
75
+ st.success("File uploaded successfully!")
76
+ except Exception as e:
77
+ st.error(f"Error reading uploaded file: {e}")
78
+
79
+ # Load dataset into session
80
+ load_dataset_into_session()
81
+
82
+ if "df" in st.session_state and llm:
83
+ df = st.session_state.df
84
+ st.write("### Data Preview")
85
+ st.dataframe(df.head(10))
86
+
87
+ # Create SmartDataFrame
88
+ chat_df = SmartDataframe(df, config={"llm": llm})
89
+
90
+ st.write("### Chat with Your Patent Data")
91
+ user_query = st.text_input("Enter your question about the patent data (e.g., 'Predict if the patent will be accepted.'):")
92
+
93
+ if user_query:
94
+ try:
95
+ response = chat_df.chat(user_query)
96
+ st.success(f"Response: {response}")
97
+ except Exception as e:
98
+ st.error(f"Error: {e}")
99
+
100
+ st.write("### Generate and View Graphs")
101
+ plot_query = st.text_input("Enter a query to generate a graph (e.g., 'Plot the number of patents by filing year.'):")
102
+
103
+ if plot_query:
104
+ try:
105
+ with tempfile.TemporaryDirectory() as temp_dir:
106
+ # PandasAI can handle plotting
107
+ chat_df.chat(plot_query)
108
+
109
+ # Save and display the plot
110
+ temp_plot_path = os.path.join(temp_dir, "plot.png")
111
+ plt.savefig(temp_plot_path)
112
+ st.image(temp_plot_path, caption="Generated Plot", use_column_width=True)
113
+
114
+ except Exception as e:
115
+ st.error(f"Error: {e}")
116
+
117
+ # Instructions
118
+ with st.sidebar:
119
+ st.header("Instructions")
120
+ st.markdown(
121
+ "1. Select how you want to input the dataset.\n"
122
+ "2. Upload, select, or fetch the dataset using the provided options.\n"
123
+ "3. Choose an LLM (Groq-based or OpenAI-based) to interact with the data.\n"
124
+ " - Example: 'Predict if the patent will be accepted.'\n"
125
+ " - Example: 'What is the primary classification of this patent?'\n"
126
+ " - Example: 'Summarize the abstract of this patent.'\n"
127
+ "4. Enter a query to generate and view graphs based on patent attributes.\n"
128
+ )
129
+ st.markdown("---")
130
+ st.header("References")
131
+ st.markdown(
132
+ "1. [Chat With Your CSV File With PandasAI - Prince Krampah](https://medium.com/aimonks/chat-with-your-csv-file-with-pandasai-22232a13c7b7)"
133
+ )
134
+