Image-Text-to-Text
Safetensors
openvla
custom_code
emrys-hong commited on
Commit
d575b00
·
verified ·
1 Parent(s): 19f0038

Update gripper_position.py

Browse files
Files changed (1) hide show
  1. gripper_position.py +117 -0
gripper_position.py CHANGED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import numpy as np
3
+ import torch
4
+ from transformers import SamModel, SamProcessor, pipeline
5
+
6
+
7
+ checkpoint = "google/owlvit-base-patch16"
8
+ detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device="cuda")
9
+ sam_model = SamModel.from_pretrained("facebook/sam-vit-base").cuda()
10
+ sam_processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
11
+
12
+ # image_dims = (256, 256)
13
+ image_dims = (224, 224)
14
+
15
+
16
+ def get_bounding_boxes(img, prompt="the black robotic gripper"):
17
+ predictions = detector(img, candidate_labels=[prompt], threshold=0.01)
18
+
19
+ return predictions
20
+
21
+
22
+ def show_box(box, ax, meta, color):
23
+ x0, y0 = box["xmin"], box["ymin"]
24
+ w, h = box["xmax"] - box["xmin"], box["ymax"] - box["ymin"]
25
+ ax.add_patch(
26
+ matplotlib.patches.FancyBboxPatch((x0, y0), w, h, edgecolor=color, facecolor=(0, 0, 0, 0), lw=2, label="hehe")
27
+ )
28
+ ax.text(x0, y0 + 10, "{:.3f}".format(meta["score"]), color="white")
29
+
30
+
31
+ def get_median(mask, p):
32
+ row_sum = np.sum(mask, axis=1)
33
+ cumulative_sum = np.cumsum(row_sum)
34
+
35
+ if p >= 1.0:
36
+ p = 1
37
+
38
+ total_sum = np.sum(row_sum)
39
+ threshold = p * total_sum
40
+
41
+ return np.argmax(cumulative_sum >= threshold)
42
+
43
+
44
+ def get_gripper_mask(img, pred):
45
+ box = [
46
+ round(pred["box"]["xmin"], 2),
47
+ round(pred["box"]["ymin"], 2),
48
+ round(pred["box"]["xmax"], 2),
49
+ round(pred["box"]["ymax"], 2),
50
+ ]
51
+
52
+ inputs = sam_processor(img, input_boxes=[[[box]]], return_tensors="pt")
53
+
54
+ for k in inputs.keys():
55
+ inputs[k] = inputs[k].cuda()
56
+ with torch.no_grad():
57
+ outputs = sam_model(**inputs)
58
+
59
+ mask = (
60
+ sam_processor.image_processor.post_process_masks(
61
+ outputs.pred_masks, inputs["original_sizes"], inputs["reshaped_input_sizes"]
62
+ )[0][0][0]
63
+ .cpu()
64
+ .numpy()
65
+ )
66
+
67
+ return mask
68
+
69
+
70
+ def sq(w, h):
71
+ return np.concatenate(
72
+ [
73
+ (np.arange(w * h).reshape(h, w) % w)[:, :, None],
74
+ (np.arange(w * h).reshape(h, w) // w)[:, :, None],
75
+ ],
76
+ axis=-1,
77
+ )
78
+
79
+
80
+ def mask_to_pos_weighted(mask):
81
+ pos = sq(*image_dims)
82
+
83
+ weight = pos[:, :, 0] + pos[:, :, 1]
84
+ weight = weight * weight
85
+
86
+ x = np.sum(mask * pos[:, :, 0] * weight) / np.sum(mask * weight)
87
+ y = get_median(mask * weight, 0.95)
88
+
89
+ return x, y
90
+
91
+
92
+ def mask_to_pos_naive(mask):
93
+ pos = sq(*image_dims)
94
+ weight = pos[:, :, 0] + pos[:, :, 1]
95
+ min_pos = np.argmax((weight * mask).flatten())
96
+
97
+ return min_pos % image_dims[0] - (image_dims[0] / 16), min_pos // image_dims[0] - (image_dims[0] / 24)
98
+
99
+
100
+ def get_gripper_pos_raw(img):
101
+ # img = Image.fromarray(img.numpy())
102
+ predictions = get_bounding_boxes(img)
103
+
104
+ if len(predictions) > 0:
105
+ mask = get_gripper_mask(img, predictions[0])
106
+ pos = mask_to_pos_naive(mask)
107
+ else:
108
+ mask = np.zeros(image_dims)
109
+ pos = (-1, -1)
110
+ predictions = [None]
111
+
112
+ # return (int(pos[0]), int(pos[1])), mask, predictions[0]
113
+ return (int(pos[0]*224/image_dims[0]), int(pos[1]*224/image_dims[1])), mask, predictions[0]
114
+
115
+
116
+ if __name__ == "__main__":
117
+ pass