|
import copy |
|
import os |
|
import spaces |
|
import subprocess |
|
import time |
|
import torch |
|
|
|
from threading import Thread |
|
from typing import List, Dict, Union |
|
from urllib.parse import urlparse |
|
from PIL import Image |
|
|
|
import gradio as gr |
|
from transformers import AutoProcessor, TextIteratorStreamer |
|
from transformers import Idefics2ForConditionalGeneration |
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
DEVICE = torch.device("cuda") |
|
MODELS = { |
|
"idefics2 lima 200": Idefics2ForConditionalGeneration.from_pretrained( |
|
"HuggingFaceM4/idefics2-tfrm-compatible", |
|
torch_dtype=torch.bfloat16, |
|
_attn_implementation="flash_attention_2", |
|
trust_remote_code=True, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
revision="11794e2ae02dbf1c55d0ebd92c28e5b0b604cf5f", |
|
).to(DEVICE), |
|
"idefics2 sft 12600": Idefics2ForConditionalGeneration.from_pretrained( |
|
"HuggingFaceM4/idefics2-tfrm-compatible", |
|
torch_dtype=torch.bfloat16, |
|
_attn_implementation="flash_attention_2", |
|
trust_remote_code=True, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
revision="86f134822798266d0d8db049cc6458c625e32344", |
|
).to(DEVICE), |
|
} |
|
PROCESSOR = AutoProcessor.from_pretrained( |
|
"HuggingFaceM4/idefics2-tfrm-compatible", |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
) |
|
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids |
|
EOS_WORDS_IDS = PROCESSOR.tokenizer("<end_of_utterance>", add_special_tokens=False).input_ids + [PROCESSOR.tokenizer.eos_token_id] |
|
|
|
SYSTEM_PROMPT = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
API_TOKEN = os.getenv("HF_AUTH_TOKEN") |
|
|
|
BOT_AVATAR = "IDEFICS_logo.png" |
|
|
|
|
|
|
|
def turn_is_pure_media(turn): |
|
return turn[1] is None |
|
|
|
|
|
def format_user_prompt_with_im_history_and_system_conditioning( |
|
user_prompt, chat_history |
|
) -> List[Dict[str, Union[List, str]]]: |
|
""" |
|
Produces the resulting list that needs to go inside the processor. |
|
It handles the potential image(s), the history and the system conditionning. |
|
""" |
|
resulting_list = copy.deepcopy(SYSTEM_PROMPT) |
|
|
|
|
|
for turn in chat_history: |
|
if not resulting_list or (resulting_list and resulting_list[-1]["role"] != "user"): |
|
resulting_list.append( |
|
{ |
|
"role": "user", |
|
"content": [], |
|
} |
|
) |
|
|
|
if turn_is_pure_media(turn): |
|
media = turn[0][0] |
|
resulting_list[-1]["content"].append(Image.open(media)) |
|
else: |
|
user_utterance, assistant_utterance = turn |
|
resulting_list[-1]["content"].append(user_utterance.strip()) |
|
resulting_list.append( |
|
{ |
|
"role": "assistant", |
|
"content": [assistant_utterance] |
|
} |
|
) |
|
|
|
|
|
if not user_prompt["files"]: |
|
resulting_list.append( |
|
{ |
|
"role": "user", |
|
"content": [user_prompt['text']], |
|
} |
|
) |
|
else: |
|
|
|
resulting_list.append( |
|
{ |
|
"role": "user", |
|
"content": [Image.open(im['path']) for im in user_prompt['files']] + [user_prompt['text']], |
|
} |
|
) |
|
|
|
return resulting_list |
|
|
|
|
|
def extract_images_from_msg_list(msg_list): |
|
all_images = [] |
|
for msg in msg_list: |
|
for c_ in msg["content"]: |
|
if isinstance(c_, Image.Image): |
|
all_images.append(c_) |
|
return all_images |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def model_inference( |
|
user_prompt, |
|
chat_history, |
|
model_selector, |
|
decoding_strategy, |
|
temperature, |
|
max_new_tokens, |
|
repetition_penalty, |
|
top_p, |
|
): |
|
if user_prompt["text"].strip() == "" and not user_prompt["files"]: |
|
gr.Error("Please input a query and optionally image(s).") |
|
|
|
if user_prompt["text"].strip() == "" and user_prompt["files"]: |
|
gr.Error("Please input a text query along the image(s).") |
|
|
|
for file in user_prompt["files"]: |
|
if not file["mime_type"].startswith("image/"): |
|
gr.Error("Idefics2 only supports images. Please input a valid image.") |
|
|
|
streamer = TextIteratorStreamer( |
|
PROCESSOR.tokenizer, |
|
skip_prompt=True, |
|
timeout=5., |
|
) |
|
|
|
|
|
|
|
generation_args = { |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
"bad_words_ids": BAD_WORDS_IDS, |
|
"eos_token_id": EOS_WORDS_IDS, |
|
"streamer": streamer, |
|
} |
|
|
|
assert decoding_strategy in [ |
|
"Greedy", |
|
"Top P Sampling", |
|
] |
|
if decoding_strategy == "Greedy": |
|
generation_args["do_sample"] = False |
|
elif decoding_strategy == "Top P Sampling": |
|
generation_args["temperature"] = temperature |
|
generation_args["do_sample"] = True |
|
generation_args["top_p"] = top_p |
|
|
|
|
|
formated_prompt_list = format_user_prompt_with_im_history_and_system_conditioning( |
|
user_prompt=user_prompt, |
|
chat_history=chat_history, |
|
) |
|
inputs = PROCESSOR.apply_chat_template(formated_prompt_list, add_generation_prompt=True, return_tensors="pt") |
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
generation_args.update(inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
thread = Thread( |
|
target=MODELS[model_selector].generate, |
|
kwargs=generation_args, |
|
) |
|
thread.start() |
|
|
|
print("start generating") |
|
acc_text = "" |
|
for text_token in streamer: |
|
time.sleep(0.04) |
|
acc_text += text_token |
|
if acc_text.endswith("<end_of_utterance>"): |
|
acc_text = acc_text[:-18] |
|
yield acc_text |
|
print("success - generated the following text:", acc_text) |
|
|
|
|
|
|
|
max_new_tokens = gr.Slider( |
|
minimum=8, |
|
maximum=1024, |
|
value=512, |
|
step=1, |
|
interactive=True, |
|
label="Maximum number of new tokens to generate", |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=0.01, |
|
maximum=5.0, |
|
value=1.0, |
|
step=0.01, |
|
interactive=True, |
|
label="Repetition penalty", |
|
info="1.0 is equivalent to no penalty", |
|
) |
|
decoding_strategy = gr.Radio( |
|
[ |
|
"Greedy", |
|
"Top P Sampling", |
|
], |
|
value="Greedy", |
|
label="Decoding strategy", |
|
interactive=True, |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=5.0, |
|
value=0.4, |
|
step=0.1, |
|
interactive=True, |
|
label="Sampling temperature", |
|
info="Higher values will produce more diverse outputs.", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.8, |
|
step=0.01, |
|
interactive=True, |
|
label="Top P", |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
label="IDEFICS2", |
|
avatar_images=[None, BOT_AVATAR], |
|
height=750, |
|
) |
|
|
|
|
|
with gr.Blocks(fill_height=True) as demo: |
|
|
|
with gr.Row(elem_id="model_selector_row"): |
|
model_selector = gr.Dropdown( |
|
choices=MODELS.keys(), |
|
value=list(MODELS.keys())[0], |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
label="Model", |
|
visible=True, |
|
) |
|
|
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=temperature, |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), |
|
inputs=decoding_strategy, |
|
outputs=top_p, |
|
) |
|
|
|
gr.ChatInterface( |
|
fn=model_inference, |
|
chatbot=chatbot, |
|
|
|
title="Idefics2 Playground", |
|
multimodal=True, |
|
additional_inputs=[model_selector, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p], |
|
) |
|
|
|
demo.launch() |
|
|