import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM, pipeline as translation_pipeline from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images from PIL import Image import numpy as np import os import time from Upsample import RealESRGAN import spaces # Import spaces for ZeroGPU compatibility import re # 번역 파이프라인 초기화 (한글 → 영어) translator = translation_pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") def translate_if_korean(prompt: str) -> str: """프롬프트에 한글이 포함되어 있으면 영어로 번역""" if re.search(r'[ㄱ-ㅎㅏ-ㅣ가-힣]', prompt): try: translation = translator(prompt)[0]['translation_text'] return translation except Exception as e: print(f"Translation error: {e}") return prompt return prompt # Load model and processor model_path = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained( model_path, language_config=language_config, trust_remote_code=True ) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: vl_gpt = vl_gpt.to(torch.float16) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' # SR model sr_model = RealESRGAN(torch.device(cuda_device), scale=2) sr_model.load_weights('weights/RealESRGAN_x2.pth', download=False) @torch.inference_mode() @spaces.GPU(duration=120) def multimodal_understanding(image, question, seed, top_p, temperature): # (생략) 기존 multimodal 이해 함수 내용 그대로... torch.cuda.empty_cache() torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed(seed) conversation = [ { "role": "<|User|>", "content": f"\n{question}", "images": [image], }, {"role": "<|Assistant|>", "content": ""}, ] pil_images = [Image.fromarray(image)] if isinstance(image, np.ndarray) else [image] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) outputs = vl_gpt.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask, pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, do_sample=False if temperature == 0 else True, use_cache=True, temperature=temperature, top_p=top_p, ) answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) return answer def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model( inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv ) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code( generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size] ) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=5): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img @torch.inference_mode() @spaces.GPU(duration=120) def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0): # 번역: 입력 프롬프트에 한글이 포함되어 있으면 영어로 변환 prompt = translate_if_korean(prompt) torch.cuda.empty_cache() if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) width = 384 height = 384 parallel_size = 5 with torch.no_grad(): messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='' ) text = text + vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) output, patches = generate( input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size, temperature=t2i_temperature ) images = unpack( patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size ) stime = time.time() ret_images = [image_upsample(Image.fromarray(images[i])) for i in range(parallel_size)] print(f'upsample time: {time.time() - stime}') return ret_images @spaces.GPU(duration=60) def image_upsample(img: Image.Image) -> Image.Image: if img is None: raise Exception("Image not uploaded") width, height = img.size if width >= 5000 or height >= 5000: raise Exception("The image is too large.") global sr_model result = sr_model.predict(img.convert('RGB')) return result # Custom CSS for a sleek, modern and highly readable interface custom_css = """ body { background: #f0f2f5; font-family: 'Segoe UI', sans-serif; color: #333; } h1, h2, h3 { font-weight: 600; } .gradio-container { padding: 20px; } header { text-align: center; padding: 20px; margin-bottom: 20px; } header h1 { font-size: 3em; color: #2c3e50; } .gr-button { background-color: #3498db !important; color: #fff !important; border: none !important; padding: 10px 20px !important; border-radius: 5px !important; font-size: 1em !important; } .gr-button:hover { background-color: #2980b9 !important; } .gr-input, .gr-slider, .gr-number, .gr-textbox { border-radius: 5px; } .gr-gallery-item { border-radius: 10px; overflow: hidden; box-shadow: 0 2px 10px rgba(0,0,0,0.1); } """ # Gradio Interface with gr.Blocks(css=custom_css, title="Multimodal & T2I") as demo: with gr.Column(variant="panel"): gr.Markdown("

Chat With Janus-Pro-7B

") with gr.Tabs(): with gr.TabItem("Multimodal Understanding"): gr.Markdown("### Chat with Images") with gr.Row(): image_input = gr.Image(label="Upload Image", type="numpy") with gr.Column(): question_input = gr.Textbox(label="Question", placeholder="Enter your question about the image here...", lines=4) und_seed_input = gr.Number(label="Seed", precision=0, value=42) top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="Top_p") temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="Temperature") understanding_button = gr.Button("Chat", elem_id="understanding-button") understanding_output = gr.Textbox(label="Response", lines=6) with gr.Accordion("Examples", open=False): gr.Examples( label="Multimodal Understanding Examples", examples=[ ["explain this meme", "doge.png"] ], inputs=[question_input, image_input], ) understanding_button.click( multimodal_understanding, inputs=[image_input, question_input, und_seed_input, top_p, temperature], outputs=understanding_output, ) with gr.TabItem("Text-to-Image Generation"): gr.Markdown("### Generate Images from Text") with gr.Row(): prompt_input = gr.Textbox(label="Prompt", placeholder="Enter detailed prompt for image generation...", lines=4) with gr.Row(): seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234) cfg_weight_input = gr.Slider(minimum=1, maximum=10, value=5, step=0.5, label="CFG Weight") t2i_temperature = gr.Slider(minimum=0, maximum=1, value=1.0, step=0.05, label="Temperature") generation_button = gr.Button("Generate Images", elem_id="generation-button") image_output = gr.Gallery(label="Generated Images", columns=2, rows=2, height=300) with gr.Accordion("Examples", open=False): gr.Examples( label="Text-to-Image Examples", examples=[ "Master shifu racoon wearing drip attire as a street gangster.", "The face of a beautiful girl", "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k", "A cute and adorable baby fox with big brown eyes, autumn leaves in the background enchanting, immortal, fluffy, shiny mane, petals, fairyism, unreal engine 5 and Octane Render, highly detailed, photorealistic, cinematic, natural colors.", "고양이가 우주복을 입고 달에 있는 모습" ], inputs=prompt_input, ) generation_button.click( fn=generate_image, inputs=[prompt_input, seed_input, cfg_weight_input, t2i_temperature], outputs=image_output, ) gr.Markdown("") demo.launch(share=True)