Spaces:
Running
Running
import streamlit as st | |
import pandas as pd | |
import numpy as np | |
import plotly.express as px | |
import plotly.graph_objects as go | |
from plotly.subplots import make_subplots | |
def render_dataset_visualization(dataset, dataset_type): | |
""" | |
Renders visualizations for the dataset. | |
Args: | |
dataset: The dataset to visualize (pandas DataFrame) | |
dataset_type: The type of dataset (csv, json, etc.) | |
""" | |
if dataset is None: | |
st.warning("No dataset to visualize.") | |
return | |
st.markdown("<h3>Dataset Visualization</h3>", unsafe_allow_html=True) | |
# Get column types | |
numeric_cols = dataset.select_dtypes(include=[np.number]).columns.tolist() | |
categorical_cols = dataset.select_dtypes(include=['object', 'category']).columns.tolist() | |
date_cols = [col for col in dataset.columns if dataset[col].dtype == 'datetime64[ns]'] | |
# Add visualization options based on column types | |
viz_type = st.selectbox( | |
"Select visualization type", | |
["Distribution", "Correlation", "Categories", "Time Series", "Custom"], | |
help="Choose the type of visualization to create" | |
) | |
if viz_type == "Distribution": | |
if numeric_cols: | |
# Select columns for distribution visualization | |
selected_cols = st.multiselect( | |
"Select columns to visualize", | |
numeric_cols, | |
default=numeric_cols[:min(3, len(numeric_cols))] | |
) | |
if not selected_cols: | |
st.warning("Please select at least one column to visualize.") | |
return | |
# Distribution plots | |
if len(selected_cols) == 1: | |
# Single column histogram with density curve | |
col = selected_cols[0] | |
fig = px.histogram( | |
dataset, | |
x=col, | |
histnorm='probability density', | |
title=f"Distribution of {col}", | |
color_discrete_sequence=["#FFD21E"], | |
template="simple_white" | |
) | |
fig.add_traces( | |
go.Scatter( | |
x=dataset[col].sort_values(), | |
y=dataset[col].sort_values().reset_index(drop=True).rolling( | |
window=int(len(dataset[col])/10) if len(dataset[col]) > 10 else len(dataset[col]), | |
min_periods=1, | |
center=True | |
).mean(), | |
mode='lines', | |
line=dict(color="#2563EB", width=3), | |
name='Smoothed' | |
) | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
# Multiple histograms in a grid | |
num_cols = min(len(selected_cols), 2) | |
num_rows = (len(selected_cols) + num_cols - 1) // num_cols | |
fig = make_subplots( | |
rows=num_rows, | |
cols=num_cols, | |
subplot_titles=[f"Distribution of {col}" for col in selected_cols] | |
) | |
for i, col in enumerate(selected_cols): | |
row = i // num_cols + 1 | |
col_pos = i % num_cols + 1 | |
# Add histogram | |
fig.add_trace( | |
go.Histogram( | |
x=dataset[col], | |
name=col, | |
marker_color="#FFD21E" | |
), | |
row=row, col=col_pos | |
) | |
fig.update_layout( | |
title="Distribution of Selected Features", | |
showlegend=False, | |
template="simple_white", | |
height=300 * num_rows | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Show distribution statistics | |
st.markdown("### Distribution Statistics") | |
stats_df = dataset[selected_cols].describe().T | |
st.dataframe(stats_df, use_container_width=True) | |
else: | |
st.warning("No numeric columns found for distribution visualization.") | |
elif viz_type == "Correlation": | |
if len(numeric_cols) >= 2: | |
# Correlation matrix | |
st.markdown("### Correlation Matrix") | |
# Select columns for correlation | |
selected_cols = st.multiselect( | |
"Select columns for correlation analysis", | |
numeric_cols, | |
default=numeric_cols[:min(5, len(numeric_cols))] | |
) | |
if len(selected_cols) < 2: | |
st.warning("Please select at least two columns for correlation analysis.") | |
return | |
# Compute correlation | |
corr = dataset[selected_cols].corr() | |
# Heatmap | |
fig = px.imshow( | |
corr, | |
color_continuous_scale="RdBu_r", | |
title="Correlation Matrix", | |
template="simple_white", | |
text_auto=True | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Scatter plot matrix for selected columns | |
if len(selected_cols) > 2 and len(selected_cols) <= 5: # Limit to 5 columns for readability | |
st.markdown("### Scatter Plot Matrix") | |
fig = px.scatter_matrix( | |
dataset, | |
dimensions=selected_cols, | |
color_discrete_sequence=["#2563EB"], | |
title="Scatter Plot Matrix", | |
template="simple_white" | |
) | |
fig.update_traces(diagonal_visible=False) | |
st.plotly_chart(fig, use_container_width=True) | |
# Correlation pairs as bar chart | |
st.markdown("### Top Correlation Pairs") | |
# Get correlation pairs | |
corr_pairs = [] | |
for i in range(len(corr.columns)): | |
for j in range(i+1, len(corr.columns)): | |
corr_pairs.append({ | |
'Feature 1': corr.columns[i], | |
'Feature 2': corr.columns[j], | |
'Correlation': corr.iloc[i, j] | |
}) | |
# Sort by absolute correlation | |
corr_pairs = sorted(corr_pairs, key=lambda x: abs(x['Correlation']), reverse=True) | |
# Create bar chart | |
if corr_pairs: | |
# Convert to DataFrame | |
corr_df = pd.DataFrame(corr_pairs) | |
pair_labels = [f"{row['Feature 1']} & {row['Feature 2']}" for _, row in corr_df.iterrows()] | |
# Bar chart | |
fig = px.bar( | |
x=pair_labels, | |
y=[abs(c) for c in corr_df['Correlation']], | |
color=corr_df['Correlation'], | |
color_continuous_scale="RdBu_r", | |
labels={'x': 'Feature Pairs', 'y': 'Absolute Correlation'}, | |
title="Top Feature Correlations" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
else: | |
st.warning("Need at least two numeric columns for correlation analysis.") | |
elif viz_type == "Categories": | |
if categorical_cols: | |
# Select categorical column | |
selected_cat = st.selectbox("Select categorical column", categorical_cols) | |
# Category counts | |
value_counts = dataset[selected_cat].value_counts() | |
# Limit to top N categories if there are too many | |
if len(value_counts) > 20: | |
st.info(f"Showing top 20 categories out of {len(value_counts)}") | |
value_counts = value_counts.head(20) | |
# Bar chart | |
fig = px.bar( | |
x=value_counts.index, | |
y=value_counts.values, | |
title=f"Category Counts for {selected_cat}", | |
labels={'x': selected_cat, 'y': 'Count'}, | |
color_discrete_sequence=["#FFD21E"] | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# If there are numeric columns, show relationship with categorical | |
if numeric_cols: | |
st.markdown(f"### {selected_cat} vs Numeric Features") | |
selected_num = st.selectbox("Select numeric column", numeric_cols) | |
# Box plot | |
fig = px.box( | |
dataset, | |
x=selected_cat, | |
y=selected_num, | |
title=f"{selected_cat} vs {selected_num}", | |
color_discrete_sequence=["#2563EB"], | |
template="simple_white" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Statistics by category | |
st.markdown(f"### Statistics of {selected_num} by {selected_cat}") | |
stats_by_cat = dataset.groupby(selected_cat)[selected_num].describe() | |
st.dataframe(stats_by_cat, use_container_width=True) | |
else: | |
st.warning("No categorical columns found for category visualization.") | |
elif viz_type == "Time Series": | |
# Check if there are potential date columns | |
potential_date_cols = date_cols.copy() | |
# Also check for object columns that might be dates | |
for col in categorical_cols: | |
# Sample the column to check if it contains date-like strings | |
sample = dataset[col].dropna().head(5).tolist() | |
if sample and all('/' in str(x) or '-' in str(x) for x in sample): | |
potential_date_cols.append(col) | |
if potential_date_cols: | |
date_col = st.selectbox("Select date column", potential_date_cols) | |
# Convert to datetime if it's not already | |
if dataset[date_col].dtype != 'datetime64[ns]': | |
try: | |
temp_df = dataset.copy() | |
temp_df[date_col] = pd.to_datetime(temp_df[date_col]) | |
except: | |
st.error(f"Could not convert {date_col} to datetime.") | |
return | |
else: | |
temp_df = dataset.copy() | |
# Select numeric column for time series | |
if numeric_cols: | |
value_col = st.selectbox("Select value column", numeric_cols) | |
# Aggregate by time period | |
time_period = st.selectbox( | |
"Aggregate by", | |
["Day", "Week", "Month", "Quarter", "Year"] | |
) | |
# Set up time grouping | |
if time_period == "Day": | |
temp_df['period'] = temp_df[date_col].dt.date | |
elif time_period == "Week": | |
temp_df['period'] = temp_df[date_col].dt.to_period('W').dt.start_time | |
elif time_period == "Month": | |
temp_df['period'] = temp_df[date_col].dt.to_period('M').dt.start_time | |
elif time_period == "Quarter": | |
temp_df['period'] = temp_df[date_col].dt.to_period('Q').dt.start_time | |
else: # Year | |
temp_df['period'] = temp_df[date_col].dt.year | |
# Aggregate data | |
agg_method = st.selectbox("Aggregation method", ["Mean", "Sum", "Min", "Max", "Count"]) | |
agg_map = { | |
"Mean": "mean", | |
"Sum": "sum", | |
"Min": "min", | |
"Max": "max", | |
"Count": "count" | |
} | |
time_series = temp_df.groupby('period')[value_col].agg(agg_map[agg_method]).reset_index() | |
# Line chart | |
fig = px.line( | |
time_series, | |
x='period', | |
y=value_col, | |
title=f"{agg_method} of {value_col} by {time_period}", | |
markers=True, | |
color_discrete_sequence=["#2563EB"], | |
template="simple_white" | |
) | |
fig.update_layout( | |
xaxis_title=time_period, | |
yaxis_title=f"{agg_method} of {value_col}" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Show trendline option | |
if st.checkbox("Show trendline"): | |
fig = px.scatter( | |
time_series, | |
x='period', | |
y=value_col, | |
trendline="ols", | |
title=f"{agg_method} of {value_col} by {time_period} with Trendline", | |
color_discrete_sequence=["#2563EB"], | |
template="simple_white" | |
) | |
fig.update_layout( | |
xaxis_title=time_period, | |
yaxis_title=f"{agg_method} of {value_col}" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
# Table view of time series data | |
st.dataframe(time_series, use_container_width=True) | |
else: | |
st.warning("No numeric columns found for time series values.") | |
else: | |
st.warning("No date columns found for time series visualization.") | |
elif viz_type == "Custom": | |
st.markdown("### Custom Visualization") | |
st.info("Create a custom plot by selecting axes and plot type") | |
# Select plot type | |
plot_type = st.selectbox( | |
"Select plot type", | |
["Scatter", "Line", "Bar", "Box", "Violin", "Histogram", "Pie", "3D Scatter"] | |
) | |
# Depending on the plot type, get required axes | |
if plot_type in ["Scatter", "Line", "Bar", "3D Scatter"]: | |
# For scatter/line/bar, we need x and y | |
x_col = st.selectbox("X-axis", dataset.columns.tolist()) | |
y_col = st.selectbox("Y-axis", numeric_cols if numeric_cols else dataset.columns.tolist()) | |
# For 3D scatter, we need a z-axis | |
if plot_type == "3D Scatter": | |
z_col = st.selectbox("Z-axis", numeric_cols if numeric_cols else dataset.columns.tolist()) | |
# Optional color dimension | |
use_color = st.checkbox("Add color dimension") | |
color_col = None | |
if use_color: | |
color_col = st.selectbox("Color by", dataset.columns.tolist()) | |
# Create plot | |
if plot_type == "Scatter": | |
fig = px.scatter( | |
dataset, | |
x=x_col, | |
y=y_col, | |
color=color_col, | |
title=f"{y_col} vs {x_col}", | |
template="simple_white" | |
) | |
elif plot_type == "Line": | |
fig = px.line( | |
dataset.sort_values(x_col), | |
x=x_col, | |
y=y_col, | |
color=color_col, | |
title=f"{y_col} vs {x_col}", | |
template="simple_white" | |
) | |
elif plot_type == "Bar": | |
fig = px.bar( | |
dataset, | |
x=x_col, | |
y=y_col, | |
color=color_col, | |
title=f"{y_col} by {x_col}", | |
template="simple_white" | |
) | |
elif plot_type == "3D Scatter": | |
fig = px.scatter_3d( | |
dataset, | |
x=x_col, | |
y=y_col, | |
z=z_col, | |
color=color_col, | |
title=f"3D Scatter: {x_col}, {y_col}, {z_col}", | |
template="simple_white" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type in ["Box", "Violin"]: | |
# For box/violin, we need x (categorical) and y (numeric) | |
x_col = st.selectbox("X-axis (categories)", categorical_cols if categorical_cols else dataset.columns.tolist()) | |
y_col = st.selectbox("Y-axis (values)", numeric_cols if numeric_cols else dataset.columns.tolist()) | |
# Optional color dimension | |
use_color = st.checkbox("Add color dimension") | |
color_col = None | |
if use_color: | |
color_col = st.selectbox("Color by", dataset.columns.tolist()) | |
# Create plot | |
if plot_type == "Box": | |
fig = px.box( | |
dataset, | |
x=x_col, | |
y=y_col, | |
color=color_col, | |
title=f"Box Plot: {y_col} by {x_col}", | |
template="simple_white" | |
) | |
else: # Violin | |
fig = px.violin( | |
dataset, | |
x=x_col, | |
y=y_col, | |
color=color_col, | |
title=f"Violin Plot: {y_col} by {x_col}", | |
template="simple_white" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Histogram": | |
# For histogram, we need just one column | |
value_col = st.selectbox("Value column", dataset.columns.tolist()) | |
# Bins option | |
n_bins = st.slider("Number of bins", 5, 100, 20) | |
# Optional color dimension | |
use_color = st.checkbox("Add color dimension") | |
color_col = None | |
if use_color: | |
color_col = st.selectbox("Color by", dataset.columns.tolist()) | |
# Create plot | |
fig = px.histogram( | |
dataset, | |
x=value_col, | |
color=color_col, | |
nbins=n_bins, | |
title=f"Histogram of {value_col}", | |
template="simple_white" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |
elif plot_type == "Pie": | |
# For pie, we need a categorical column | |
cat_col = st.selectbox("Category column", categorical_cols if categorical_cols else dataset.columns.tolist()) | |
# Optional value column | |
use_values = st.checkbox("Use custom values") | |
value_col = None | |
if use_values and numeric_cols: | |
value_col = st.selectbox("Value column", numeric_cols) | |
# Limit to top N categories if there are too many | |
top_n = st.slider("Limit to top N categories", 0, 20, 10, | |
help="Set to 0 to show all categories. Recommended to limit to top 10-15 categories for readability.") | |
# Process data for pie chart | |
if top_n > 0: | |
if use_values and value_col: | |
pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index() | |
pie_data = pie_data.sort_values(value_col, ascending=False).head(top_n) | |
else: | |
value_counts = dataset[cat_col].value_counts().reset_index() | |
value_counts.columns = [cat_col, 'count'] | |
pie_data = value_counts.head(top_n) | |
value_col = 'count' | |
else: | |
if use_values and value_col: | |
pie_data = dataset.groupby(cat_col)[value_col].sum().reset_index() | |
else: | |
value_counts = dataset[cat_col].value_counts().reset_index() | |
value_counts.columns = [cat_col, 'count'] | |
pie_data = value_counts | |
value_col = 'count' | |
# Create plot | |
fig = px.pie( | |
pie_data, | |
names=cat_col, | |
values=value_col, | |
title=f"Pie Chart of {cat_col}", | |
template="simple_white" | |
) | |
st.plotly_chart(fig, use_container_width=True) | |