import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images from PIL import Image import numpy as np import os import time from Upsample import RealESRGAN import spaces # Import spaces for ZeroGPU compatibility # Load model and processor model_path = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: vl_gpt = vl_gpt.to(torch.float16) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' # SR model sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2) sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False) @torch.inference_mode() @spaces.GPU(duration=120) def multimodal_understanding(image, question, seed, top_p, temperature): # Clear CUDA cache before generating torch.cuda.empty_cache() # Set seed torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed(seed) conversation = [ { "role": "<|User|>", "content": f"\n{question}", "images": [image], }, {"role": "<|Assistant|>", "content": ""}, ] pil_images = [Image.fromarray(image)] if isinstance(image, np.ndarray) else [image] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) outputs = vl_gpt.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask, pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, do_sample=False if temperature == 0 else True, use_cache=True, temperature=temperature, top_p=top_p, ) answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) return answer def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=5): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img @torch.inference_mode() @spaces.GPU(duration=120) def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0): torch.cuda.empty_cache() if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) width = 384 height = 384 parallel_size = 5 with torch.no_grad(): messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='' ) text = text + vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size, temperature=t2i_temperature) images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size) stime = time.time() ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)] print(f'upsample time: {time.time() - stime}') return ret_images @spaces.GPU(duration=60) def image_upsample(img: Image.Image) -> Image.Image: if img is None: raise Exception("Image not uploaded") width, height = img.size if width >= 5000 or height >= 5000: raise Exception("The image is too large.") global sr_model result = sr_model.predict(img.convert('RGB')) return result # Custom CSS for a sleek, modern and highly readable interface custom_css = """ body { background: #f0f2f5; font-family: 'Segoe UI', sans-serif; color: #333; } h1, h2, h3 { font-weight: 600; } .gradio-container { padding: 20px; } header { text-align: center; padding: 20px; margin-bottom: 20px; } header h1 { font-size: 3em; color: #2c3e50; } .gr-button { background-color: #3498db !important; color: #fff !important; border: none !important; padding: 10px 20px !important; border-radius: 5px !important; font-size: 1em !important; } .gr-button:hover { background-color: #2980b9 !important; } .gr-input, .gr-slider, .gr-number, .gr-textbox { border-radius: 5px; } .gr-gallery-item { border-radius: 10px; overflow: hidden; box-shadow: 0 2px 10px rgba(0,0,0,0.1); } """ # Gradio Interface with gr.Blocks(css=custom_css, title="Multimodal & T2I Demo") as demo: with gr.Column(variant="panel"): gr.Markdown("

Janus Multimodal Demo

") with gr.Tabs(): with gr.TabItem("Multimodal Understanding"): gr.Markdown("### Chat with Images") with gr.Row(): image_input = gr.Image(label="Upload Image", type="numpy", tool="editor") with gr.Column(): question_input = gr.Textbox(label="Question", placeholder="Enter your question about the image here...", lines=4) und_seed_input = gr.Number(label="Seed", precision=0, value=42) top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top_p") temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature") understanding_button = gr.Button("Chat", elem_id="understanding-button") understanding_output = gr.Textbox(label="Response", lines=6) with gr.Accordion("Examples", open=False): gr.Examples( label="Multimodal Understanding Examples", examples=[ ["explain this meme", "doge.png"], ["Convert the formula into LaTeX code.", "equation.png"], ], inputs=[question_input, image_input], ) understanding_button.click( multimodal_understanding, inputs=[image_input, question_input, und_seed_input, top_p, temperature], outputs=understanding_output, ) with gr.TabItem("Text-to-Image Generation"): gr.Markdown("### Generate Images from Text") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="Enter detailed prompt for image generation...", lines=4) with gr.Row(): seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234) cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight") t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature") generation_button = gr.Button("Generate Images", elem_id="generation-button") image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300) with gr.Accordion("Examples", open=False): gr.Examples( label="Text-to-Image Examples", examples=[ "Master shifu racoon wearing drip attire as a street gangster.", "The face of a beautiful girl", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting, immortal, fluffy, shiny mane, petals, fairyism, unreal engine 5 and Octane Render, highly detailed, photorealistic, cinematic, natural colors.", "An intricately designed eye with ornate swirl patterns, vivid blue iris, and classical architectural motifs, exuding mysterious timelessness." ], inputs=prompt_input, ) generation_button.click( fn=generate_image, inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature], outputs=image_output, ) gr.Markdown("") demo.launch(share=True)