File size: 8,459 Bytes
61286f7 b51c72b 61286f7 907461e 61286f7 907461e 61286f7 b51c72b 61286f7 907461e 61286f7 907461e 61286f7 907461e 61286f7 907461e 61286f7 907461e 61286f7 907461e 61286f7 b51c72b 61286f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
import os
import json
import re
import sys
import io
import contextlib
import warnings
from typing import Optional, List, Any, Tuple
from PIL import Image
import streamlit as st
import pandas as pd
import base64
from io import BytesIO
from together import Together
from e2b_code_interpreter import Sandbox
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
pattern = re.compile(r"```python\n(.*?)\n```", re.DOTALL)
def code_interpret(e2b_code_interpreter: Sandbox, code: str) -> Optional[List[Any]]:
with st.spinner('Executing code in E2B sandbox...'):
stdout_capture = io.StringIO()
stderr_capture = io.StringIO()
with contextlib.redirect_stdout(stdout_capture), contextlib.redirect_stderr(stderr_capture):
with warnings.catch_warnings():
warnings.simplefilter("ignore")
exec = e2b_code_interpreter.run_code(code)
if stderr_capture.getvalue():
print("[Code Interpreter Warnings/Errors]", file=sys.stderr)
print(stderr_capture.getvalue(), file=sys.stderr)
if stdout_capture.getvalue():
print("[Code Interpreter Output]", file=sys.stdout)
print(stdout_capture.getvalue(), file=sys.stdout)
if exec.error:
print(f"[Code Interpreter ERROR] {exec.error}", file=sys.stderr)
return None
return exec.results
def match_code_blocks(llm_response: str) -> str:
match = pattern.search(llm_response)
if match:
code = match.group(1)
return code
return ""
def chat_with_llm(e2b_code_interpreter: Sandbox, user_message: str, dataset_path: str) -> Tuple[Optional[List[Any]], str]:
# Updated system prompt with Excel support
system_prompt = f"""You're a Python data scientist and data visualization expert. You are given a dataset at path '{dataset_path}' (could be CSV or Excel) and the user's query.
You need to analyze the dataset and answer the user's query with a response and you run Python code to solve them.
IMPORTANT:
- Use pd.read_csv() for .csv files and pd.read_excel() for .xlsx/.xls files
- Always use the dataset path variable '{dataset_path}' in your code when reading the file"""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message},
]
with st.spinner('Getting response from Together AI LLM model...'):
client = Together(api_key=st.session_state.together_api_key)
response = client.chat.completions.create(
model=st.session_state.model_name,
messages=messages,
)
response_message = response.choices[0].message
python_code = match_code_blocks(response_message.content)
if python_code:
code_interpreter_results = code_interpret(e2b_code_interpreter, python_code)
return code_interpreter_results, response_message.content
else:
st.warning(f"Failed to match any Python code in model's response")
return None, response_message.content
def upload_dataset(code_interpreter: Sandbox, uploaded_file) -> str:
dataset_path = f"./{uploaded_file.name}"
try:
code_interpreter.files.write(dataset_path, uploaded_file)
return dataset_path
except Exception as error:
st.error(f"Error during file upload: {error}")
raise error
def main():
"""Main Streamlit application."""
st.set_page_config(page_title="π AI Data Visualization Agent", page_icon="π", layout="wide")
st.title("π AI Data Visualization Agent")
st.write("Upload your dataset (CSV or Excel) and ask questions about it!")
# Initialize session state variables
if 'together_api_key' not in st.session_state:
st.session_state.together_api_key = ''
if 'e2b_api_key' not in st.session_state:
st.session_state.e2b_api_key = ''
if 'model_name' not in st.session_state:
st.session_state.model_name = ''
# Sidebar for API keys and model configuration
with st.sidebar:
st.header("π API Keys and Model Configuration")
st.session_state.together_api_key = st.text_input("Together AI API Key", type="password")
st.info("π‘ Everyone gets a free $1 credit by Together AI - AI Acceleration Cloud platform")
st.markdown("[Get Together AI API Key](https://api.together.ai/signin)")
st.session_state.e2b_api_key = st.text_input("Enter E2B API Key", type="password")
st.markdown("[Get E2B API Key](https://e2b.dev/docs/legacy/getting-started/api-key)")
# Add model selection dropdown
model_options = {
"Meta-Llama 3.1 405B": "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
"DeepSeek V3": "deepseek-ai/DeepSeek-V3",
"Qwen 2.5 7B": "Qwen/Qwen2.5-7B-Instruct-Turbo",
"Meta-Llama 3.3 70B": "meta-llama/Llama-3.3-70B-Instruct-Turbo"
}
st.session_state.model_name = st.selectbox(
"Select Model",
options=list(model_options.keys()),
index=0 # Default to first option
)
st.session_state.model_name = model_options[st.session_state.model_name]
# Main content layout
col1, col2 = st.columns([1, 2])
with col1:
st.header("π Upload Dataset")
# Updated file uploader to accept both CSV and Excel
uploaded_file = st.file_uploader("Choose a CSV or Excel file", type=["csv", "xlsx", "xls"], key="file_uploader")
if uploaded_file is not None:
# Handle both CSV and Excel files
if uploaded_file.name.endswith(('.xlsx', '.xls')):
df = pd.read_excel(uploaded_file)
else:
df = pd.read_csv(uploaded_file)
st.write("### Dataset Preview")
show_full = st.checkbox("Show full dataset")
if show_full:
st.dataframe(df)
else:
st.write("Preview (first 5 rows):")
st.dataframe(df.head())
with col2:
if uploaded_file is not None:
st.header("β Ask a Question")
query = st.text_area(
"What would you like to know about your data?",
"Can you compare the average cost for two people between different categories?",
height=100
)
if st.button("Analyze", type="primary", key="analyze_button"):
if not st.session_state.together_api_key or not st.session_state.e2b_api_key:
st.error("Please enter both API keys in the sidebar.")
else:
with Sandbox(api_key=st.session_state.e2b_api_key) as code_interpreter:
# Upload the dataset
dataset_path = upload_dataset(code_interpreter, uploaded_file)
# Pass dataset_path to chat_with_llm
code_results, llm_response = chat_with_llm(code_interpreter, query, dataset_path)
# Display LLM's text response
st.header("π€ AI Response")
st.write(llm_response)
# Display results/visualizations
if code_results:
st.header("π Analysis Results")
for result in code_results:
if hasattr(result, 'png') and result.png:
png_data = base64.b64decode(result.png)
image = Image.open(BytesIO(png_data))
st.image(image, caption="Generated Visualization", use_column_width=True)
elif hasattr(result, 'figure'):
fig = result.figure
st.pyplot(fig)
elif hasattr(result, 'show'):
st.plotly_chart(result)
elif isinstance(result, (pd.DataFrame, pd.Series)):
st.dataframe(result)
else:
st.write(result)
if __name__ == "__main__":
main() |