File size: 4,863 Bytes
cb8fd55
 
 
 
 
2aed1b0
cb8fd55
2aed1b0
5b786df
cb8fd55
2aed1b0
cb8fd55
 
 
 
 
 
 
2aed1b0
cb8fd55
 
 
 
 
 
 
 
 
2aed1b0
cb8fd55
 
 
 
2aed1b0
cb8fd55
2aed1b0
 
 
 
 
4b7d3dc
cb8fd55
2aed1b0
 
 
 
4b7d3dc
2aed1b0
 
4b7d3dc
97c0463
cb8fd55
 
 
2aed1b0
cb8fd55
 
 
 
 
2aed1b0
cb8fd55
 
 
2aed1b0
cb8fd55
 
 
2aed1b0
cb8fd55
 
 
 
2aed1b0
cb8fd55
 
2aed1b0
cb8fd55
 
2aed1b0
 
 
cb8fd55
 
 
2aed1b0
cb8fd55
 
2aed1b0
 
 
 
cb8fd55
 
2aed1b0
cb8fd55
 
 
 
 
2aed1b0
cb8fd55
 
 
 
 
 
2aed1b0
cb8fd55
 
2aed1b0
 
 
 
 
 
 
 
 
 
cb8fd55
 
 
 
2aed1b0
cb8fd55
 
 
2aed1b0
cb8fd55
 
 
2aed1b0
cb8fd55
2aed1b0
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
import requests
import cv2
from skimage import feature
from io import BytesIO
from PIL import Image, ImageFile
import torch

ImageFile.LOAD_TRUNCATED_IMAGES = True


def get_canonical_label(object_name_phrase: str) -> str:
    print(f"\n [Label] Extracting label for: '{object_name_phrase}'")
    label = object_name_phrase.strip().lower().split()[-1]
    label = ''.join(filter(str.isalpha, label))
    print(f" [Label] βœ… Extracted label: '{label}'")
    return label if label else "unknown"


def download_image_from_url(image_url: str) -> Image.Image:
    print(f" [Download] Downloading image from: {image_url[:80]}...")
    response = requests.get(image_url)
    response.raise_for_status()
    image = Image.open(BytesIO(response.content))
    image_rgb = image.convert("RGB")
    print(" [Download] βœ… Image downloaded and standardized to RGB.")
    return image_rgb


def detect_and_crop(image: Image.Image, object_name: str, models: dict) -> Image.Image:
    print(f"\n [Detect & Crop] Starting detection for object: '{object_name}'")
    image_np = np.array(image.convert("RGB"))
    height, width = image_np.shape[:2]

    prompt = [[f"a {object_name}"]]
    inputs = models['processor_gnd'](
        images=image,
        text=prompt,
        return_tensors="pt"
    ).to(models['device'])

    with torch.no_grad():
        outputs = models['model_gnd'](**inputs)
        results = models['processor_gnd'].post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=0.4,
            text_threshold=0.3,
            target_sizes=[(height, width)]
        )

    if not results or len(results[0]['boxes']) == 0:
        print(" [Detect & Crop] ⚠ Warning: Grounding DINO did not detect the object. Using full image.")
        return image

    result = results[0]
    scores = result['scores']
    max_idx = int(torch.argmax(scores))
    box = result['boxes'][max_idx].cpu().numpy().astype(int)
    print(f" [Detect & Crop] βœ… Object detected with confidence: {scores[max_idx]:.2f}, Box: {box}")

    x1, y1, x2, y2 = box
    models['predictor'].set_image(image_np)
    box_prompt = np.array([[x1, y1, x2, y2]])

    masks, _, _ = models['predictor'].predict(box=box_prompt, multimask_output=False)
    mask = masks[0]
    mask_bool = mask > 0

    cropped_img_rgba = np.zeros((height, width, 4), dtype=np.uint8)
    cropped_img_rgba[:, :, :3] = image_np
    cropped_img_rgba[:, :, 3] = mask_bool * 255
    cropped_img_rgba = cropped_img_rgba[y1:y2, x1:x2]

    return Image.fromarray(cropped_img_rgba, 'RGBA')


def extract_features(segmented_image: Image.Image) -> dict:
    image_rgba = np.array(segmented_image)
    if image_rgba.shape[2] != 4:
        raise ValueError("Segmented image must be RGBA")

    b, g, r, a = cv2.split(image_rgba)
    image_rgb = cv2.merge((b, g, r))
    mask = a

    gray = cv2.cvtColor(image_rgb, cv2.COLOR_BGR2GRAY)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    hu_moments = cv2.HuMoments(cv2.moments(contours[0])).flatten() if contours else np.zeros(7)

    color_hist = cv2.calcHist([image_rgb], [0, 1, 2], mask, [8, 8, 8],
                               [0, 256, 0, 256, 0, 256])
    cv2.normalize(color_hist, color_hist)
    color_hist = color_hist.flatten()

    gray_masked = cv2.bitwise_and(gray, gray, mask=mask)
    lbp = feature.local_binary_pattern(gray_masked, P=24, R=3, method="uniform")
    (texture_hist, _) = np.histogram(lbp.ravel(), bins=np.arange(0, 27), range=(0, 26))
    texture_hist = texture_hist.astype("float32")
    texture_hist /= (texture_hist.sum() + 1e-6)

    return {
        "shape_features": hu_moments.tolist(),
        "color_features": color_hist.tolist(),
        "texture_features": texture_hist.tolist()
    }


def get_text_embedding(text: str, models: dict) -> list:
    print(f" [Embedding] Generating text embedding for: '{text[:50]}...'")
    text_with_instruction = f"Represent this sentence for searching relevant passages: {text}"

    inputs = models['tokenizer_text'](
        text_with_instruction,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=512
    ).to(models['device'])

    with torch.no_grad():
        outputs = models['model_text'](**inputs)
        embedding = outputs.last_hidden_state[:, 0, :]
        embedding = torch.nn.functional.normalize(embedding, p=2, dim=1)

    print(" [Embedding] βœ… Text embedding generated.")
    return embedding.cpu().numpy()[0].tolist()


def cosine_similarity(vec1: np.ndarray, vec2: np.ndarray) -> float:
    return float(np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)))


def stretch_image_score(score):
    if score < 0.4 or score == 1.0:
        return score
    return 0.7 + (score - 0.4) * (0.99 - 0.7) / (1.0 - 0.4)