import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images from PIL import Image import numpy as np import os import time from Upsample import RealESRGAN import spaces # Import spaces for ZeroGPU compatibility # Load model and processor model_path = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: vl_gpt = vl_gpt.to(torch.float16) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' # SR model sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2) sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False) @torch.inference_mode() @spaces.GPU(duration=120) def multimodal_understanding(image, question, seed, top_p, temperature): # Clear CUDA cache before generating torch.cuda.empty_cache() # Set seed torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed(seed) conversation = [ { "role": "<|User|>", "content": f"\n{question}", "images": [image], }, {"role": "<|Assistant|>", "content": ""}, ] pil_images = [Image.fromarray(image)] if isinstance(image, np.ndarray) else [image] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) outputs = vl_gpt.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask, pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, do_sample=False if temperature == 0 else True, use_cache=True, temperature=temperature, top_p=top_p, ) answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) return answer def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=5): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img @torch.inference_mode() @spaces.GPU(duration=120) def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0): torch.cuda.empty_cache() if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) width = 384 height = 384 parallel_size = 5 with torch.no_grad(): messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='' ) text = text + vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size, temperature=t2i_temperature) images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size) stime = time.time() ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)] print(f'upsample time: {time.time() - stime}') return ret_images @spaces.GPU(duration=60) def image_upsample(img: Image.Image) -> Image.Image: if img is None: raise Exception("Image not uploaded") width, height = img.size if width >= 5000 or height >= 5000: raise Exception("The image is too large.") global sr_model result = sr_model.predict(img.convert('RGB')) return result # Custom CSS for a sleek, modern and highly readable interface custom_css = """ body { background: #f0f2f5; font-family: 'Segoe UI', sans-serif; color: #333; } h1, h2, h3 { font-weight: 600; } .gradio-container { padding: 20px; } header { text-align: center; padding: 20px; margin-bottom: 20px; } header h1 { font-size: 3em; color: #2c3e50; } .gr-button { background-color: #3498db !important; color: #fff !important; border: none !important; padding: 10px 20px !important; border-radius: 5px !important; font-size: 1em !important; } .gr-button:hover { background-color: #2980b9 !important; } .gr-input, .gr-slider, .gr-number, .gr-textbox { border-radius: 5px; } .gr-gallery-item { border-radius: 10px; overflow: hidden; box-shadow: 0 2px 10px rgba(0,0,0,0.1); } """ # Gradio Interface with gr.Blocks(css=custom_css, title="Multimodal & T2I Demo") as demo: with gr.Column(variant="panel"): gr.Markdown("

Elegant Janus Multimodal & T2I Demo

") with gr.Tabs(): with gr.TabItem("Multimodal Understanding"): gr.Markdown("### Chat with Images") with gr.Row(): image_input = gr.Image(label="Upload Image", type="numpy") with gr.Column(): question_input = gr.Textbox(label="Question", placeholder="Enter your question about the image here...", lines=4) und_seed_input = gr.Number(label="Seed", precision=0, value=42) top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top_p") temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature") understanding_button = gr.Button("Chat", elem_id="understanding-button") understanding_output = gr.Textbox(label="Response", lines=6) with gr.Accordion("Examples", open=False): gr.Examples( label="Multimodal Understanding Examples", examples=[ ["explain this meme", "doge.png"], ["이 이미지를 설명해줘", "korean_example.png"] ], inputs=[question_input, image_input], ) understanding_button.click( multimodal_understanding, inputs=[image_input, question_input, und_seed_input, top_p, temperature], outputs=understanding_output, ) with gr.TabItem("Text-to-Image Generation"): gr.Markdown("### Generate Images from Text") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="Enter detailed prompt for image generation...", lines=4) with gr.Row(): seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234) cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight") t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature") generation_button = gr.Button("Generate Images", elem_id="generation-button") image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300) with gr.Accordion("Examples", open=False): gr.Examples( label="Text-to-Image Examples", examples=[ "Master shifu racoon wearing drip attire as a street gangster.", "The face of a beautiful girl", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting, immortal, fluffy, shiny mane, petals, fairyism, unreal engine 5 and Octane Render, highly detailed, photorealistic, cinematic, natural colors.", "고양이가 우주복을 입고 달에 있는 모습" ], inputs=prompt_input, ) generation_button.click( fn=generate_image, inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature], outputs=image_output, ) gr.Markdown("") demo.launch(share=True)