DrishtiSharma commited on
Commit
f84b0ed
·
verified ·
1 Parent(s): 51ae25c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -22
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import streamlit as st
3
  import pandas as pd
4
  from datasets import load_dataset
@@ -19,7 +20,7 @@ def load_huggingface_dataset(dataset_name):
19
  try:
20
  # Incrementally update progress
21
  progress_bar.progress(10)
22
- dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True, uniform_split=True)
23
  progress_bar.progress(50)
24
  if hasattr(dataset, "to_pandas"):
25
  df = dataset.to_pandas()
@@ -107,15 +108,16 @@ if "df" in st.session_state:
107
  st.header("Run Queries on Patent Data")
108
 
109
  with st.spinner("Setting up LangChain CSV Agent..."):
110
- df.to_csv("patent_data.csv", index=False)
111
-
112
- csv_agent = create_csv_agent(
113
- ChatOpenAI(temperature=0, model="gpt-4", api_key=os.getenv("OPENAI_API_KEY")),
114
- path=["patent_data.csv"],
115
- verbose=True,
116
- agent_type=AgentType.OPENAI_FUNCTIONS,
117
- allow_dangerous_code=True
118
- )
 
119
 
120
  # Query Input and Execution
121
  query = st.text_area("Enter your natural language query:", "How many patents are related to AI?")
@@ -123,24 +125,27 @@ if "df" in st.session_state:
123
  if st.button("Run Query"):
124
  with st.spinner("Running query..."):
125
  try:
126
- # Check if the dataset is too large and split if needed
127
- max_rows = 1000 # Limit chunk size to manage token limits
128
  total_rows = len(df)
129
 
130
  if total_rows > max_rows:
131
  results = []
132
  for start in range(0, total_rows, max_rows):
133
  chunk = df.iloc[start:start + max_rows]
134
- chunk.to_csv("chunk_data.csv", index=False)
135
- partial_agent = create_csv_agent(
136
- ChatOpenAI(temperature=0, model="gpt-4", api_key=os.getenv("OPENAI_API_KEY")),
137
- path=["chunk_data.csv"],
138
- verbose=True,
139
- agent_type=AgentType.OPENAI_FUNCTIONS,
140
- allow_dangerous_code=True
141
- )
142
- result = partial_agent.invoke(query)
143
- results.append(result)
 
 
 
144
 
145
  st.success("Query executed successfully!")
146
  st.write("### Combined Query Results:")
 
1
  import os
2
+ import tempfile
3
  import streamlit as st
4
  import pandas as pd
5
  from datasets import load_dataset
 
20
  try:
21
  # Incrementally update progress
22
  progress_bar.progress(10)
23
+ dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True)
24
  progress_bar.progress(50)
25
  if hasattr(dataset, "to_pandas"):
26
  df = dataset.to_pandas()
 
108
  st.header("Run Queries on Patent Data")
109
 
110
  with st.spinner("Setting up LangChain CSV Agent..."):
111
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as temp_file:
112
+ df.to_csv(temp_file.name, index=False)
113
+
114
+ csv_agent = create_csv_agent(
115
+ ChatOpenAI(temperature=0, model="gpt-4", api_key=os.getenv("OPENAI_API_KEY")),
116
+ path=[temp_file.name],
117
+ verbose=True,
118
+ agent_type=AgentType.OPENAI_FUNCTIONS,
119
+ allow_dangerous_code=True
120
+ )
121
 
122
  # Query Input and Execution
123
  query = st.text_area("Enter your natural language query:", "How many patents are related to AI?")
 
125
  if st.button("Run Query"):
126
  with st.spinner("Running query..."):
127
  try:
128
+ # Token limit configuration
129
+ max_rows = 1000 # Adjust chunk size dynamically
130
  total_rows = len(df)
131
 
132
  if total_rows > max_rows:
133
  results = []
134
  for start in range(0, total_rows, max_rows):
135
  chunk = df.iloc[start:start + max_rows]
136
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as chunk_file:
137
+ chunk.to_csv(chunk_file.name, index=False)
138
+
139
+ # Update the agent dynamically with the chunk
140
+ csv_agent = create_csv_agent(
141
+ ChatOpenAI(temperature=0, model="gpt-4", api_key=os.getenv("OPENAI_API_KEY")),
142
+ path=[chunk_file.name],
143
+ verbose=False,
144
+ agent_type=AgentType.OPENAI_FUNCTIONS,
145
+ allow_dangerous_code=False
146
+ )
147
+ result = csv_agent.invoke(query)
148
+ results.append(result)
149
 
150
  st.success("Query executed successfully!")
151
  st.write("### Combined Query Results:")