stevenbucaille commited on
Commit
159cb1e
·
1 Parent(s): a9bdf65

Implement LightGlue image matching app with Gradio interface and necessary dependencies

Browse files
Files changed (2) hide show
  1. app.py +175 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from transformers import AutoImageProcessor, AutoModel
4
+ from transformers.image_utils import to_numpy_array
5
+ import torch
6
+ import plotly.graph_objects as go
7
+ from PIL import Image
8
+
9
+
10
+ def process_images(image1, image2):
11
+ """
12
+ Process two images and return a plot of the matching keypoints.
13
+ """
14
+ if image1 is None or image2 is None:
15
+ return None
16
+
17
+ images = [image1, image2]
18
+ processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint")
19
+ model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint")
20
+
21
+ inputs = processor(images, return_tensors="pt")
22
+ with torch.no_grad():
23
+ outputs = model(**inputs)
24
+
25
+ image_sizes = [[(image.height, image.width) for image in images]]
26
+ outputs = processor.post_process_keypoint_matching(
27
+ outputs, image_sizes, threshold=0.2
28
+ )
29
+ output = outputs[0]
30
+
31
+ image1 = to_numpy_array(image1)
32
+ image2 = to_numpy_array(image2)
33
+
34
+ height0, width0 = image1.shape[:2]
35
+ height1, width1 = image2.shape[:2]
36
+
37
+ # Create PIL image from numpy array
38
+ pil_img = Image.fromarray((image1 / 255.0 * 255).astype(np.uint8))
39
+ pil_img2 = Image.fromarray((image2 / 255.0 * 255).astype(np.uint8))
40
+
41
+ # Create Plotly figure
42
+ fig = go.Figure()
43
+
44
+ # Get keypoints
45
+ keypoints0_x, keypoints0_y = output["keypoints0"].unbind(1)
46
+ keypoints1_x, keypoints1_y = output["keypoints1"].unbind(1)
47
+
48
+ # Add a separate trace for each match (line + markers) to enable highlighting
49
+ for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
50
+ keypoints0_x,
51
+ keypoints0_y,
52
+ keypoints1_x,
53
+ keypoints1_y,
54
+ output["matching_scores"],
55
+ ):
56
+ color_val = matching_score.item()
57
+ color = f"rgba({int(255 * (1 - color_val))}, {int(255 * color_val)}, 0, 0.7)"
58
+
59
+ hover_text = (
60
+ f"Score: {matching_score.item():.2f}<br>"
61
+ f"Point 1: ({keypoint0_x.item():.1f}, {keypoint0_y.item():.1f})<br>"
62
+ f"Point 2: ({keypoint1_x.item():.1f}, {keypoint1_y.item():.1f})"
63
+ )
64
+
65
+ fig.add_trace(
66
+ go.Scatter(
67
+ x=[keypoint0_x.item(), keypoint1_x.item() + width0],
68
+ y=[keypoint0_y.item(), keypoint1_y.item()],
69
+ mode="lines+markers",
70
+ line=dict(color=color, width=2),
71
+ marker=dict(color=color, size=5, opacity=0.8),
72
+ hoverinfo="text",
73
+ hovertext=hover_text,
74
+ showlegend=False,
75
+ )
76
+ )
77
+
78
+ # Update layout to use images as background
79
+ fig.update_layout(
80
+ title="LightGlue Keypoint Matching",
81
+ xaxis=dict(
82
+ range=[0, width0 + width1],
83
+ showgrid=False,
84
+ zeroline=False,
85
+ showticklabels=False,
86
+ ),
87
+ yaxis=dict(
88
+ range=[max(height0, height1), 0],
89
+ showgrid=False,
90
+ zeroline=False,
91
+ showticklabels=False,
92
+ scaleanchor="x",
93
+ scaleratio=1,
94
+ ),
95
+ margin=dict(l=0, r=0, t=50, b=0),
96
+ height=max(height0, height1),
97
+ width=width0 + width1,
98
+ images=[
99
+ dict(
100
+ source=pil_img,
101
+ xref="x",
102
+ yref="y",
103
+ x=0,
104
+ y=0,
105
+ sizex=width0,
106
+ sizey=height0,
107
+ sizing="stretch",
108
+ opacity=1,
109
+ layer="below",
110
+ ),
111
+ dict(
112
+ source=pil_img2,
113
+ xref="x",
114
+ yref="y",
115
+ x=width0,
116
+ y=0,
117
+ sizex=width1,
118
+ sizey=height1,
119
+ sizing="stretch",
120
+ opacity=1,
121
+ layer="below",
122
+ ),
123
+ ],
124
+ )
125
+
126
+ return fig
127
+
128
+
129
+ # Create the Gradio interface
130
+ with gr.Blocks(title="LightGlue Matching Demo") as demo:
131
+ gr.Markdown("# LightGlue Matching Demo")
132
+ gr.Markdown(
133
+ "Upload two images and get a side-by-side matching of your images using LightGlue."
134
+ )
135
+ gr.Markdown("""
136
+ ## How to use:
137
+ 1. Upload two images using the file uploaders above
138
+ 2. Click the 'Match Images' button
139
+ 3. View the matched output image below
140
+
141
+ The app will create a side-by-side matching of your images using LightGlue.
142
+ You can also select an example image pair from the dataset.
143
+ """)
144
+
145
+ with gr.Row():
146
+ # Input images on the same row
147
+ image1 = gr.Image(label="First Image", type="pil")
148
+ image2 = gr.Image(label="Second Image", type="pil")
149
+
150
+ # Process button
151
+ process_btn = gr.Button("Match Images", variant="primary")
152
+
153
+ # Output plot
154
+ output_plot = gr.Plot(label="Matching Results")
155
+
156
+ # Connect the function
157
+ process_btn.click(fn=process_images, inputs=[image1, image2], outputs=output_plot)
158
+
159
+ # Add some example usage
160
+
161
+ examples = gr.Dataset(
162
+ components=[image1, image2],
163
+ label="Example Image Pairs",
164
+ samples=[
165
+ [
166
+ "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_98169888_3347710852.jpg",
167
+ "https://raw.githubusercontent.com/magicleap/SuperGluePretrainedNetwork/refs/heads/master/assets/phototourism_sample_images/united_states_capitol_26757027_6717084061.jpg",
168
+ ],
169
+ ],
170
+ )
171
+
172
+ examples.select(lambda x: (x[0], x[1]), [examples], [image1, image2])
173
+
174
+ if __name__ == "__main__":
175
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio>=5.34.2
2
+ Pillow>=10.0.0
3
+ numpy>=1.24.0
4
+ transformers
5
+ matplotlib
6
+ torch
7
+ plotly