|
import os |
|
import subprocess |
|
|
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
|
|
import copy |
|
import spaces |
|
import time |
|
import torch |
|
|
|
from threading import Thread |
|
from typing import List, Dict, Union |
|
from urllib.parse import urlparse |
|
from PIL import Image |
|
|
|
import gradio as gr |
|
from transformers import AutoProcessor, TextIteratorStreamer |
|
from transformers import Idefics2ForConditionalGeneration |
|
|
|
|
|
DEVICE = torch.device("cuda") |
|
MODELS = { |
|
"idefics2-8b (sft)": Idefics2ForConditionalGeneration.from_pretrained( |
|
"HuggingFaceM4/idefics2-8b", |
|
torch_dtype=torch.bfloat16, |
|
_attn_implementation="flash_attention_2", |
|
trust_remote_code=True, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
).to(DEVICE), |
|
"idefics2-8b-chatty (chat)": Idefics2ForConditionalGeneration.from_pretrained( |
|
"HuggingFaceM4/idefics2-8b-chatty", |
|
torch_dtype=torch.bfloat16, |
|
_attn_implementation="flash_attention_2", |
|
trust_remote_code=True, |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
).to(DEVICE), |
|
} |
|
PROCESSOR = AutoProcessor.from_pretrained( |
|
"HuggingFaceM4/idefics2-8b", |
|
token=os.environ["HF_AUTH_TOKEN"], |
|
) |
|
|
|
SYSTEM_PROMPT = [ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
] |
|
|
|
API_TOKEN = os.getenv("HF_AUTH_TOKEN") |
|
|
|
BOT_AVATAR = "IDEFICS_logo.png" |
|
|
|
|
|
|
|
def turn_is_pure_media(turn): |
|
return turn[1] is None |
|
|
|
|
|
def format_user_prompt_with_im_history_and_system_conditioning( |
|
user_prompt, chat_history |
|
) -> List[Dict[str, Union[List, str]]]: |
|
""" |
|
Produces the resulting list that needs to go inside the processor. |
|
It handles the potential image(s), the history and the system conditionning. |
|
""" |
|
resulting_messages = copy.deepcopy(SYSTEM_PROMPT) |
|
resulting_images = [] |
|
|
|
|
|
for turn in chat_history: |
|
if not resulting_messages or (resulting_messages and resulting_messages[-1]["role"] != "user"): |
|
resulting_messages.append( |
|
{ |
|
"role": "user", |
|
"content": [], |
|
} |
|
) |
|
|
|
if turn_is_pure_media(turn): |
|
media = turn[0][0] |
|
resulting_messages[-1]["content"].append({"type": "image"}) |
|
resulting_images.append(Image.open(media)) |
|
else: |
|
user_utterance, assistant_utterance = turn |
|
resulting_messages[-1]["content"].append( |
|
{"type": "text", "text": user_utterance.strip()} |
|
) |
|
resulting_messages.append( |
|
{ |
|
"role": "assistant", |
|
"content": [ |
|
{"type": "text", "text": user_utterance.strip()} |
|
] |
|
} |
|
) |
|
|
|
|
|
if not user_prompt["files"]: |
|
resulting_messages.append( |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "text", "text": user_prompt['text']} |
|
], |
|
} |
|
) |
|
else: |
|
|
|
resulting_messages.append( |
|
{ |
|
"role": "user", |
|
"content": [{"type": "image"}] * len(user_prompt['files']) + [ |
|
{"type": "text", "text": user_prompt['text']} |
|
] |
|
} |
|
) |
|
resulting_images.extend([Image.open(im['path']) for im in user_prompt['files']]) |
|
|
|
return resulting_messages, resulting_images |
|
|
|
|
|
def extract_images_from_msg_list(msg_list): |
|
all_images = [] |
|
for msg in msg_list: |
|
for c_ in msg["content"]: |
|
if isinstance(c_, Image.Image): |
|
all_images.append(c_) |
|
return all_images |
|
|
|
|
|
@spaces.GPU(duration=180) |
|
def model_inference( |
|
user_prompt, |
|
chat_history, |
|
model_selector, |
|
decoding_strategy, |
|
temperature, |
|
max_new_tokens, |
|
repetition_penalty, |
|
top_p, |
|
): |
|
if user_prompt["text"].strip() == "" and not user_prompt["files"]: |
|
gr.Error("Please input a query and optionally image(s).") |
|
|
|
if user_prompt["text"].strip() == "" and user_prompt["files"]: |
|
gr.Error("Please input a text query along the image(s).") |
|
|
|
for file in user_prompt["files"]: |
|
if not file["mime_type"].startswith("image/"): |
|
gr.Error("Idefics2 only supports images. Please input a valid image.") |
|
|
|
streamer = TextIteratorStreamer( |
|
PROCESSOR.tokenizer, |
|
skip_prompt=True, |
|
timeout=5., |
|
) |
|
|
|
|
|
|
|
generation_args = { |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
"streamer": streamer, |
|
} |
|
|
|
assert decoding_strategy in [ |
|
"Greedy", |
|
"Top P Sampling", |
|
] |
|
if decoding_strategy == "Greedy": |
|
generation_args["do_sample"] = False |
|
elif decoding_strategy == "Top P Sampling": |
|
generation_args["temperature"] = temperature |
|
generation_args["do_sample"] = True |
|
generation_args["top_p"] = top_p |
|
|
|
|
|
resulting_text, resulting_images = format_user_prompt_with_im_history_and_system_conditioning( |
|
user_prompt=user_prompt, |
|
chat_history=chat_history, |
|
) |
|
prompt = PROCESSOR.apply_chat_template(resulting_text, add_generation_prompt=True) |
|
inputs = PROCESSOR(text=prompt, images=resulting_images if resulting_images else None, return_tensors="pt") |
|
inputs = {k: v.to(DEVICE) for k, v in inputs.items()} |
|
generation_args.update(inputs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
thread = Thread( |
|
target=MODELS[model_selector].generate, |
|
kwargs=generation_args, |
|
) |
|
thread.start() |
|
|
|
print("Start generating") |
|
acc_text = "" |
|
for text_token in streamer: |
|
time.sleep(0.04) |
|
acc_text += text_token |
|
if acc_text.endswith("<end_of_utterance>"): |
|
acc_text = acc_text[:-18] |
|
yield acc_text |
|
print("Success - generated the following text:", acc_text) |
|
print("-----") |
|
|
|
|
|
|
|
max_new_tokens = gr.Slider( |
|
minimum=8, |
|
maximum=1024, |
|
value=512, |
|
step=1, |
|
interactive=True, |
|
label="Maximum number of new tokens to generate", |
|
) |
|
repetition_penalty = gr.Slider( |
|
minimum=0.01, |
|
maximum=5.0, |
|
value=1.1, |
|
step=0.01, |
|
interactive=True, |
|
label="Repetition penalty", |
|
info="1.0 is equivalent to no penalty", |
|
) |
|
decoding_strategy = gr.Radio( |
|
[ |
|
"Greedy", |
|
"Top P Sampling", |
|
], |
|
value="Greedy", |
|
label="Decoding strategy", |
|
interactive=True, |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
temperature = gr.Slider( |
|
minimum=0.0, |
|
maximum=5.0, |
|
value=0.4, |
|
step=0.1, |
|
interactive=True, |
|
label="Sampling temperature", |
|
info="Higher values will produce more diverse outputs.", |
|
) |
|
top_p = gr.Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.8, |
|
step=0.01, |
|
interactive=True, |
|
label="Top P", |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
) |
|
|
|
|
|
chatbot = gr.Chatbot( |
|
label="Idefics2", |
|
avatar_images=[None, BOT_AVATAR], |
|
|
|
) |
|
|
|
|
|
with gr.Blocks(fill_height=True, css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img { width: auto; max-width: 30%; height: auto; max-height: 30%; }") as demo: |
|
|
|
with gr.Row(elem_id="model_selector_row"): |
|
model_selector = gr.Dropdown( |
|
choices=MODELS.keys(), |
|
value=list(MODELS.keys())[0], |
|
interactive=True, |
|
show_label=False, |
|
container=False, |
|
label="Model", |
|
visible=True, |
|
) |
|
|
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=temperature, |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider( |
|
visible=( |
|
selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"] |
|
) |
|
), |
|
inputs=decoding_strategy, |
|
outputs=repetition_penalty, |
|
) |
|
decoding_strategy.change( |
|
fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])), |
|
inputs=decoding_strategy, |
|
outputs=top_p, |
|
) |
|
examples =[ |
|
["./example_images/docvqa_example.png", "How many items are sold?", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["./example_images/example_images_travel_tips.jpg", "I want to go somewhere similar to the one in the photo. Give me destinations and travel tips.", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["./example_images/baklava.png", "Where is this pastry from?", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["./example_images/dummy_pdf.png", "How much percent is the order status?", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["./example_images/art_critic.png", "As an art critic AI assistant, could you describe this painting in details and make a thorough critic?.", "Greedy", 0.4, 512, 1.2, 0.8], |
|
["./example_images/s2w_example.png", "What is this UI about?", "Greedy", 0.4, 512, 1.2, 0.8] |
|
] |
|
|
|
description = "Try [IDEFICS2-8B](https://huggingface.co/HuggingFaceM4/idefics2-8b), the instruction fine-tuned IDEFICS2, and [IDEFICS2 Chatty](CHATTY_LINK_HERE) in this demo. 🐶💬 IDEFICS2 is a state-of-the-art vision language model in various benchmarks. To get started, upload an image and write a text prompt or try one of the examples. You can also play with advanced generation parameters. To learn more about IDEFICS2, read [the blog](https://huggingface.co/blog/idefics2)." |
|
gr.ChatInterface( |
|
fn=model_inference, |
|
chatbot=chatbot, |
|
examples=examples, |
|
title="Idefics2 Playground 💬", |
|
multimodal=True, |
|
description=description, |
|
additional_inputs=[model_selector, decoding_strategy, temperature, max_new_tokens, repetition_penalty, top_p], |
|
) |
|
|
|
demo.launch() |
|
|