Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Create impressive.py
Browse files- lab/impressive.py +183 -0
    	
        lab/impressive.py
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import streamlit as st
         | 
| 2 | 
            +
            import pandas as pd
         | 
| 3 | 
            +
            import os
         | 
| 4 | 
            +
            from pandasai import SmartDataframe
         | 
| 5 | 
            +
            from pandasai.llm import OpenAI
         | 
| 6 | 
            +
            import tempfile
         | 
| 7 | 
            +
            import matplotlib.pyplot as plt
         | 
| 8 | 
            +
            from datasets import load_dataset
         | 
| 9 | 
            +
            from langchain_groq import ChatGroq
         | 
| 10 | 
            +
            from langchain_openai import ChatOpenAI
         | 
| 11 | 
            +
            import time
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            # Load environment variables
         | 
| 14 | 
            +
            openai_api_key = os.getenv("OPENAI_API_KEY")
         | 
| 15 | 
            +
            groq_api_key = os.getenv("GROQ_API_KEY")
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            st.title("Chat with Patent Dataset Using PandasAI")
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            # Initialize the LLM based on user selection
         | 
| 20 | 
            +
            def initialize_llm(model_choice):
         | 
| 21 | 
            +
                if model_choice == "llama-3.3-70b":
         | 
| 22 | 
            +
                    if not groq_api_key:
         | 
| 23 | 
            +
                        st.error("Groq API key is missing. Please set the GROQ_API_KEY environment variable.")
         | 
| 24 | 
            +
                        return None
         | 
| 25 | 
            +
                    return ChatGroq(groq_api_key=groq_api_key, model="groq/llama-3.3-70b-versatile")
         | 
| 26 | 
            +
                elif model_choice == "GPT-4o":
         | 
| 27 | 
            +
                    if not openai_api_key:
         | 
| 28 | 
            +
                        st.error("OpenAI API key is missing. Please set the OPENAI_API_KEY environment variable.")
         | 
| 29 | 
            +
                        return None
         | 
| 30 | 
            +
                    return ChatOpenAI(api_key=openai_api_key, model="gpt-4o")
         | 
| 31 | 
            +
             | 
| 32 | 
            +
            # Select LLM model
         | 
| 33 | 
            +
            model_choice = st.radio("Select LLM", ["GPT-4o", "llama-3.3-70b"], index=0, horizontal=True)
         | 
| 34 | 
            +
            llm = initialize_llm(model_choice)
         | 
| 35 | 
            +
             | 
| 36 | 
            +
            # Dataset loading without caching to support progress bar
         | 
| 37 | 
            +
            def load_huggingface_dataset(dataset_name):
         | 
| 38 | 
            +
                # Initialize progress bar
         | 
| 39 | 
            +
                progress_bar = st.progress(0)
         | 
| 40 | 
            +
                try:
         | 
| 41 | 
            +
                    # Incrementally update progress
         | 
| 42 | 
            +
                    progress_bar.progress(10)
         | 
| 43 | 
            +
                    dataset = load_dataset(dataset_name, name="sample", split="train", trust_remote_code=True, uniform_split=True)
         | 
| 44 | 
            +
                    progress_bar.progress(50)
         | 
| 45 | 
            +
                    if hasattr(dataset, "to_pandas"):
         | 
| 46 | 
            +
                        df = dataset.to_pandas()
         | 
| 47 | 
            +
                    else:
         | 
| 48 | 
            +
                        df = pd.DataFrame(dataset)
         | 
| 49 | 
            +
                    progress_bar.progress(100)  # Final update to 100%
         | 
| 50 | 
            +
                    return df
         | 
| 51 | 
            +
                except Exception as e:
         | 
| 52 | 
            +
                    progress_bar.progress(0)  # Reset progress bar on failure
         | 
| 53 | 
            +
                    raise e
         | 
| 54 | 
            +
             | 
| 55 | 
            +
            def load_uploaded_csv(uploaded_file):
         | 
| 56 | 
            +
                # Initialize progress bar
         | 
| 57 | 
            +
                progress_bar = st.progress(0)
         | 
| 58 | 
            +
                try:
         | 
| 59 | 
            +
                    # Simulate progress
         | 
| 60 | 
            +
                    progress_bar.progress(10)
         | 
| 61 | 
            +
                    time.sleep(1)  # Simulate file processing delay
         | 
| 62 | 
            +
                    progress_bar.progress(50)
         | 
| 63 | 
            +
                    df = pd.read_csv(uploaded_file)
         | 
| 64 | 
            +
                    progress_bar.progress(100)  # Final update
         | 
| 65 | 
            +
                    return df
         | 
| 66 | 
            +
                except Exception as e:
         | 
| 67 | 
            +
                    progress_bar.progress(0)  # Reset progress bar on failure
         | 
| 68 | 
            +
                    raise e
         | 
| 69 | 
            +
             | 
| 70 | 
            +
            # Dataset selection logic
         | 
| 71 | 
            +
            def load_dataset_into_session():
         | 
| 72 | 
            +
                input_option = st.radio(
         | 
| 73 | 
            +
                    "Select Dataset Input:",
         | 
| 74 | 
            +
                    ["Use Repo Directory Dataset", "Use Hugging Face Dataset", "Upload CSV File"], index=1, horizontal=True
         | 
| 75 | 
            +
                )
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                # Option 1: Load dataset from the repo directory
         | 
| 78 | 
            +
                if input_option == "Use Repo Directory Dataset":
         | 
| 79 | 
            +
                    file_path = "./source/test.csv"
         | 
| 80 | 
            +
                    if st.button("Load Dataset"):
         | 
| 81 | 
            +
                        try:
         | 
| 82 | 
            +
                            with st.spinner("Loading dataset from the repo directory..."):
         | 
| 83 | 
            +
                                st.session_state.df = pd.read_csv(file_path)
         | 
| 84 | 
            +
                            st.success(f"File loaded successfully from '{file_path}'!")
         | 
| 85 | 
            +
                        except Exception as e:
         | 
| 86 | 
            +
                            st.error(f"Error loading dataset from the repo directory: {e}")
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                # Option 2: Load dataset from Hugging Face
         | 
| 89 | 
            +
                elif input_option == "Use Hugging Face Dataset":
         | 
| 90 | 
            +
                    dataset_name = st.text_input(
         | 
| 91 | 
            +
                        "Enter Hugging Face Dataset Name:", value="HUPD/hupd"
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
                    if st.button("Load Dataset"):
         | 
| 94 | 
            +
                        try:
         | 
| 95 | 
            +
                            st.session_state.df = load_huggingface_dataset(dataset_name)
         | 
| 96 | 
            +
                            st.success(f"Hugging Face Dataset '{dataset_name}' loaded successfully!")
         | 
| 97 | 
            +
                        except Exception as e:
         | 
| 98 | 
            +
                            st.error(f"Error loading Hugging Face dataset: {e}")
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                # Option 3: Upload CSV File
         | 
| 101 | 
            +
                elif input_option == "Upload CSV File":
         | 
| 102 | 
            +
                    uploaded_file = st.file_uploader("Upload a CSV File:", type=["csv"])
         | 
| 103 | 
            +
                    if uploaded_file:
         | 
| 104 | 
            +
                        try:
         | 
| 105 | 
            +
                            st.session_state.df = load_uploaded_csv(uploaded_file)
         | 
| 106 | 
            +
                            st.success("File uploaded successfully!")
         | 
| 107 | 
            +
                        except Exception as e:
         | 
| 108 | 
            +
                            st.error(f"Error reading uploaded file: {e}")
         | 
| 109 | 
            +
             | 
| 110 | 
            +
            # Load dataset into session
         | 
| 111 | 
            +
            load_dataset_into_session()
         | 
| 112 | 
            +
             | 
| 113 | 
            +
            if "df" in st.session_state and llm:
         | 
| 114 | 
            +
                df = st.session_state.df
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                # Display dataset metadata
         | 
| 117 | 
            +
                st.write("### Dataset Metadata")
         | 
| 118 | 
            +
                st.text(f"Number of Rows: {df.shape[0]}")
         | 
| 119 | 
            +
                st.text(f"Number of Columns: {df.shape[1]}")
         | 
| 120 | 
            +
                st.text(f"Column Names: {', '.join(df.columns)}")
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                # Display dataset preview
         | 
| 123 | 
            +
                st.write("### Dataset Preview")
         | 
| 124 | 
            +
                num_rows = st.slider("Select number of rows to display:", min_value=5, max_value=50, value=10)
         | 
| 125 | 
            +
                st.dataframe(df.head(num_rows))
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                # Create SmartDataFrame
         | 
| 128 | 
            +
                chat_df = SmartDataframe(df, config={"llm": llm})
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                # Chat functionality
         | 
| 131 | 
            +
                st.write("### Chat with Your Patent Data")
         | 
| 132 | 
            +
                user_query = st.text_input("Enter your question about the patent data (e.g., 'Predict if the patent will be accepted.'):")
         | 
| 133 | 
            +
             | 
| 134 | 
            +
                if user_query:
         | 
| 135 | 
            +
                    try:
         | 
| 136 | 
            +
                        response = chat_df.chat(user_query)
         | 
| 137 | 
            +
                        st.success(f"Response: {response}")
         | 
| 138 | 
            +
                    except Exception as e:
         | 
| 139 | 
            +
                        st.error(f"Error: {e}")
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                # Plot generation functionality
         | 
| 142 | 
            +
                st.write("### Generate and View Graphs")
         | 
| 143 | 
            +
                plot_query = st.text_input("Enter a query to generate a graph (e.g., 'Plot the number of patents by filing year.'):")
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                if plot_query:
         | 
| 146 | 
            +
                    try:
         | 
| 147 | 
            +
                        with tempfile.TemporaryDirectory() as temp_dir:
         | 
| 148 | 
            +
                            # PandasAI can handle plotting
         | 
| 149 | 
            +
                            chat_df.chat(plot_query)
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                            # Save and display the plot
         | 
| 152 | 
            +
                            temp_plot_path = os.path.join(temp_dir, "plot.png")
         | 
| 153 | 
            +
                            plt.savefig(temp_plot_path)
         | 
| 154 | 
            +
                            st.image(temp_plot_path, caption="Generated Plot", use_container_width=True)
         | 
| 155 | 
            +
             | 
| 156 | 
            +
                    except Exception as e:
         | 
| 157 | 
            +
                        st.error(f"Error: {e}")
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                # Download processed dataset
         | 
| 160 | 
            +
                st.write("### Download Processed Dataset")
         | 
| 161 | 
            +
                st.download_button(
         | 
| 162 | 
            +
                    label="Download Dataset as CSV",
         | 
| 163 | 
            +
                    data=df.to_csv(index=False),
         | 
| 164 | 
            +
                    file_name="processed_dataset.csv",
         | 
| 165 | 
            +
                    mime="text/csv"
         | 
| 166 | 
            +
                )
         | 
| 167 | 
            +
             | 
| 168 | 
            +
            # Sidebar instructions
         | 
| 169 | 
            +
            with st.sidebar:
         | 
| 170 | 
            +
                st.header("Instructions:")
         | 
| 171 | 
            +
                st.markdown(
         | 
| 172 | 
            +
                    "1. Choose an LLM (Groq-based or OpenAI-based) to interact with the data.\n"
         | 
| 173 | 
            +
                    "2. Upload, select, or fetch the dataset using the provided options.\n"
         | 
| 174 | 
            +
                    "3. Enter a query to generate and view graphs based on patent attributes.\n"
         | 
| 175 | 
            +
                    "   - Example: 'Predict if the patent will be accepted.'\n"
         | 
| 176 | 
            +
                    "   - Example: 'What is the primary classification of this patent?'\n"
         | 
| 177 | 
            +
                    "   - Example: 'Summarize the abstract of this patent.'\n"
         | 
| 178 | 
            +
                )
         | 
| 179 | 
            +
                st.markdown("---")
         | 
| 180 | 
            +
                st.header("References:")
         | 
| 181 | 
            +
                st.markdown(
         | 
| 182 | 
            +
                    "1. [Chat With Your CSV File With PandasAI - Prince Krampah](https://medium.com/aimonks/chat-with-your-csv-file-with-pandasai-22232a13c7b7)"
         | 
| 183 | 
            +
                )
         | 
