File size: 8,258 Bytes
23c9ef8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import argparse
import torch

from libra.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from libra.conversation import conv_templates, SeparatorStyle
from libra.model.builder import load_pretrained_model
from libra.utils import disable_torch_init
from libra.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

import requests
import pydicom
from PIL import Image
from io import BytesIO
from pydicom.pixel_data_handlers.util import apply_voi_lut
from transformers import TextStreamer


def load_images(image_file):
    """
    Load an image from a local file, a URL, or a DICOM file.

    Args:
        image_file (str): The path or URL of the image file to load.

    Returns:
        PIL.Image.Image: The loaded image in RGB format.

    Raises:
        ValueError: If the DICOM file does not contain image data.
        TypeError: If the input is neither a valid file path nor a URL.
    """
    if isinstance(image_file, str):
        # Case 1: Load from URL
        if image_file.startswith(('http://', 'https://')):
            try:
                response = requests.get(image_file)
                response.raise_for_status()
                image = Image.open(BytesIO(response.content)).convert('RGB')
            except Exception as e:
                raise ValueError(f"Error loading image from URL: {image_file}\n{e}")

        # Case 2: Load from DICOM file
        elif image_file.lower().endswith('.dcm'):
            try:
                dicom = pydicom.dcmread(image_file)
                if 'PixelData' in dicom:
                    data = apply_voi_lut(dicom.pixel_array, dicom)

                    # Handle MONOCHROME1 images
                    if dicom.PhotometricInterpretation == "MONOCHROME1":
                        data = np.max(data) - data

                    # Normalize the image data
                    data = data - np.min(data)
                    data = data / np.max(data)
                    data = (data * 255).astype(np.uint8)

                    # Convert to 3-channel RGB if necessary
                    if data.ndim == 2:
                        data = np.stack([data] * 3, axis=-1)

                    image = Image.fromarray(data).convert('RGB')
                else:
                    raise ValueError("DICOM file does not contain image data")
            except Exception as e:
                raise ValueError(f"Error loading DICOM file: {image_file}\n{e}")

        # Case 3: Load standard image files (e.g., PNG, JPG)
        else:
            try:
                image = Image.open(image_file).convert('RGB')
            except Exception as e:
                raise ValueError(f"Error loading standard image file: {image_file}\n{e}")

    else:
        raise TypeError("image_file must be a string representing a file path or URL")

    return image

def main(args):
    """
    Main function to load a pretrained model, process images, and interact with the user through a conversation loop.
    Args:
        args (Namespace): A namespace object containing the following attributes:
            model_path (str): Path to the pretrained model.
            model_base (str): Base model name.
            load_8bit (bool): Flag to load the model in 8-bit precision.
            load_4bit (bool): Flag to load the model in 4-bit precision.
            device (str): Device to load the model on (e.g., 'cuda', 'cpu').
            conv_mode (str, optional): Conversation mode to use. If None, it will be inferred from the model name.
            image_file (list): List of paths to image files to be processed.
            temperature (float): Sampling temperature for text generation.
            max_new_tokens (int): Maximum number of new tokens to generate.
            debug (bool): Flag to enable debug mode for additional output.
    Raises:
        EOFError: If an EOFError is encountered during user input, the loop will exit.
    """
    # Model
    disable_torch_init()

    model_name = get_model_name_from_path(args.model_path)
    tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)

    if 'libra' in model_name.lower():
        conv_mode = "libra_v1"

    if args.conv_mode is not None and conv_mode != args.conv_mode:
        print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
    else:
        args.conv_mode = conv_mode

    conv = conv_templates[args.conv_mode].copy()
    roles = conv.roles
    
    image=[]
    for path in args.image_file:
        img = load_images(path)
        image.append(img)

    # set dummy prior image
    if len(image) == 1:
        print("Contains only current image. Adding a dummy prior image.")
        image.append(image[0])

    processed_images = []
    for img_data in image:
        image_temp = process_images([img_data], image_processor, model.config)[0]
        image_temp = image_temp.to(device='cuda',non_blocking=True) 
        processed_images.append(image_temp)
        
    cur_images = [processed_images[0]]
    prior_images = [processed_images[1]]
    image_tensor = torch.stack([torch.stack(cur_images), torch.stack(prior_images)])
    image_tensor = image_tensor.to(model.device, dtype=torch.float16)

    while True:
        try:
            inp = input(f"{roles[0]}: ")
        except EOFError:
            inp = ""
        if not inp:
            print("exit...")
            break

        print(f"{roles[1]}: ", end="")

        if image is not None:
            # first message
            if model.config.mm_use_im_start_end:
                inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
            else:
                inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
            image = None

        conv.append_message(conv.roles[0], inp)
        conv.append_message(conv.roles[1], None)
        prompt = conv.get_prompt()

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
        stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
        
        attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
        pad_token_id = tokenizer.pad_token_id
        
        with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True if args.temperature > 0 else False,
                temperature=args.temperature,
                max_new_tokens=args.max_new_tokens,
                streamer=streamer,
                use_cache=True,
                attention_mask=attention_mask, 
                pad_token_id=pad_token_id,
                stopping_criteria=[stopping_criteria])
        
        outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:],skip_special_tokens=True).strip()
        conv.messages[-1][-1] = outputs

        if args.debug:
            print("\n", {"prompt": prompt, "outputs": outputs}, "\n")

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model-path", type=str, default="X-iZhang/libra-v1.0-7b")
    parser.add_argument("--model-base", type=str, default=None)
    parser.add_argument("--image-file", type=str, nargs="+", required=True, help="List of image files to process.")
    parser.add_argument("--device", type=str, default="cuda")
    parser.add_argument("--conv-mode", type=str, default="libra_v1")
    parser.add_argument("--temperature", type=float, default=0.5)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--load-8bit", action="store_true")
    parser.add_argument("--load-4bit", action="store_true")
    parser.add_argument("--debug", action="store_true")
    args = parser.parse_args()
    
    main(args)