Spaces:
Running
Running
File size: 8,258 Bytes
23c9ef8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import argparse
import torch
from libra.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from libra.conversation import conv_templates, SeparatorStyle
from libra.model.builder import load_pretrained_model
from libra.utils import disable_torch_init
from libra.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
import requests
import pydicom
from PIL import Image
from io import BytesIO
from pydicom.pixel_data_handlers.util import apply_voi_lut
from transformers import TextStreamer
def load_images(image_file):
"""
Load an image from a local file, a URL, or a DICOM file.
Args:
image_file (str): The path or URL of the image file to load.
Returns:
PIL.Image.Image: The loaded image in RGB format.
Raises:
ValueError: If the DICOM file does not contain image data.
TypeError: If the input is neither a valid file path nor a URL.
"""
if isinstance(image_file, str):
# Case 1: Load from URL
if image_file.startswith(('http://', 'https://')):
try:
response = requests.get(image_file)
response.raise_for_status()
image = Image.open(BytesIO(response.content)).convert('RGB')
except Exception as e:
raise ValueError(f"Error loading image from URL: {image_file}\n{e}")
# Case 2: Load from DICOM file
elif image_file.lower().endswith('.dcm'):
try:
dicom = pydicom.dcmread(image_file)
if 'PixelData' in dicom:
data = apply_voi_lut(dicom.pixel_array, dicom)
# Handle MONOCHROME1 images
if dicom.PhotometricInterpretation == "MONOCHROME1":
data = np.max(data) - data
# Normalize the image data
data = data - np.min(data)
data = data / np.max(data)
data = (data * 255).astype(np.uint8)
# Convert to 3-channel RGB if necessary
if data.ndim == 2:
data = np.stack([data] * 3, axis=-1)
image = Image.fromarray(data).convert('RGB')
else:
raise ValueError("DICOM file does not contain image data")
except Exception as e:
raise ValueError(f"Error loading DICOM file: {image_file}\n{e}")
# Case 3: Load standard image files (e.g., PNG, JPG)
else:
try:
image = Image.open(image_file).convert('RGB')
except Exception as e:
raise ValueError(f"Error loading standard image file: {image_file}\n{e}")
else:
raise TypeError("image_file must be a string representing a file path or URL")
return image
def main(args):
"""
Main function to load a pretrained model, process images, and interact with the user through a conversation loop.
Args:
args (Namespace): A namespace object containing the following attributes:
model_path (str): Path to the pretrained model.
model_base (str): Base model name.
load_8bit (bool): Flag to load the model in 8-bit precision.
load_4bit (bool): Flag to load the model in 4-bit precision.
device (str): Device to load the model on (e.g., 'cuda', 'cpu').
conv_mode (str, optional): Conversation mode to use. If None, it will be inferred from the model name.
image_file (list): List of paths to image files to be processed.
temperature (float): Sampling temperature for text generation.
max_new_tokens (int): Maximum number of new tokens to generate.
debug (bool): Flag to enable debug mode for additional output.
Raises:
EOFError: If an EOFError is encountered during user input, the loop will exit.
"""
# Model
disable_torch_init()
model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit, device=args.device)
if 'libra' in model_name.lower():
conv_mode = "libra_v1"
if args.conv_mode is not None and conv_mode != args.conv_mode:
print('[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}'.format(conv_mode, args.conv_mode, args.conv_mode))
else:
args.conv_mode = conv_mode
conv = conv_templates[args.conv_mode].copy()
roles = conv.roles
image=[]
for path in args.image_file:
img = load_images(path)
image.append(img)
# set dummy prior image
if len(image) == 1:
print("Contains only current image. Adding a dummy prior image.")
image.append(image[0])
processed_images = []
for img_data in image:
image_temp = process_images([img_data], image_processor, model.config)[0]
image_temp = image_temp.to(device='cuda',non_blocking=True)
processed_images.append(image_temp)
cur_images = [processed_images[0]]
prior_images = [processed_images[1]]
image_tensor = torch.stack([torch.stack(cur_images), torch.stack(prior_images)])
image_tensor = image_tensor.to(model.device, dtype=torch.float16)
while True:
try:
inp = input(f"{roles[0]}: ")
except EOFError:
inp = ""
if not inp:
print("exit...")
break
print(f"{roles[1]}: ", end="")
if image is not None:
# first message
if model.config.mm_use_im_start_end:
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
else:
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
image = None
conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(model.device)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
pad_token_id = tokenizer.pad_token_id
with torch.inference_mode():
output_ids = model.generate(
input_ids,
images=image_tensor,
do_sample=True if args.temperature > 0 else False,
temperature=args.temperature,
max_new_tokens=args.max_new_tokens,
streamer=streamer,
use_cache=True,
attention_mask=attention_mask,
pad_token_id=pad_token_id,
stopping_criteria=[stopping_criteria])
outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:],skip_special_tokens=True).strip()
conv.messages[-1][-1] = outputs
if args.debug:
print("\n", {"prompt": prompt, "outputs": outputs}, "\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str, default="X-iZhang/libra-v1.0-7b")
parser.add_argument("--model-base", type=str, default=None)
parser.add_argument("--image-file", type=str, nargs="+", required=True, help="List of image files to process.")
parser.add_argument("--device", type=str, default="cuda")
parser.add_argument("--conv-mode", type=str, default="libra_v1")
parser.add_argument("--temperature", type=float, default=0.5)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--load-8bit", action="store_true")
parser.add_argument("--load-4bit", action="store_true")
parser.add_argument("--debug", action="store_true")
args = parser.parse_args()
main(args) |