will33am commited on
Commit
448630d
·
1 Parent(s): 7c8067c

first commit

Browse files
Files changed (2) hide show
  1. .ipynb_checkpoints/app-checkpoint.py +76 -0
  2. app.py +76 -0
.ipynb_checkpoints/app-checkpoint.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+
4
+
5
+ # +
6
+ def get_methods_and_arch(dataset):
7
+ columns = dataset.column_names[5:]
8
+ methods = []
9
+ archs = []
10
+ for column in columns:
11
+ methods.append(column.split('_')[0])
12
+ archs.append(column.split('_')[1])
13
+ return list(set(methods)),list(set(archs))
14
+
15
+ def get_columns(arch,method):
16
+ columns = dataset.column_names[5:]
17
+ for col in columns:
18
+ if f'{method}_{arch}' in col:
19
+ return col
20
+ def button_fn(arch,method):
21
+ column_heatmap = get_columns(arch,method)
22
+ #print("Updated column: ",column_heatmap)
23
+ return column_heatmap,index_default,dataset[index_default]["image"],dataset[index_default][column_heatmap]
24
+
25
+ def func_slider(index,column_textbox):
26
+ #global column_heatmap
27
+ example = dataset[index]
28
+ return example['image'],example[column_textbox]
29
+
30
+
31
+ # -
32
+
33
+ dataset = load_dataset("GazeLocation/stimuli_heatmaps",split = 'train')
34
+ METHODS, ARCHS = get_methods_and_arch(dataset)
35
+ index_default = 0
36
+
37
+ if __name__ == '__main__':
38
+ demo = gr.Blocks()
39
+ with demo:
40
+ gr.Markdown("# Heatmap Gaze Location")
41
+
42
+ with gr.Row():
43
+ dropdown_arch = gr.Dropdown(choices = ARCHS,
44
+ value = 'resnet50',
45
+ label = 'Model')
46
+
47
+ dropdown_method = gr.Dropdown(choices = METHODS,
48
+ value = 'gradcam',
49
+ label = 'Method')
50
+ with gr.Row():
51
+ button = gr.Button(label = 'Update Heatmap Model - Method')
52
+
53
+ with gr.Row():
54
+ hf_slider = gr.Slider(minimum=0, maximum=len(dataset)-1,step = 1)
55
+ with gr.Row():
56
+ column_textbox = gr.Textbox(label = 'column name',
57
+ value = get_columns(ARCHS[0],METHODS[0]) )
58
+ with gr.Row():
59
+ with gr.Column():
60
+ image_input = gr.Image(label="Input Image",value = dataset[index_default]["image"])
61
+ with gr.Column():
62
+ image_output = gr.Image(label="Output",value = dataset[index_default][get_columns('resnet50','gradcam')])
63
+
64
+
65
+ button.click(fn = button_fn,
66
+ inputs = [dropdown_arch,dropdown_method],
67
+ outputs = [column_textbox,hf_slider,image_input,image_output])
68
+
69
+
70
+ hf_slider.change(func_slider,
71
+ inputs = [hf_slider,column_textbox],
72
+ outputs = [image_input, image_output])
73
+
74
+ demo.launch(share = True,debug = True)
75
+
76
+
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from datasets import load_dataset
3
+
4
+
5
+ # +
6
+ def get_methods_and_arch(dataset):
7
+ columns = dataset.column_names[5:]
8
+ methods = []
9
+ archs = []
10
+ for column in columns:
11
+ methods.append(column.split('_')[0])
12
+ archs.append(column.split('_')[1])
13
+ return list(set(methods)),list(set(archs))
14
+
15
+ def get_columns(arch,method):
16
+ columns = dataset.column_names[5:]
17
+ for col in columns:
18
+ if f'{method}_{arch}' in col:
19
+ return col
20
+ def button_fn(arch,method):
21
+ column_heatmap = get_columns(arch,method)
22
+ #print("Updated column: ",column_heatmap)
23
+ return column_heatmap,index_default,dataset[index_default]["image"],dataset[index_default][column_heatmap]
24
+
25
+ def func_slider(index,column_textbox):
26
+ #global column_heatmap
27
+ example = dataset[index]
28
+ return example['image'],example[column_textbox]
29
+
30
+
31
+ # -
32
+
33
+ dataset = load_dataset("GazeLocation/stimuli_heatmaps",split = 'train')
34
+ METHODS, ARCHS = get_methods_and_arch(dataset)
35
+ index_default = 0
36
+
37
+ if __name__ == '__main__':
38
+ demo = gr.Blocks()
39
+ with demo:
40
+ gr.Markdown("# Heatmap Gaze Location")
41
+
42
+ with gr.Row():
43
+ dropdown_arch = gr.Dropdown(choices = ARCHS,
44
+ value = 'resnet50',
45
+ label = 'Model')
46
+
47
+ dropdown_method = gr.Dropdown(choices = METHODS,
48
+ value = 'gradcam',
49
+ label = 'Method')
50
+ with gr.Row():
51
+ button = gr.Button(label = 'Update Heatmap Model - Method')
52
+
53
+ with gr.Row():
54
+ hf_slider = gr.Slider(minimum=0, maximum=len(dataset)-1,step = 1)
55
+ with gr.Row():
56
+ column_textbox = gr.Textbox(label = 'column name',
57
+ value = get_columns(ARCHS[0],METHODS[0]) )
58
+ with gr.Row():
59
+ with gr.Column():
60
+ image_input = gr.Image(label="Input Image",value = dataset[index_default]["image"])
61
+ with gr.Column():
62
+ image_output = gr.Image(label="Output",value = dataset[index_default][get_columns('resnet50','gradcam')])
63
+
64
+
65
+ button.click(fn = button_fn,
66
+ inputs = [dropdown_arch,dropdown_method],
67
+ outputs = [column_textbox,hf_slider,image_input,image_output])
68
+
69
+
70
+ hf_slider.change(func_slider,
71
+ inputs = [hf_slider,column_textbox],
72
+ outputs = [image_input, image_output])
73
+
74
+ demo.launch(share = True,debug = True)
75
+
76
+