stevenbucaille commited on
Commit
ee88db6
·
1 Parent(s): cf31499
Files changed (2) hide show
  1. app.py +20 -10
  2. requirements.txt +2 -1
app.py CHANGED
@@ -6,6 +6,7 @@ import torch
6
  import plotly.graph_objects as go
7
  from PIL import Image
8
  import spaces
 
9
 
10
  @spaces.GPU
11
  def process_images(image1, image2):
@@ -17,9 +18,15 @@ def process_images(image1, image2):
17
 
18
  images = [image1, image2]
19
  processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint")
20
- model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint")
21
-
22
  inputs = processor(images, return_tensors="pt")
 
 
 
 
 
 
 
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
 
@@ -39,9 +46,11 @@ def process_images(image1, image2):
39
  pil_img = Image.fromarray((image1 / 255.0 * 255).astype(np.uint8))
40
  pil_img2 = Image.fromarray((image2 / 255.0 * 255).astype(np.uint8))
41
 
42
- # Create Plotly figure
43
  fig = go.Figure()
44
 
 
 
 
45
  # Get keypoints
46
  keypoints0_x, keypoints0_y = output["keypoints0"].unbind(1)
47
  keypoints1_x, keypoints1_y = output["keypoints1"].unbind(1)
@@ -55,10 +64,13 @@ def process_images(image1, image2):
55
  output["matching_scores"],
56
  ):
57
  color_val = matching_score.item()
58
- color = f"rgba({int(255 * (1 - color_val))}, {int(255 * color_val)}, 0, 0.7)"
 
 
 
59
 
60
  hover_text = (
61
- f"Score: {matching_score.item():.2f}<br>"
62
  f"Point 1: ({keypoint0_x.item():.1f}, {keypoint0_y.item():.1f})<br>"
63
  f"Point 2: ({keypoint1_x.item():.1f}, {keypoint1_y.item():.1f})"
64
  )
@@ -78,7 +90,6 @@ def process_images(image1, image2):
78
 
79
  # Update layout to use images as background
80
  fig.update_layout(
81
- title="LightGlue Keypoint Matching",
82
  xaxis=dict(
83
  range=[0, width0 + width1],
84
  showgrid=False,
@@ -93,9 +104,8 @@ def process_images(image1, image2):
93
  scaleanchor="x",
94
  scaleratio=1,
95
  ),
96
- margin=dict(l=0, r=0, t=50, b=0),
97
- height=max(height0, height1),
98
- width=width0 + width1,
99
  images=[
100
  dict(
101
  source=pil_img,
@@ -152,7 +162,7 @@ with gr.Blocks(title="LightGlue Matching Demo") as demo:
152
  process_btn = gr.Button("Match Images", variant="primary")
153
 
154
  # Output plot
155
- output_plot = gr.Plot(label="Matching Results")
156
 
157
  # Connect the function
158
  process_btn.click(fn=process_images, inputs=[image1, image2], outputs=output_plot)
 
6
  import plotly.graph_objects as go
7
  from PIL import Image
8
  import spaces
9
+ import matplotlib.cm as cm
10
 
11
  @spaces.GPU
12
  def process_images(image1, image2):
 
18
 
19
  images = [image1, image2]
20
  processor = AutoImageProcessor.from_pretrained("ETH-CVG/lightglue_superpoint")
21
+ model = AutoModel.from_pretrained("ETH-CVG/lightglue_superpoint", device_map="auto")
 
22
  inputs = processor(images, return_tensors="pt")
23
+ inputs = inputs.to(model.device)
24
+ print(
25
+ "Model is on device: ",
26
+ model.device,
27
+ "and inputs are on device: ",
28
+ inputs.device,
29
+ )
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
 
 
46
  pil_img = Image.fromarray((image1 / 255.0 * 255).astype(np.uint8))
47
  pil_img2 = Image.fromarray((image2 / 255.0 * 255).astype(np.uint8))
48
 
 
49
  fig = go.Figure()
50
 
51
+ # Create colormap (red-yellow-green: red for low scores, green for high scores)
52
+ colormap = cm.RdYlGn
53
+
54
  # Get keypoints
55
  keypoints0_x, keypoints0_y = output["keypoints0"].unbind(1)
56
  keypoints1_x, keypoints1_y = output["keypoints1"].unbind(1)
 
64
  output["matching_scores"],
65
  ):
66
  color_val = matching_score.item()
67
+ rgba_color = colormap(color_val)
68
+
69
+ # Convert to rgba string with transparency
70
+ color = f"rgba({int(rgba_color[0] * 255)}, {int(rgba_color[1] * 255)}, {int(rgba_color[2] * 255)}, 0.8)"
71
 
72
  hover_text = (
73
+ f"Score: {matching_score.item():.3f}<br>"
74
  f"Point 1: ({keypoint0_x.item():.1f}, {keypoint0_y.item():.1f})<br>"
75
  f"Point 2: ({keypoint1_x.item():.1f}, {keypoint1_y.item():.1f})"
76
  )
 
90
 
91
  # Update layout to use images as background
92
  fig.update_layout(
 
93
  xaxis=dict(
94
  range=[0, width0 + width1],
95
  showgrid=False,
 
104
  scaleanchor="x",
105
  scaleratio=1,
106
  ),
107
+ margin=dict(l=0, r=0, t=0, b=0),
108
+ autosize=True,
 
109
  images=[
110
  dict(
111
  source=pil_img,
 
162
  process_btn = gr.Button("Match Images", variant="primary")
163
 
164
  # Output plot
165
+ output_plot = gr.Plot(label="Matching Results", scale=2)
166
 
167
  # Connect the function
168
  process_btn.click(fn=process_images, inputs=[image1, image2], outputs=output_plot)
requirements.txt CHANGED
@@ -5,4 +5,5 @@ transformers @ git+https://github.com/huggingface/transformers.git@e5a9ce48f711b
5
  matplotlib
6
  torch
7
  plotly
8
- spaces
 
 
5
  matplotlib
6
  torch
7
  plotly
8
+ spaces
9
+ accelerate