dotBlood / app.py
kedimestan's picture
Update app.py
602a264 verified
raw
history blame
2.5 kB
import gradio as gr
import pandas as pd
import plotly.express as px
from shapely.geometry import Point, Polygon
def load_csv(file):
return pd.read_csv(file.name)
def update_dropdowns(df):
return gr.Dropdown(choices=df.columns.tolist()), gr.Dropdown(choices=df.columns.tolist())
def create_plot(df, x_col, y_col):
if df is None or x_col is None or y_col is None:
return None
fig = px.scatter(df, x=x_col, y=y_col, title=f"{x_col} vs {y_col}")
fig.update_layout(dragmode='lasso')
return fig
def find_points_in_polygon(selected_data, df, x_col, y_col):
if not selected_data or not df.empty:
return pd.DataFrame()
# Get selected points from Plotly's lasso selection
selected_points = selected_data['points']
if not selected_points:
return pd.DataFrame()
# Extract coordinates of polygon vertices
polygon_points = [(p['x'], p['y']) for p in selected_points]
# Create polygon and check containment
polygon = Polygon(polygon_points)
mask = df.apply(lambda row: polygon.contains(Point(row[x_col], row[y_col])), axis=1)
return df[mask]
with gr.Blocks() as demo:
gr.Markdown("## Interactive CSV Explorer with Polygon Selection")
df_state = gr.State(pd.DataFrame())
x_col_state = gr.State()
y_col_state = gr.State()
with gr.Row():
csv_upload = gr.File(label="Upload CSV", file_types=[".csv"])
x_col = gr.Dropdown(label="X Column")
y_col = gr.Dropdown(label="Y Column")
plot = gr.Plot(label="Scatter Plot")
results = gr.DataFrame(label="Selected Points")
# Upload handling
csv_upload.upload(
load_csv,
inputs=csv_upload,
outputs=df_state
).then(
update_dropdowns,
inputs=df_state,
outputs=[x_col, y_col]
)
# Plot updates
x_col.change(
create_plot,
inputs=[df_state, x_col, y_col],
outputs=plot
).then(
lambda x: x, # Store x_col in state
inputs=x_col,
outputs=x_col_state
)
y_col.change(
create_plot,
inputs=[df_state, x_col, y_col],
outputs=plot
).then(
lambda y: y, # Store y_col in state
inputs=y_col,
outputs=y_col_state
)
# Selection handling
plot.select(
find_points_in_polygon,
inputs=[gr.State(), df_state, x_col_state, y_col_state],
outputs=results
)
demo.launch()