File size: 2,497 Bytes
c6379ae
9be415c
602a264
9239ab3
c6379ae
9239ab3
 
 
 
 
f305b38
9239ab3
 
 
602a264
 
 
 
f305b38
9239ab3
602a264
 
 
 
 
 
 
 
 
 
 
 
 
 
9239ab3
602a264
f305b38
 
602a264
ee16303
67adc66
602a264
 
ee16303
 
9239ab3
 
 
ee16303
602a264
 
ee16303
602a264
9239ab3
67adc66
9239ab3
 
 
 
 
 
 
 
602a264
9239ab3
 
 
 
602a264
 
 
 
9239ab3
602a264
9239ab3
 
 
 
602a264
 
 
 
9239ab3
 
602a264
9239ab3
 
602a264
9239ab3
 
f305b38
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import gradio as gr
import pandas as pd
import plotly.express as px
from shapely.geometry import Point, Polygon

def load_csv(file):
    return pd.read_csv(file.name)

def update_dropdowns(df):
    return gr.Dropdown(choices=df.columns.tolist()), gr.Dropdown(choices=df.columns.tolist())

def create_plot(df, x_col, y_col):
    if df is None or x_col is None or y_col is None:
        return None
        
    fig = px.scatter(df, x=x_col, y=y_col, title=f"{x_col} vs {y_col}")
    fig.update_layout(dragmode='lasso')
    return fig

def find_points_in_polygon(selected_data, df, x_col, y_col):
    if not selected_data or not df.empty:
        return pd.DataFrame()
    
    # Get selected points from Plotly's lasso selection
    selected_points = selected_data['points']
    if not selected_points:
        return pd.DataFrame()
    
    # Extract coordinates of polygon vertices
    polygon_points = [(p['x'], p['y']) for p in selected_points]
    
    # Create polygon and check containment
    polygon = Polygon(polygon_points)
    mask = df.apply(lambda row: polygon.contains(Point(row[x_col], row[y_col])), axis=1)
    
    return df[mask]

with gr.Blocks() as demo:
    gr.Markdown("## Interactive CSV Explorer with Polygon Selection")
    
    df_state = gr.State(pd.DataFrame())
    x_col_state = gr.State()
    y_col_state = gr.State()
    
    with gr.Row():
        csv_upload = gr.File(label="Upload CSV", file_types=[".csv"])
        x_col = gr.Dropdown(label="X Column")
        y_col = gr.Dropdown(label="Y Column")
    
    plot = gr.Plot(label="Scatter Plot")
    results = gr.DataFrame(label="Selected Points")
    
    # Upload handling
    csv_upload.upload(
        load_csv,
        inputs=csv_upload,
        outputs=df_state
    ).then(
        update_dropdowns,
        inputs=df_state,
        outputs=[x_col, y_col]
    )
    
    # Plot updates
    x_col.change(
        create_plot,
        inputs=[df_state, x_col, y_col],
        outputs=plot
    ).then(
        lambda x: x,  # Store x_col in state
        inputs=x_col,
        outputs=x_col_state
    )
    
    y_col.change(
        create_plot,
        inputs=[df_state, x_col, y_col],
        outputs=plot
    ).then(
        lambda y: y,  # Store y_col in state
        inputs=y_col,
        outputs=y_col_state
    )
    
    # Selection handling
    plot.select(
        find_points_in_polygon,
        inputs=[gr.State(), df_state, x_col_state, y_col_state],
        outputs=results
    )

demo.launch()