|
import copy |
|
import os |
|
import spaces |
|
import subprocess |
|
import time |
|
import torch |
|
|
|
from threading import Thread |
|
from typing import List, Tuple |
|
from urllib.parse import urlparse |
|
from PIL import Image |
|
|
|
import gradio as gr |
|
from gradio_client.client import DEFAULT_TEMP_DIR |
|
from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer |
|
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension |
|
from transformers.image_transforms import resize, to_channel_dimension_format |
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
DEVICE = torch.device("cuda") |
|
MODELS = { |
|
"288_ter - mix8 - opt 5'800": AutoModelForCausalLM.from_pretrained( |
|
"HuggingFaceM4/idefics2", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
revision="25bb7ad6d9ab9e43d5002d30f857d4106ed964f3", |
|
).to(DEVICE), |
|
"288_ter - mix 8 - opt 11'000": AutoModelForCausalLM.from_pretrained( |
|
"HuggingFaceM4/idefics2", |
|
trust_remote_code=True, |
|
torch_dtype=torch.bfloat16, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
revision="7eccbf5178f85eee8fab9995f31ab12441ce767a", |
|
).to(DEVICE), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
PROCESSOR = AutoProcessor.from_pretrained( |
|
"HuggingFaceM4/idefics2", |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
) |
|
FAKE_TOK_AROUND_IMAGE = "<fake_token_around_image>" |
|
BOS_TOKEN = PROCESSOR.tokenizer.bos_token |
|
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids |
|
EOS_WORDS_IDS = PROCESSOR.tokenizer("<end_of_utterance>", add_special_tokens=False).input_ids + [PROCESSOR.tokenizer.eos_token_id] |
|
IMAGE_SEQ_LEN = 64 |
|
|
|
SYSTEM_PROMPT = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
API_TOKEN = os.getenv("HF_AUTH_TOKEN") |
|
|
|
BOT_AVATAR = "IDEFICS_logo.png" |
|
|
|
|
|
|
|
def convert_to_rgb(image): |
|
|
|
|
|
if image.mode == "RGB": |
|
return image |
|
|
|
image_rgba = image.convert("RGBA") |
|
background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) |
|
alpha_composite = Image.alpha_composite(background, image_rgba) |
|
alpha_composite = alpha_composite.convert("RGB") |
|
return alpha_composite |
|
|
|
|
|
def custom_transform(x): |
|
x = convert_to_rgb(x) |
|
x = to_numpy_array(x) |
|
|
|
height, width = x.shape[:2] |
|
aspect_ratio = width / height |
|
if width >= height and width > 980: |
|
width = 980 |
|
height = int(width / aspect_ratio) |
|
elif height > width and height > 980: |
|
height = 980 |
|
width = int(height * aspect_ratio) |
|
width = max(width, 378) |
|
height = max(height, 378) |
|
|
|
x = resize(x, (height, width), resample=PILImageResampling.BILINEAR) |
|
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) |
|
x = PROCESSOR.image_processor.normalize( |
|
x, |
|
mean=PROCESSOR.image_processor.image_mean, |
|
std=PROCESSOR.image_processor.image_std |
|
) |
|
x = to_channel_dimension_format(x, ChannelDimension.FIRST) |
|
x = torch.tensor(x) |
|
return x |
|
|
|
|
|
def create_model_inputs( |
|
input_texts: List[str], |
|
image_lists: List[List[Image.Image]], |
|
): |
|
""" |
|
All this logic will eventually be handled inside the model processor. |
|
""" |
|
inputs = PROCESSOR.tokenizer( |
|
input_texts, |
|
return_tensors="pt", |
|
add_special_tokens=False, |
|
padding=True, |
|
) |
|
|
|
output_images = [ |
|
[PROCESSOR.image_processor(img, transform=custom_transform) for img in im_list] |
|
for im_list in image_lists |
|
] |
|
total_batch_size = len(output_images) |
|
max_num_images = max([len(img_l) for img_l in output_images]) |
|
if max_num_images > 0: |
|
max_height = max([i.size(2) for img_l in output_images for i in img_l]) |
|
max_width = max([i.size(3) for img_l in output_images for i in img_l]) |
|
padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width) |
|
padded_pixel_attention_masks = torch.zeros( |
|
total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool |
|
) |
|
for batch_idx, img_l in enumerate(output_images): |
|
for img_idx, img in enumerate(img_l): |
|
im_height, im_width = img.size()[2:] |
|
padded_image_tensor[batch_idx, img_idx, :, :im_height, :im_width] = img |
|
padded_pixel_attention_masks[batch_idx, img_idx, :im_height, :im_width] = True |
|
|
|
inputs["pixel_values"] = padded_image_tensor |
|
inputs["pixel_attention_mask"] = padded_pixel_attention_masks |
|
|
|
return inputs |
|
|
|
|
|
|
|
def is_image(string: str) -> bool: |
|
""" |
|
There are two ways for images: local image path or url. |
|
""" |
|
return is_url(string) or string.startswith(DEFAULT_TEMP_DIR) |
|
|
|
|
|
def is_url(string: str) -> bool: |
|
""" |
|
Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately |
|
invalidated the url |
|
""" |
|
if " " in string: |
|
return False |
|
result = urlparse(string) |
|
return all([result.scheme, result.netloc]) |
|
|
|
|
|
def prompt_list_to_model_input(prompt_list: List[str]) -> Tuple[str, List[Image.Image]]: |
|
""" |
|
Create the final input string and image list to feed to the model. |
|
""" |
|
images = [] |
|
for idx, part in enumerate(prompt_list): |
|
if is_image(part): |
|
images.append(Image.open(part)) |
|
prompt_list[idx] = f"{FAKE_TOK_AROUND_IMAGE}{'<image>' * IMAGE_SEQ_LEN}{FAKE_TOK_AROUND_IMAGE}" |
|
input_text = "".join(prompt_list) |
|
input_text = input_text.replace(FAKE_TOK_AROUND_IMAGE * 2, FAKE_TOK_AROUND_IMAGE) |
|
input_text = BOS_TOKEN + input_text.strip() |
|
return input_text, images |
|
|
|
|
|
def turn_is_pure_media(turn): |
|
return turn[1] is None |
|
|
|
|
|
def format_user_prompt_with_im_history_and_system_conditioning( |
|
user_prompt, chat_history |
|
) -> List[str]: |
|
""" |
|
Produces the resulting list that needs to go inside the processor. |
|
It handles the potential image(s), the history and the system conditionning. |
|
""" |
|
resulting_list = copy.deepcopy(SYSTEM_PROMPT) |
|
|
|
|
|
for turn in chat_history: |
|
if turn_is_pure_media(turn): |
|
media = turn[0][0] |
|
if resulting_list == [] or (resulting_list != [] and resulting_list[-1].endswith("<end_of_utterance>")): |
|
resulting_list.append("\nUser:") |
|
resulting_list.append(media) |
|
else: |
|
user_utterance, assistant_utterance = turn |
|
if resulting_list and is_image(resulting_list[-1]): |
|
resulting_list.append(f"{user_utterance.strip()}<end_of_utterance>\nAssistant: {assistant_utterance}<end_of_utterance>") |
|
else: |
|
resulting_list.append(f"\nUser: {user_utterance.strip()}<end_of_utterance>\nAssistant: {assistant_utterance}<end_of_utterance>") |
|
|
|
|
|
if not user_prompt["files"]: |
|
resulting_list.append(f"\nUser: ") |
|
else: |
|
|
|
resulting_list.append("\nUser:") |
|
resulting_list.extend([im["path"] for im in user_prompt["files"]]) |
|
resulting_list.append(f"{user_prompt['text']}<end_of_utterance>\nAssistant:") |
|
|
|
return resulting_list |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def model_inference( |
|
user_prompt, |
|
chat_history, |
|
model_selector, |
|
decoding_strategy, |
|
temperature, |
|
max_new_tokens, |
|
repetition_penalty, |
|
top_p, |
|
): |
|
if user_prompt["text"].strip() == "" and not user_prompt["files"]: |
|
gr.Error("Please input a query and optionally image(s).") |
|
|
|
if user_prompt["text"].strip() == "" and user_prompt["files"]: |
|
gr.Error("Please input a text query along the image(s).") |
|
|
|
for file in user_prompt["files"]: |
|
if not file["mime_type"].startswith("image/"): |
|
gr.Error("Idefics2 only supports images. Please input a valid image.") |
|
|
|
formated_prompt_list = format_user_prompt_with_im_history_and_system_conditioning( |
|
user_prompt=user_prompt, |
|
chat_history=chat_history, |
|
) |
|
|
|
streamer = TextIteratorStreamer( |
|
PROCESSOR.tokenizer, |
|
skip_prompt=True, |
|
timeout=5., |
|
) |
|
|
|
|
|
|
|
generation_args = { |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
"bad_words_ids": BAD_WORDS_IDS, |
|
"eos_token_id": EOS_WORDS_IDS, |
|
"streamer": streamer, |
|
} |
|
|
|
assert decoding_strategy in [ |
|
"Greedy", |
|
"Top P Sampling", |
|
] |
|
if decoding_strategy == "Greedy": |
|
generation_args["do_sample"] = False |
|
elif decoding_strategy == "Top P Sampling": |
|
generation_args["temperature"] = temperature |
|
generation_args["do_sample"] = True |
|
generation_args["top_p"] = top_p |
|
|
|
|
|
|
|
input_text, images = prompt_list_to_model_input(formated_prompt_list) |
|
inputs = create_model_inputs([input_text], [images]) |
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
generation_args.update(inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
thread = Thread( |
|
target=MODELS[model_selector].generate, |
|
kwargs=generation_args, |
|
) |
|
thread.start() |
|
|
|
print("start generating") |
|
acc_text = "" |
|
try: |
|
for text_token in streamer: |
|
acc_text += text_token |
|
time.sleep(0.03) |
|
yield acc_text |
|
except Exception as e: |
|
print("error") |
|
gr.Error(e) |
|
print(f"Success! Generated the following sequence: `{acc_text}`") |
|
|
|
|
|
|
|
|
|
max_new_tokens = gr.Slider( |
|
minimum=8, |
|
maximum=1024, |
|
value=512, |
|
step=1, |
|
interactive=True, |
|
label="Maximum number of new tokens to generate", |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=0.01, |
|
maximum=5.0, |
|
value=1.0, |
|
step=0.01, |
|
interactive=True, |
|
label="Repetition penalty", |
|
info="1.0 is equivalent to no penalty", |
|
) |
|
decoding_strategy = gr.Radio( |
|
[ |
|
"Greedy", |
|
"Top P Sampling", |
|
], |
|
value="Greedy", |
|
label="Decoding strategy", |
|
interactive=True, |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=5.0, |
|
value=0.4, |
|
step=0.1, |
|
interactive=True, |
|
label="Sampling temperature", |
|
info="Higher values will produce more diverse outputs.", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.8, |
|
step=0.01, |
|
interactive=True, |
|
label="Top P", |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
label="IDEFICS2", |
|
avatar_images=[None, BOT_AVATAR], |
|
height=500, |
|
) |
|
|
|
|
|
with gr.Blocks(fill_height=True) as demo: |
|
|
|
with gr.Row(elem_id="model_selector_row"): |
|
model_selector = gr.Dropdown( |
|
choices=MODELS.keys(), |
|
value=list(MODELS.keys())[0], |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
label="Model", |
|
visible=True, |
|
) |
|
|
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=temperature, |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), |
|
inputs=decoding_strategy, |
|
outputs=top_p, |
|
) |
|
|
|
gr.ChatInterface( |
|
fn=model_inference, |
|
chatbot=chatbot, |
|
|
|
title="Idefics2 Playground", |
|
multimodal=True, |
|
additional_inputs=[model_selector, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p], |
|
) |
|
|
|
demo.launch() |
|
|