sayedM commited on
Commit
471a3ca
Β·
verified Β·
1 Parent(s): 53fb72b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -12
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import torch
2
  import gradio as gr
3
  import numpy as np
@@ -5,16 +6,18 @@ from PIL import Image
5
  import torchvision.transforms.functional as TF
6
  from matplotlib import colormaps
7
  from transformers import AutoModel
 
8
 
9
  # ----------------------------
10
  # Configuration
11
  # ----------------------------
12
  # The model will be downloaded from the Hugging Face Hub
13
- MODEL_ID = "facebook/dinov3-vith16plus-pretrain-lvd1689m"
 
14
  PATCH_SIZE = 16
15
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- # Normalization constants
18
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
19
  IMAGENET_STD = (0.229, 0.224, 0.225)
20
 
@@ -25,14 +28,17 @@ def load_model_from_hub():
25
  """Loads the DINOv3 model from the Hugging Face Hub."""
26
  print(f"Loading model '{MODEL_ID}' from Hugging Face Hub...")
27
  try:
28
- model = AutoModel.from_pretrained(MODEL_ID)
 
 
 
29
  model.to(DEVICE).eval()
30
  print(f"βœ… Model loaded successfully on device: {DEVICE}")
31
  return model
32
  except Exception as e:
33
  print(f"❌ Failed to load model: {e}")
34
- gr.Error(f"Could not load model from Hub: {e}")
35
- return None
36
 
37
  # Load the model globally when the app starts
38
  model = load_model_from_hub()
@@ -79,7 +85,7 @@ def generate_pca_visuals(
79
  ):
80
  """Main function to generate PCA visuals."""
81
  if model is None:
82
- raise gr.Error("DINOv3 model could not be loaded. Check the logs.")
83
  if image_pil is None:
84
  return None, None, "Please upload an image and click Generate.", None, None
85
 
@@ -94,20 +100,24 @@ def generate_pca_visuals(
94
  # 2. Feature Extraction
95
  progress(0.5, desc="πŸ¦– Extracting features with DINOv3...")
96
  outputs = model(t_norm)
97
- # The patch embeddings are in last_hidden_state, we skip the first token (CLS)
98
- patch_embeddings = outputs.last_hidden_state.squeeze(0)[1:, :]
 
 
 
 
99
 
100
  # 3. PCA Calculation
101
  progress(0.8, desc="πŸ”¬ Performing PCA...")
102
  X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
103
  U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
104
 
105
- # Stabilize the signs of the eigenvectors for deterministic output
 
106
  for i in range(V.shape[1]):
107
  max_abs_idx = torch.argmax(torch.abs(V[:, i]))
108
  if V[max_abs_idx, i] < 0:
109
  V[:, i] *= -1
110
-
111
  scores = X_centered @ V[:, :3]
112
 
113
  # 4. Explained Variance
@@ -121,8 +131,10 @@ def generate_pca_visuals(
121
  )
122
 
123
  # 5. Create Visualizations
 
124
  pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
125
  pc1_image_raw = colorize(pc1_map, cmap_name)
 
126
  pc_rgb_map = scores.reshape(Hp, Wp, 3).cpu().numpy()
127
  min_vals = pc_rgb_map.reshape(-1, 3).min(axis=0)
128
  max_vals = pc_rgb_map.reshape(-1, 3).max(axis=0)
@@ -137,7 +149,6 @@ def generate_pca_visuals(
137
  progress(1.0, desc="βœ… Done!")
138
  return pc1_image_smooth, pc_rgb_image_smooth, variance_text, blended_image, original_processed_image
139
 
140
-
141
  # ----------------------------
142
  # Gradio Interface
143
  # ----------------------------
@@ -152,7 +163,8 @@ with gr.Blocks(theme=gr.themes.Soft(), title="DINOv3 PCA Explorer") as demo:
152
 
153
  with gr.Row():
154
  with gr.Column(scale=2):
155
- input_image = gr.Image(type="pil", label="Upload Image", value="https://picsum.photos/id/1011/800/600")
 
156
 
157
  with gr.Accordion("βš™οΈ Visualization Controls", open=True):
158
  resolution_slider = gr.Slider(
 
1
+ # app.py
2
  import torch
3
  import gradio as gr
4
  import numpy as np
 
6
  import torchvision.transforms.functional as TF
7
  from matplotlib import colormaps
8
  from transformers import AutoModel
9
+ import os
10
 
11
  # ----------------------------
12
  # Configuration
13
  # ----------------------------
14
  # The model will be downloaded from the Hugging Face Hub
15
+ # Using the specific revision that works well with transformers AutoModel
16
+ MODEL_ID = "facebook/dinov3-vith16plus"
17
  PATCH_SIZE = 16
18
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
19
 
20
+ # Normalization constants (standard for ImageNet)
21
  IMAGENET_MEAN = (0.485, 0.456, 0.406)
22
  IMAGENET_STD = (0.229, 0.224, 0.225)
23
 
 
28
  """Loads the DINOv3 model from the Hugging Face Hub."""
29
  print(f"Loading model '{MODEL_ID}' from Hugging Face Hub...")
30
  try:
31
+ # Use your HF token if the model is gated
32
+ # You can set this as a secret in your Hugging Face Space settings
33
+ token = os.environ.get("HF_TOKEN")
34
+ model = AutoModel.from_pretrained(MODEL_ID, token=token, trust_remote_code=True)
35
  model.to(DEVICE).eval()
36
  print(f"βœ… Model loaded successfully on device: {DEVICE}")
37
  return model
38
  except Exception as e:
39
  print(f"❌ Failed to load model: {e}")
40
+ # This will display an error message in the Gradio interface
41
+ raise gr.Error(f"Could not load model from Hub. If it's a gated model, ensure you have access and have set your HF_TOKEN secret in the Space settings. Error: {e}")
42
 
43
  # Load the model globally when the app starts
44
  model = load_model_from_hub()
 
85
  ):
86
  """Main function to generate PCA visuals."""
87
  if model is None:
88
+ raise gr.Error("DINOv3 model is not available. Check the startup logs.")
89
  if image_pil is None:
90
  return None, None, "Please upload an image and click Generate.", None, None
91
 
 
100
  # 2. Feature Extraction
101
  progress(0.5, desc="πŸ¦– Extracting features with DINOv3...")
102
  outputs = model(t_norm)
103
+
104
+ # πŸ’‘ FIX: The model output includes a [CLS] token AND 4 register tokens.
105
+ # We must skip all of them (total 5) to get only the patch embeddings.
106
+ # The original code only skipped 1, causing the size mismatch.
107
+ n_special_tokens = 5 # 1 [CLS] token + 4 register tokens
108
+ patch_embeddings = outputs.last_hidden_state.squeeze(0)[n_special_tokens:, :]
109
 
110
  # 3. PCA Calculation
111
  progress(0.8, desc="πŸ”¬ Performing PCA...")
112
  X_centered = patch_embeddings.float() - patch_embeddings.float().mean(0, keepdim=True)
113
  U, S, V = torch.pca_lowrank(X_centered, q=3, center=False)
114
 
115
+ # πŸ’‘ IMPROVEMENT: Stabilize the signs of the eigenvectors for deterministic output.
116
+ # This prevents the colors from randomly inverting on different runs.
117
  for i in range(V.shape[1]):
118
  max_abs_idx = torch.argmax(torch.abs(V[:, i]))
119
  if V[max_abs_idx, i] < 0:
120
  V[:, i] *= -1
 
121
  scores = X_centered @ V[:, :3]
122
 
123
  # 4. Explained Variance
 
131
  )
132
 
133
  # 5. Create Visualizations
134
+ # This part should now work correctly as `scores` has the right shape (Hp*Wp, 3)
135
  pc1_map = scores[:, 0].reshape(Hp, Wp).cpu().numpy()
136
  pc1_image_raw = colorize(pc1_map, cmap_name)
137
+
138
  pc_rgb_map = scores.reshape(Hp, Wp, 3).cpu().numpy()
139
  min_vals = pc_rgb_map.reshape(-1, 3).min(axis=0)
140
  max_vals = pc_rgb_map.reshape(-1, 3).max(axis=0)
 
149
  progress(1.0, desc="βœ… Done!")
150
  return pc1_image_smooth, pc_rgb_image_smooth, variance_text, blended_image, original_processed_image
151
 
 
152
  # ----------------------------
153
  # Gradio Interface
154
  # ----------------------------
 
163
 
164
  with gr.Row():
165
  with gr.Column(scale=2):
166
+ # Added a default image URL for convenience
167
+ input_image = gr.Image(type="pil", label="Upload Image", value="https://images.squarespace-cdn.com/content/v1/607f89e638219e13eee71b1e/1684821560422-SD5V37BAG28BURTLIXUQ/michael-sum-LEpfefQf4rU-unsplash.jpg")
168
 
169
  with gr.Accordion("βš™οΈ Visualization Controls", open=True):
170
  resolution_slider = gr.Slider(