File size: 4,178 Bytes
3e1d9f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import copy
from typing import Optional

from PIL import Image

from .single_image_convsation import SingleImageConvDatasetMixin


class SingleImageInteractive(SingleImageConvDatasetMixin):
    _printed_sample = True

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.image: Optional[Image.Image] = None
        self.roles = ('human', 'gpt')
        self.boxes = []
        self.points = []
        self.raw_conv = []
        self.conversations = []

    def set_image(self, image: Image.Image):
        assert self.image is None, f"{image}"
        self.image = image

    def append_message(self, role: str, message: str, *, boxes=None, points=None, boxes_seq=None, points_seq=None):
        """Append a new message."""
        assert role in self.roles

        def convert_idx(objs_seq, objs_value, get_obj_idx_func):
            if objs_seq is None:
                return None
            ret = []
            for objs_idx in objs_seq:
                new_objs_idx = []
                for idx in objs_idx:
                    new_idx = get_obj_idx_func(objs_value[idx])
                    new_objs_idx.append(new_idx)
                ret.append(tuple(new_objs_idx))
            return tuple(ret)

        boxes_seq = convert_idx(boxes_seq, boxes, self._get_box_idx)
        points_seq = convert_idx(points_seq, points, self._get_point_idx)

        if self.image is not None:
            previous_message_has_image_placeholder = any(
                '<image>' in item['value'] for item in self.conversations
            )
            if not previous_message_has_image_placeholder and '<image>' not in message:
                message = '<image> ' + message
            if previous_message_has_image_placeholder and '<image>' in message:
                message = message.replace('<image>', '')

        self.conversations.append(
            {
                'from': role,
                'value': message,
                'boxes_seq': copy.deepcopy(boxes_seq),
                'points_seq': copy.deepcopy(points_seq),
            }
        )

    def get_raw_item(self, index=None):
        ret = copy.deepcopy({
            'image': self.image,
            'target': {
                'boxes': self.boxes,
                'points': self.points,
            },
            'conversations': self.conversations,
        })
        assert ret['conversations'][0]['from'] == self.roles[0]
        if ret['conversations'][-1]['from'] == self.roles[0]:
            ret['conversations'].append(
                {
                    'from': self.roles[1],
                    'value': '',
                }
            )
        return ret

    def to_model_input(self):
        item = self.__getitem__(0)
        ret = {'input_ids': item['input_ids'].unsqueeze(0).cuda()}
        if 'image' in item and item['image'] is not None:
            ret['images'] = item['image'].unsqueeze(0).cuda()
        else:
            ret['images'] = None
        return ret

    def to_gradio_chatbot_new_messages(self):
        conv = self.__getitem__(0, return_conv=True)
        new_messages = conv.messages[-2:]
        ret_messages = []
        for r, m in new_messages:
            nm = m.replace('<im_patch>', '').replace('<im_end>', '').replace('<im_start>', '<image>')
            ret_messages.append((r, nm))
        return ret_messages

    def _get_box_idx(self, box):
        assert isinstance(box, (tuple, list)), f"{type(box)}"
        assert isinstance(box[0], (int, float)), f"{type(box[0])}"
        assert len(box) == 4
        box = tuple(box)
        if box not in self.boxes:
            self.boxes.append(box)
            return len(self.boxes) - 1
        else:
            return self.boxes.index(box)

    def _get_point_idx(self, point):
        assert isinstance(point, (tuple, list))
        assert isinstance(point[0], (int, float))
        assert len(point) == 2
        point = tuple(point)
        if point not in self.points:
            self.points.append(tuple(point))
            return len(self.points) - 1
        else:
            return self.points.index(point)

    def __len__(self):
        return 1