|  | import copy | 
					
						
						|  | import os | 
					
						
						|  | import spaces | 
					
						
						|  | import subprocess | 
					
						
						|  | import time | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from threading import Thread | 
					
						
						|  | from typing import List, Tuple | 
					
						
						|  | from urllib.parse import urlparse | 
					
						
						|  | from PIL import Image | 
					
						
						|  |  | 
					
						
						|  | import gradio as gr | 
					
						
						|  | from gradio_client.client import DEFAULT_TEMP_DIR | 
					
						
						|  | from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer | 
					
						
						|  | from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension | 
					
						
						|  | from transformers.image_transforms import resize, to_channel_dimension_format | 
					
						
						|  |  | 
					
						
						|  | subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | 
					
						
						|  |  | 
					
						
						|  | DEVICE = torch.device("cuda") | 
					
						
						|  | MODELS = { | 
					
						
						|  | "288_ter - mix8 - opt 5'800": AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | "HuggingFaceM4/idefics2", | 
					
						
						|  | trust_remote_code=True, | 
					
						
						|  | torch_dtype=torch.bfloat16, | 
					
						
						|  | token=os.environ["HF_AUTH_TOKEN"], | 
					
						
						|  | revision="25bb7ad6d9ab9e43d5002d30f857d4106ed964f3", | 
					
						
						|  | ).to(DEVICE), | 
					
						
						|  | "288_ter - mix 8 - opt 11'000": AutoModelForCausalLM.from_pretrained( | 
					
						
						|  | "HuggingFaceM4/idefics2", | 
					
						
						|  | trust_remote_code=True, | 
					
						
						|  | torch_dtype=torch.bfloat16, | 
					
						
						|  | token=os.environ["HF_AUTH_TOKEN"], | 
					
						
						|  | revision="7eccbf5178f85eee8fab9995f31ab12441ce767a", | 
					
						
						|  | ).to(DEVICE), | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | } | 
					
						
						|  | PROCESSOR = AutoProcessor.from_pretrained( | 
					
						
						|  | "HuggingFaceM4/idefics2", | 
					
						
						|  | token=os.environ["HF_AUTH_TOKEN"], | 
					
						
						|  | ) | 
					
						
						|  | FAKE_TOK_AROUND_IMAGE = "<fake_token_around_image>" | 
					
						
						|  | BOS_TOKEN = PROCESSOR.tokenizer.bos_token | 
					
						
						|  | BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids | 
					
						
						|  | EOS_WORDS_IDS = PROCESSOR.tokenizer("<end_of_utterance>", add_special_tokens=False).input_ids + [PROCESSOR.tokenizer.eos_token_id] | 
					
						
						|  | IMAGE_SEQ_LEN = 64 | 
					
						
						|  |  | 
					
						
						|  | SYSTEM_PROMPT = [ | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ] | 
					
						
						|  |  | 
					
						
						|  | API_TOKEN = os.getenv("HF_AUTH_TOKEN") | 
					
						
						|  |  | 
					
						
						|  | BOT_AVATAR = "IDEFICS_logo.png" | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def convert_to_rgb(image): | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if image.mode == "RGB": | 
					
						
						|  | return image | 
					
						
						|  |  | 
					
						
						|  | image_rgba = image.convert("RGBA") | 
					
						
						|  | background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) | 
					
						
						|  | alpha_composite = Image.alpha_composite(background, image_rgba) | 
					
						
						|  | alpha_composite = alpha_composite.convert("RGB") | 
					
						
						|  | return alpha_composite | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def custom_transform(x): | 
					
						
						|  | x = convert_to_rgb(x) | 
					
						
						|  | x = to_numpy_array(x) | 
					
						
						|  |  | 
					
						
						|  | height, width = x.shape[:2] | 
					
						
						|  | aspect_ratio = width / height | 
					
						
						|  | if width >= height and width > 980: | 
					
						
						|  | width = 980 | 
					
						
						|  | height = int(width / aspect_ratio) | 
					
						
						|  | elif height > width and height > 980: | 
					
						
						|  | height = 980 | 
					
						
						|  | width = int(height * aspect_ratio) | 
					
						
						|  | width = max(width, 378) | 
					
						
						|  | height = max(height, 378) | 
					
						
						|  |  | 
					
						
						|  | x = resize(x, (height, width), resample=PILImageResampling.BILINEAR) | 
					
						
						|  | x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) | 
					
						
						|  | x = PROCESSOR.image_processor.normalize( | 
					
						
						|  | x, | 
					
						
						|  | mean=PROCESSOR.image_processor.image_mean, | 
					
						
						|  | std=PROCESSOR.image_processor.image_std | 
					
						
						|  | ) | 
					
						
						|  | x = to_channel_dimension_format(x, ChannelDimension.FIRST) | 
					
						
						|  | x = torch.tensor(x) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def create_model_inputs( | 
					
						
						|  | input_texts: List[str], | 
					
						
						|  | image_lists: List[List[Image.Image]], | 
					
						
						|  | ): | 
					
						
						|  | """ | 
					
						
						|  | All this logic will eventually be handled inside the model processor. | 
					
						
						|  | """ | 
					
						
						|  | inputs = PROCESSOR.tokenizer( | 
					
						
						|  | input_texts, | 
					
						
						|  | return_tensors="pt", | 
					
						
						|  | add_special_tokens=False, | 
					
						
						|  | padding=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | output_images = [ | 
					
						
						|  | [PROCESSOR.image_processor(img, transform=custom_transform) for img in im_list] | 
					
						
						|  | for im_list in image_lists | 
					
						
						|  | ] | 
					
						
						|  | total_batch_size = len(output_images) | 
					
						
						|  | max_num_images = max([len(img_l) for img_l in output_images]) | 
					
						
						|  | if max_num_images > 0: | 
					
						
						|  | max_height = max([i.size(2) for img_l in output_images for i in img_l]) | 
					
						
						|  | max_width = max([i.size(3) for img_l in output_images for i in img_l]) | 
					
						
						|  | padded_image_tensor = torch.zeros(total_batch_size, max_num_images, 3, max_height, max_width) | 
					
						
						|  | padded_pixel_attention_masks = torch.zeros( | 
					
						
						|  | total_batch_size, max_num_images, max_height, max_width, dtype=torch.bool | 
					
						
						|  | ) | 
					
						
						|  | for batch_idx, img_l in enumerate(output_images): | 
					
						
						|  | for img_idx, img in enumerate(img_l): | 
					
						
						|  | im_height, im_width = img.size()[2:] | 
					
						
						|  | padded_image_tensor[batch_idx, img_idx, :, :im_height, :im_width] = img | 
					
						
						|  | padded_pixel_attention_masks[batch_idx, img_idx, :im_height, :im_width] = True | 
					
						
						|  |  | 
					
						
						|  | inputs["pixel_values"] = padded_image_tensor | 
					
						
						|  | inputs["pixel_attention_mask"] = padded_pixel_attention_masks | 
					
						
						|  |  | 
					
						
						|  | return inputs | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def is_image(string: str) -> bool: | 
					
						
						|  | """ | 
					
						
						|  | There are two ways for images: local image path or url. | 
					
						
						|  | """ | 
					
						
						|  | return is_url(string) or string.startswith(DEFAULT_TEMP_DIR) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def is_url(string: str) -> bool: | 
					
						
						|  | """ | 
					
						
						|  | Checks if the passed string contains a valid url and nothing else. e.g. if space is included it's immediately | 
					
						
						|  | invalidated the url | 
					
						
						|  | """ | 
					
						
						|  | if " " in string: | 
					
						
						|  | return False | 
					
						
						|  | result = urlparse(string) | 
					
						
						|  | return all([result.scheme, result.netloc]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def prompt_list_to_model_input(prompt_list: List[str]) -> Tuple[str, List[Image.Image]]: | 
					
						
						|  | """ | 
					
						
						|  | Create the final input string and image list to feed to the model. | 
					
						
						|  | """ | 
					
						
						|  | images = [] | 
					
						
						|  | for idx, part in enumerate(prompt_list): | 
					
						
						|  | if is_image(part): | 
					
						
						|  | images.append(Image.open(part)) | 
					
						
						|  | prompt_list[idx] = f"{FAKE_TOK_AROUND_IMAGE}{'<image>' * IMAGE_SEQ_LEN}{FAKE_TOK_AROUND_IMAGE}" | 
					
						
						|  | input_text = "".join(prompt_list) | 
					
						
						|  | input_text = input_text.replace(FAKE_TOK_AROUND_IMAGE * 2, FAKE_TOK_AROUND_IMAGE) | 
					
						
						|  | input_text = BOS_TOKEN + input_text.strip() | 
					
						
						|  | return input_text, images | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def turn_is_pure_media(turn): | 
					
						
						|  | return turn[1] is None | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def format_user_prompt_with_im_history_and_system_conditioning( | 
					
						
						|  | user_prompt, chat_history | 
					
						
						|  | ) -> List[str]: | 
					
						
						|  | """ | 
					
						
						|  | Produces the resulting list that needs to go inside the processor. | 
					
						
						|  | It handles the potential image(s), the history and the system conditionning. | 
					
						
						|  | """ | 
					
						
						|  | resulting_list = copy.deepcopy(SYSTEM_PROMPT) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | for turn in chat_history: | 
					
						
						|  | if turn_is_pure_media(turn): | 
					
						
						|  | media = turn[0][0] | 
					
						
						|  | if resulting_list == [] or (resulting_list != [] and resulting_list[-1].endswith("<end_of_utterance>")): | 
					
						
						|  | resulting_list.append("\nUser:") | 
					
						
						|  | resulting_list.append(media) | 
					
						
						|  | else: | 
					
						
						|  | user_utterance, assistant_utterance = turn | 
					
						
						|  | if resulting_list and is_image(resulting_list[-1]): | 
					
						
						|  | resulting_list.append(f"{user_utterance.strip()}<end_of_utterance>\nAssistant: {assistant_utterance}<end_of_utterance>") | 
					
						
						|  | else: | 
					
						
						|  | resulting_list.append(f"\nUser: {user_utterance.strip()}<end_of_utterance>\nAssistant: {assistant_utterance}<end_of_utterance>") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if not user_prompt["files"]: | 
					
						
						|  | resulting_list.append(f"\nUser: ") | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  | resulting_list.append("\nUser:") | 
					
						
						|  | resulting_list.extend([im["path"] for im in user_prompt["files"]]) | 
					
						
						|  | resulting_list.append(f"{user_prompt['text']}<end_of_utterance>\nAssistant:") | 
					
						
						|  |  | 
					
						
						|  | return resulting_list | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | @spaces.GPU(duration=180) | 
					
						
						|  | def model_inference( | 
					
						
						|  | user_prompt, | 
					
						
						|  | chat_history, | 
					
						
						|  | model_selector, | 
					
						
						|  | decoding_strategy, | 
					
						
						|  | temperature, | 
					
						
						|  | max_new_tokens, | 
					
						
						|  | repetition_penalty, | 
					
						
						|  | top_p, | 
					
						
						|  | ): | 
					
						
						|  | if user_prompt["text"].strip() == "" and not user_prompt["files"]: | 
					
						
						|  | gr.Error("Please input a query and optionally image(s).") | 
					
						
						|  |  | 
					
						
						|  | if user_prompt["text"].strip() == "" and user_prompt["files"]: | 
					
						
						|  | gr.Error("Please input a text query along the image(s).") | 
					
						
						|  |  | 
					
						
						|  | for file in user_prompt["files"]: | 
					
						
						|  | if not file["mime_type"].startswith("image/"): | 
					
						
						|  | gr.Error("Idefics2 only supports images. Please input a valid image.") | 
					
						
						|  |  | 
					
						
						|  | formated_prompt_list = format_user_prompt_with_im_history_and_system_conditioning( | 
					
						
						|  | user_prompt=user_prompt, | 
					
						
						|  | chat_history=chat_history, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | streamer = TextIteratorStreamer( | 
					
						
						|  | PROCESSOR.tokenizer, | 
					
						
						|  | skip_prompt=True, | 
					
						
						|  | timeout=5., | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | generation_args = { | 
					
						
						|  | "max_new_tokens": max_new_tokens, | 
					
						
						|  | "repetition_penalty": repetition_penalty, | 
					
						
						|  | "bad_words_ids": BAD_WORDS_IDS, | 
					
						
						|  | "eos_token_id": EOS_WORDS_IDS, | 
					
						
						|  | "streamer": streamer, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | assert decoding_strategy in [ | 
					
						
						|  | "Greedy", | 
					
						
						|  | "Top P Sampling", | 
					
						
						|  | ] | 
					
						
						|  | if decoding_strategy == "Greedy": | 
					
						
						|  | generation_args["do_sample"] = False | 
					
						
						|  | elif decoding_strategy == "Top P Sampling": | 
					
						
						|  | generation_args["temperature"] = temperature | 
					
						
						|  | generation_args["do_sample"] = True | 
					
						
						|  | generation_args["top_p"] = top_p | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | input_text, images = prompt_list_to_model_input(formated_prompt_list) | 
					
						
						|  | inputs = create_model_inputs([input_text], [images]) | 
					
						
						|  | inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | 
					
						
						|  | generation_args.update(inputs) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | thread = Thread( | 
					
						
						|  | target=MODELS[model_selector].generate, | 
					
						
						|  | kwargs=generation_args, | 
					
						
						|  | ) | 
					
						
						|  | thread.start() | 
					
						
						|  |  | 
					
						
						|  | print("start generating") | 
					
						
						|  | acc_text = "" | 
					
						
						|  | try: | 
					
						
						|  | for text_token in streamer: | 
					
						
						|  | acc_text += text_token | 
					
						
						|  | time.sleep(0.03) | 
					
						
						|  | yield acc_text | 
					
						
						|  | except Exception as e: | 
					
						
						|  | print("error") | 
					
						
						|  | gr.Error(e) | 
					
						
						|  | print(f"Success! Generated the following sequence: `{acc_text}`") | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | max_new_tokens = gr.Slider( | 
					
						
						|  | minimum=8, | 
					
						
						|  | maximum=1024, | 
					
						
						|  | value=512, | 
					
						
						|  | step=1, | 
					
						
						|  | interactive=True, | 
					
						
						|  | label="Maximum number of new tokens to generate", | 
					
						
						|  | ) | 
					
						
						|  | repetition_penalty = gr.Slider( | 
					
						
						|  | minimum=0.01, | 
					
						
						|  | maximum=5.0, | 
					
						
						|  | value=1.0, | 
					
						
						|  | step=0.01, | 
					
						
						|  | interactive=True, | 
					
						
						|  | label="Repetition penalty", | 
					
						
						|  | info="1.0 is equivalent to no penalty", | 
					
						
						|  | ) | 
					
						
						|  | decoding_strategy = gr.Radio( | 
					
						
						|  | [ | 
					
						
						|  | "Greedy", | 
					
						
						|  | "Top P Sampling", | 
					
						
						|  | ], | 
					
						
						|  | value="Greedy", | 
					
						
						|  | label="Decoding strategy", | 
					
						
						|  | interactive=True, | 
					
						
						|  | info="Higher values is equivalent to sampling more low-probability tokens.", | 
					
						
						|  | ) | 
					
						
						|  | temperature = gr.Slider( | 
					
						
						|  | minimum=0.0, | 
					
						
						|  | maximum=5.0, | 
					
						
						|  | value=0.4, | 
					
						
						|  | step=0.1, | 
					
						
						|  | interactive=True, | 
					
						
						|  | label="Sampling temperature", | 
					
						
						|  | info="Higher values will produce more diverse outputs.", | 
					
						
						|  | ) | 
					
						
						|  | top_p = gr.Slider( | 
					
						
						|  | minimum=0.01, | 
					
						
						|  | maximum=0.99, | 
					
						
						|  | value=0.8, | 
					
						
						|  | step=0.01, | 
					
						
						|  | interactive=True, | 
					
						
						|  | label="Top P", | 
					
						
						|  | info="Higher values is equivalent to sampling more low-probability tokens.", | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | chatbot = gr.Chatbot( | 
					
						
						|  | label="IDEFICS2", | 
					
						
						|  | avatar_images=[None, BOT_AVATAR], | 
					
						
						|  | height=500, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | with gr.Blocks(fill_height=True) as demo: | 
					
						
						|  |  | 
					
						
						|  | with gr.Row(elem_id="model_selector_row"): | 
					
						
						|  | model_selector = gr.Dropdown( | 
					
						
						|  | choices=MODELS.keys(), | 
					
						
						|  | value=list(MODELS.keys())[0], | 
					
						
						|  | interactive=True, | 
					
						
						|  | show_label=False, | 
					
						
						|  | container=False, | 
					
						
						|  | label="Model", | 
					
						
						|  | visible=True, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | decoding_strategy.change( | 
					
						
						|  | fn=lambda selection: gr.Slider( | 
					
						
						|  | visible=( | 
					
						
						|  | selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] | 
					
						
						|  | ) | 
					
						
						|  | ), | 
					
						
						|  | inputs=decoding_strategy, | 
					
						
						|  | outputs=temperature, | 
					
						
						|  | ) | 
					
						
						|  | decoding_strategy.change( | 
					
						
						|  | fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), | 
					
						
						|  | inputs=decoding_strategy, | 
					
						
						|  | outputs=top_p, | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | gr.ChatInterface( | 
					
						
						|  | fn=model_inference, | 
					
						
						|  | chatbot=chatbot, | 
					
						
						|  |  | 
					
						
						|  | title="Idefics2 Playground", | 
					
						
						|  | multimodal=True, | 
					
						
						|  | additional_inputs=[model_selector, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p], | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  | demo.launch() | 
					
						
						|  |  |