Xenova HF staff commited on
Commit
ece4a6d
·
verified ·
1 Parent(s): f0c221e

Create index.js

Browse files
Files changed (1) hide show
  1. index.js +301 -0
index.js ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import {
2
+ SamModel,
3
+ AutoProcessor,
4
+ RawImage,
5
+ Tensor,
6
+ } from "https://cdn.jsdelivr.net/npm/@huggingface/[email protected]";
7
+
8
+ // Reference the elements we will use
9
+ const statusLabel = document.getElementById("status");
10
+ const fileUpload = document.getElementById("upload");
11
+ const imageContainer = document.getElementById("container");
12
+ const example = document.getElementById("example");
13
+ const uploadButton = document.getElementById("upload-button");
14
+ const resetButton = document.getElementById("reset-image");
15
+ const clearButton = document.getElementById("clear-points");
16
+ const cutButton = document.getElementById("cut-mask");
17
+ const starIcon = document.getElementById("star-icon");
18
+ const crossIcon = document.getElementById("cross-icon");
19
+ const maskCanvas = document.getElementById("mask-output");
20
+ const maskContext = maskCanvas.getContext("2d");
21
+
22
+ const EXAMPLE_URL =
23
+ "https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/corgi.jpg";
24
+
25
+ // State variables
26
+ let isEncoding = false;
27
+ let isDecoding = false;
28
+ let decodePending = false;
29
+ let lastPoints = null;
30
+ let isMultiMaskMode = false;
31
+ let imageInput = null;
32
+ let imageProcessed = null;
33
+ let imageEmbeddings = null;
34
+
35
+ async function decode() {
36
+ // Only proceed if we are not already decoding
37
+ if (isDecoding) {
38
+ decodePending = true;
39
+ return;
40
+ }
41
+ isDecoding = true;
42
+
43
+ // Prepare inputs for decoding
44
+ const reshaped = imageProcessed.reshaped_input_sizes[0];
45
+ const points = lastPoints
46
+ .map((x) => [x.position[0] * reshaped[1], x.position[1] * reshaped[0]])
47
+ .flat(Infinity);
48
+ const labels = lastPoints.map((x) => BigInt(x.label)).flat(Infinity);
49
+
50
+ const num_points = lastPoints.length;
51
+ const input_points = new Tensor("float32", points, [1, 1, num_points, 2]);
52
+ const input_labels = new Tensor("int64", labels, [1, 1, num_points]);
53
+
54
+ // Generate the mask
55
+ const { pred_masks, iou_scores } = await model({
56
+ ...imageEmbeddings,
57
+ input_points,
58
+ input_labels,
59
+ });
60
+
61
+ // Post-process the mask
62
+ const masks = await processor.post_process_masks(
63
+ pred_masks,
64
+ imageProcessed.original_sizes,
65
+ imageProcessed.reshaped_input_sizes,
66
+ );
67
+
68
+ isDecoding = false;
69
+
70
+ updateMaskOverlay(RawImage.fromTensor(masks[0][0]), iou_scores.data);
71
+
72
+ // Check if another decode is pending
73
+ if (decodePending) {
74
+ decodePending = false;
75
+ decode();
76
+ }
77
+ }
78
+
79
+ function updateMaskOverlay(mask, scores) {
80
+ // Update canvas dimensions (if different)
81
+ if (maskCanvas.width !== mask.width || maskCanvas.height !== mask.height) {
82
+ maskCanvas.width = mask.width;
83
+ maskCanvas.height = mask.height;
84
+ }
85
+
86
+ // Allocate buffer for pixel data
87
+ const imageData = maskContext.createImageData(
88
+ maskCanvas.width,
89
+ maskCanvas.height,
90
+ );
91
+
92
+ // Select best mask
93
+ const numMasks = scores.length; // 3
94
+ let bestIndex = 0;
95
+ for (let i = 1; i < numMasks; ++i) {
96
+ if (scores[i] > scores[bestIndex]) {
97
+ bestIndex = i;
98
+ }
99
+ }
100
+ statusLabel.textContent = `Segment score: ${scores[bestIndex].toFixed(2)}`;
101
+
102
+ // Fill mask with colour
103
+ const pixelData = imageData.data;
104
+ for (let i = 0; i < pixelData.length; ++i) {
105
+ if (mask.data[numMasks * i + bestIndex] === 1) {
106
+ const offset = 4 * i;
107
+ pixelData[offset] = 0; // red
108
+ pixelData[offset + 1] = 114; // green
109
+ pixelData[offset + 2] = 189; // blue
110
+ pixelData[offset + 3] = 255; // alpha
111
+ }
112
+ }
113
+
114
+ // Draw image data to context
115
+ maskContext.putImageData(imageData, 0, 0);
116
+ }
117
+
118
+ function clearPointsAndMask() {
119
+ // Reset state
120
+ isMultiMaskMode = false;
121
+ lastPoints = null;
122
+
123
+ // Remove points from previous mask (if any)
124
+ document.querySelectorAll(".icon").forEach((e) => e.remove());
125
+
126
+ // Disable cut button
127
+ cutButton.disabled = true;
128
+
129
+ // Reset mask canvas
130
+ maskContext.clearRect(0, 0, maskCanvas.width, maskCanvas.height);
131
+ }
132
+ clearButton.addEventListener("click", clearPointsAndMask);
133
+
134
+ resetButton.addEventListener("click", () => {
135
+ // Reset the state
136
+ imageInput = null;
137
+ imageProcessed = null;
138
+ imageEmbeddings = null;
139
+ isEncoding = false;
140
+ isDecoding = false;
141
+
142
+ // Clear points and mask (if present)
143
+ clearPointsAndMask();
144
+
145
+ // Update UI
146
+ cutButton.disabled = true;
147
+ imageContainer.style.backgroundImage = "none";
148
+ uploadButton.style.display = "flex";
149
+ statusLabel.textContent = "Ready";
150
+ });
151
+
152
+ async function encode(url) {
153
+ if (isEncoding) return;
154
+ isEncoding = true;
155
+ statusLabel.textContent = "Extracting image embedding...";
156
+
157
+ imageInput = await RawImage.fromURL(url);
158
+
159
+ // Update UI
160
+ imageContainer.style.backgroundImage = `url(${url})`;
161
+ uploadButton.style.display = "none";
162
+ cutButton.disabled = true;
163
+
164
+ // Recompute image embeddings
165
+ imageProcessed = await processor(imageInput);
166
+ imageEmbeddings = await model.get_image_embeddings(imageProcessed);
167
+
168
+ statusLabel.textContent = "Embedding extracted!";
169
+ isEncoding = false;
170
+ }
171
+
172
+ // Handle file selection
173
+ fileUpload.addEventListener("change", function (e) {
174
+ const file = e.target.files[0];
175
+ if (!file) return;
176
+
177
+ const reader = new FileReader();
178
+
179
+ // Set up a callback when the file is loaded
180
+ reader.onload = (e2) => encode(e2.target.result);
181
+
182
+ reader.readAsDataURL(file);
183
+ });
184
+
185
+ example.addEventListener("click", (e) => {
186
+ e.preventDefault();
187
+ encode(EXAMPLE_URL);
188
+ });
189
+
190
+ // Attach hover event to image container
191
+ imageContainer.addEventListener("mousedown", (e) => {
192
+ if (e.button !== 0 && e.button !== 2) {
193
+ return; // Ignore other buttons
194
+ }
195
+ if (!imageEmbeddings) {
196
+ return; // Ignore if not encoded yet
197
+ }
198
+ if (!isMultiMaskMode) {
199
+ lastPoints = [];
200
+ isMultiMaskMode = true;
201
+ cutButton.disabled = false;
202
+ }
203
+
204
+ const point = getPoint(e);
205
+ lastPoints.push(point);
206
+
207
+ // add icon
208
+ const icon = (point.label === 1 ? starIcon : crossIcon).cloneNode();
209
+ icon.style.left = `${point.position[0] * 100}%`;
210
+ icon.style.top = `${point.position[1] * 100}%`;
211
+ imageContainer.appendChild(icon);
212
+
213
+ // Run decode
214
+ decode();
215
+ });
216
+
217
+ // Clamp a value inside a range [min, max]
218
+ function clamp(x, min = 0, max = 1) {
219
+ return Math.max(Math.min(x, max), min);
220
+ }
221
+
222
+ function getPoint(e) {
223
+ // Get bounding box
224
+ const bb = imageContainer.getBoundingClientRect();
225
+
226
+ // Get the mouse coordinates relative to the container
227
+ const mouseX = clamp((e.clientX - bb.left) / bb.width);
228
+ const mouseY = clamp((e.clientY - bb.top) / bb.height);
229
+
230
+ return {
231
+ position: [mouseX, mouseY],
232
+ label:
233
+ e.button === 2 // right click
234
+ ? 0 // negative prompt
235
+ : 1, // positive prompt
236
+ };
237
+ }
238
+
239
+ // Do not show context menu on right click
240
+ imageContainer.addEventListener("contextmenu", (e) => e.preventDefault());
241
+
242
+ // Attach hover event to image container
243
+ imageContainer.addEventListener("mousemove", (e) => {
244
+ if (!imageEmbeddings || isMultiMaskMode) {
245
+ // Ignore mousemove events if the image is not encoded yet,
246
+ // or we are in multi-mask mode
247
+ return;
248
+ }
249
+ lastPoints = [getPoint(e)];
250
+
251
+ decode();
252
+ });
253
+
254
+ // Handle cut button click
255
+ cutButton.addEventListener("click", async () => {
256
+ const [w, h] = [maskCanvas.width, maskCanvas.height];
257
+
258
+ // Get the mask pixel data (and use this as a buffer)
259
+ const maskImageData = maskContext.getImageData(0, 0, w, h);
260
+
261
+ // Create a new canvas to hold the cut-out
262
+ const cutCanvas = new OffscreenCanvas(w, h);
263
+ const cutContext = cutCanvas.getContext("2d");
264
+
265
+ // Copy the image pixel data to the cut canvas
266
+ const maskPixelData = maskImageData.data;
267
+ const imagePixelData = imageInput.data;
268
+ for (let i = 0; i < w * h; ++i) {
269
+ const sourceOffset = 3 * i; // RGB
270
+ const targetOffset = 4 * i; // RGBA
271
+
272
+ if (maskPixelData[targetOffset + 3] > 0) {
273
+ // Only copy opaque pixels
274
+ for (let j = 0; j < 3; ++j) {
275
+ maskPixelData[targetOffset + j] = imagePixelData[sourceOffset + j];
276
+ }
277
+ }
278
+ }
279
+ cutContext.putImageData(maskImageData, 0, 0);
280
+
281
+ // Download image
282
+ const link = document.createElement("a");
283
+ link.download = "image.png";
284
+ link.href = URL.createObjectURL(await cutCanvas.convertToBlob());
285
+ link.click();
286
+ link.remove();
287
+ });
288
+
289
+ const model_id = "Xenova/slimsam-77-uniform";
290
+ statusLabel.textContent = "Loading model...";
291
+ const model = await SamModel.from_pretrained(model_id, {
292
+ dtype: "fp16", // or "fp32"
293
+ device: "webgpu",
294
+ });
295
+ const processor = await AutoProcessor.from_pretrained(model_id);
296
+ statusLabel.textContent = "Ready";
297
+
298
+ // Enable the user interface
299
+ fileUpload.disabled = false;
300
+ uploadButton.style.opacity = 1;
301
+ example.style.pointerEvents = "auto";