import os os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "False" os.environ["TOKENIZERS_PARALLELISM"] = "true" import tempfile from share_btn import share_js, save_js import gradio as gr from PIL import Image import torch from omegaconf import OmegaConf from transformers import AutoTokenizer from models import Showo, MAGVITv2, get_mask_chedule from prompting_utils import UniversalPrompting, create_attention_mask_predict_next # Prepare model config = OmegaConf.load("configs/showo_demo.yaml") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained(config.model.showo.llm_model_path, padding_side="left") uni_prompting = UniversalPrompting(tokenizer, max_text_len=config.dataset.preprocessing.max_seq_length, special_tokens=("<|soi|>", "<|eoi|>", "<|sov|>", "<|eov|>", "<|t2i|>", "<|mmu|>", "<|t2v|>", "<|v2v|>", "<|lvg|>"), ignore_id=-100, cond_dropout_prob=config.training.cond_dropout_prob) vq_model = MAGVITv2(config.model.vq_model.type) vq_model = vq_model.from_pretrained(config.model.vq_model.vq_model_name).to(device) vq_model.requires_grad_(False) vq_model.eval() model = Showo.from_pretrained(config.model.showo.pretrained_model_path).to(device) model.eval() mask_token_id = model.config.mask_token_id css = """ #chatbot { min-height: 300px; } #save-btn { background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); } #save-btn:hover { background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0)); } #share-btn { background-image: linear-gradient(to right bottom, rgba(130,217,244, 0.9), rgba(158,231,214, 1.0)); } #share-btn:hover { background-image: linear-gradient(to right bottom, rgba(110,197,224, 0.9), rgba(138,211,194, 1.0)); } #gallery { z-index: 999999; } #gallery img:hover {transform: scale(2.3); z-index: 999999; position: relative; padding-right: 30%; padding-bottom: 30%;} #gallery button img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; padding-bottom: 0;} @media (hover: none) { #gallery img:hover {transform: none; z-index: 999999; position: relative; padding-right: 0; 0;} } .html2canvas-container { width: 3000px !important; height: 3000px !important; } """ def upload_image(state, image_input): conversation = state[0] chat_history = state[1] input_image = Image.open(image_input.name).resize( (224, 224)).convert('RGB') input_image.save(image_input.name) # Overwrite with smaller image. conversation += [(f'', "")] return [conversation, chat_history + [input_image, ""]], conversation def reset(): return [[], []], [] def reset_last(state): conversation = state[0][:-1] chat_history = state[1][:-2] return [conversation, chat_history], conversation def save_image_to_local(image: Image.Image): filename = next(tempfile._get_candidate_names()) + '.png' image.save(filename) return filename def text_to_image_generation(input_text, state, guidance_scale, generation_timesteps): prompts = [input_text] config.training.batch_size = config.batch_size = 1 config.training.guidance_scale = config.guidance_scale = guidance_scale config.training.generation_timesteps = config.generation_timesteps = generation_timesteps image_tokens = torch.ones((len(prompts), config.model.showo.num_vq_tokens), dtype=torch.long, device=device) * mask_token_id input_ids, _ = uni_prompting((prompts, image_tokens), 't2i_gen') if config.training.guidance_scale > 0: uncond_input_ids, _ = uni_prompting(([''] * len(prompts), image_tokens), 't2i_gen') attention_mask = create_attention_mask_predict_next(torch.cat([input_ids, uncond_input_ids], dim=0), pad_id=int(uni_prompting.sptids_dict['<|pad|>']), soi_id=int(uni_prompting.sptids_dict['<|soi|>']), eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']), rm_pad_in_image=True) else: attention_mask = create_attention_mask_predict_next(input_ids, pad_id=int(uni_prompting.sptids_dict['<|pad|>']), soi_id=int(uni_prompting.sptids_dict['<|soi|>']), eoi_id=int(uni_prompting.sptids_dict['<|eoi|>']), rm_pad_in_image=True) uncond_input_ids = None if config.get("mask_schedule", None) is not None: schedule = config.mask_schedule.schedule args = config.mask_schedule.get("params", {}) mask_schedule = get_mask_chedule(schedule, **args) else: mask_schedule = get_mask_chedule(config.training.get("mask_schedule", "cosine")) with torch.no_grad(): gen_token_ids = model.t2i_generate( input_ids=input_ids, uncond_input_ids=uncond_input_ids, attention_mask=attention_mask, guidance_scale=config.training.guidance_scale, temperature=config.training.get("generation_temperature", 1.0), timesteps=config.training.generation_timesteps, noise_schedule=mask_schedule, noise_type=config.training.get("noise_type", "mask"), seq_len=config.model.showo.num_vq_tokens, uni_prompting=uni_prompting, config=config, ) gen_token_ids = torch.clamp(gen_token_ids, max=config.model.showo.codebook_size - 1, min=0) images = vq_model.decode_code(gen_token_ids) images = torch.clamp((images + 1.0) / 2.0, min=0.0, max=1.0) images *= 255.0 images = images.permute(0, 2, 3, 1).cpu().numpy().astype(np.uint8) pil_images = [Image.fromarray(image) for image in images] wandb_images = [wandb.Image(image, caption=prompts[i]) for i, image in enumerate(pil_images)] wandb.log({"generated_images": wandb_images}, step=step) def generate_for_prompt(input_text, state, ret_scale_factor, num_words, temperature): g_cuda = torch.Generator(device='cuda').manual_seed(1337) # Ignore empty inputs. if len(input_text) == 0: return state, state[0], gr.update(visible=True) input_prompt = 'Q: ' + input_text + '\nA:' conversation = state[0] chat_history = state[1] print('Generating for', chat_history, flush=True) # If an image was uploaded, prepend it to the model. model_inputs = chat_history model_inputs.append(input_prompt) # Remove empty text. model_inputs = [s for s in model_inputs if s != ''] top_p = 1.0 if temperature != 0.0: top_p = 0.95 print('Running model.generate_for_images_and_texts with', model_inputs, flush=True) model_outputs = model.generate_for_images_and_texts(model_inputs, num_words=max(num_words, 1), ret_scale_factor=ret_scale_factor, top_p=top_p, temperature=temperature, max_num_rets=1, num_inference_steps=50, generator=g_cuda) print('model_outputs', model_outputs, ret_scale_factor, flush=True) response = '' text_outputs = [] for output_i, p in enumerate(model_outputs): if type(p) == str: if output_i > 0: response += '
' # Remove the image tokens for output. text_outputs.append(p.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', '')) response += p if len(model_outputs) > 1: response += '
' elif type(p) == dict: # Decide whether to generate or retrieve. if p['decision'] is not None and p['decision'][0] == 'gen': image = p['gen'][0][0]#.resize((224, 224)) filename = save_image_to_local(image) response += f'

(Generated)

' else: image = p['ret'][0][0]#.resize((224, 224)) filename = save_image_to_local(image) response += f'

(Retrieved)

' chat_history = model_inputs + \ [' '.join([s for s in model_outputs if type(s) == str]) + '\n'] # Remove [RET] from outputs. conversation.append((input_text, response.replace('[IMG0] [IMG1] [IMG2] [IMG3] [IMG4] [IMG5] [IMG6] [IMG7]', ''))) # Set input image to None. print('state', state, flush=True) print('updated state', [conversation, chat_history], flush=True) return [conversation, chat_history], conversation, gr.update(visible=True), gr.update(visible=True) with gr.Blocks(css=css) as demo: gr.HTML("""

🐟 GILL

This is the official Gradio demo for the GILL model, a model that can process arbitrarily interleaved image and text inputs, and produce image and text outputs.

Paper: Generating Images with Multimodal Language Models
Project Website: GILL Website
Code and Models: GitHub

Tips: """) gr_state = gr.State([[], []]) # conversation, chat_history with gr.Row(): with gr.Column(scale=0.7, min_width=500): with gr.Row(): chatbot = gr.Chatbot(elem_id="chatbot", label="🐟 GILL Chatbot") with gr.Row(): image_btn = gr.UploadButton("🖼️ Upload Image", file_types=["image"]) text_input = gr.Textbox(label="Message", placeholder="Type a message") with gr.Column(): submit_btn = gr.Button("Submit", interactive=True, variant="primary") clear_last_btn = gr.Button("Undo") clear_btn = gr.Button("Reset All") with gr.Row(visible=False) as save_group: save_button = gr.Button("💾 Save Conversation as .png", elem_id="save-btn") with gr.Row(visible=False) as share_group: share_button = gr.Button("🤗 Share to Community (opens new window)", elem_id="share-btn") with gr.Column(scale=0.3, min_width=400): ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.3, step=0.1, interactive=True, label="Frequency multiplier for returning images (higher means more frequent)") gr_max_len = gr.Slider(minimum=1, maximum=64, value=32, step=1, interactive=True, label="Max # of words") gr_temperature = gr.Slider( minimum=0.0, maximum=1.0, value=0.0, step=0.1, interactive=True, label="Temperature (0 for deterministic, higher for more randomness)") gallery = gr.Gallery( value=[Image.open(e) for e in examples], label="Example Conversations", show_label=True, elem_id="gallery", ).style(grid=[2], height="auto") text_input.submit(generate_for_prompt, [text_input, gr_state, ret_scale_factor, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) text_input.submit(lambda: "", None, text_input) # Reset chatbox. submit_btn.click(generate_for_prompt, [text_input, gr_state, ret_scale_factor, gr_max_len, gr_temperature], [gr_state, chatbot, share_group, save_group]) submit_btn.click(lambda: "", None, text_input) # Reset chatbox. image_btn.upload(upload_image, [gr_state, image_btn], [gr_state, chatbot]) clear_last_btn.click(reset_last, [gr_state], [gr_state, chatbot]) clear_btn.click(reset, [], [gr_state, chatbot]) share_button.click(None, [], [], _js=share_js) save_button.click(None, [], [], _js=save_js) demo.queue(concurrency_count=1, api_open=False, max_size=16) demo.launch(debug=True, server_name="0.0.0.0")