import gradio as gr import torch from transformers import AutoConfig, AutoModelForCausalLM from janus.models import MultiModalityCausalLM, VLChatProcessor from janus.utils.io import load_pil_images from PIL import Image import numpy as np import os import time import spaces # Import spaces for ZeroGPU compatibility # Load model and processor model_path = "deepseek-ai/Janus-Pro-7B" config = AutoConfig.from_pretrained(model_path) language_config = config.language_config language_config._attn_implementation = 'eager' vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True) if torch.cuda.is_available(): vl_gpt = vl_gpt.to(torch.bfloat16).cuda() else: vl_gpt = vl_gpt.to(torch.float16) vl_chat_processor = VLChatProcessor.from_pretrained(model_path) tokenizer = vl_chat_processor.tokenizer cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu' @torch.inference_mode() @spaces.GPU(duration=120) # Multimodal Understanding function def multimodal_understanding(image, question, seed, top_p, temperature): # Clear CUDA cache before generating torch.cuda.empty_cache() # set seed torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed(seed) conversation = [ { "role": "<|User|>", "content": f"\n{question}", "images": [image], }, {"role": "<|Assistant|>", "content": ""}, ] pil_images = [Image.fromarray(image)] prepare_inputs = vl_chat_processor( conversations=conversation, images=pil_images, force_batchify=True ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16) inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) outputs = vl_gpt.language_model.generate( inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask, pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, max_new_tokens=4000, do_sample=False if temperature == 0 else True, use_cache=True, temperature=temperature, top_p=top_p, ) answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True) return answer def generate(input_ids, width, height, temperature: float = 1, parallel_size: int = 5, cfg_weight: float = 5, image_token_num_per_image: int = 576, patch_size: int = 16): # Clear CUDA cache before generating torch.cuda.empty_cache() tokens = torch.zeros((parallel_size * 2, len(input_ids)), dtype=torch.int).to(cuda_device) for i in range(parallel_size * 2): tokens[i, :] = input_ids if i % 2 != 0: tokens[i, 1:-1] = vl_chat_processor.pad_id inputs_embeds = vl_gpt.language_model.get_input_embeddings()(tokens) generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).to(cuda_device) pkv = None for i in range(image_token_num_per_image): with torch.no_grad(): outputs = vl_gpt.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=pkv) pkv = outputs.past_key_values hidden_states = outputs.last_hidden_state logits = vl_gpt.gen_head(hidden_states[:, -1, :]) logit_cond = logits[0::2, :] logit_uncond = logits[1::2, :] logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond) probs = torch.softmax(logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_tokens[:, i] = next_token.squeeze(dim=-1) next_token = torch.cat([next_token.unsqueeze(dim=1), next_token.unsqueeze(dim=1)], dim=1).view(-1) img_embeds = vl_gpt.prepare_gen_img_embeds(next_token) inputs_embeds = img_embeds.unsqueeze(dim=1) patches = vl_gpt.gen_vision_model.decode_code(generated_tokens.to(dtype=torch.int), shape=[parallel_size, 8, width // patch_size, height // patch_size]) return generated_tokens.to(dtype=torch.int), patches def unpack(dec, width, height, parallel_size=5): dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1) dec = np.clip((dec + 1) / 2 * 255, 0, 255) visual_img = np.zeros((parallel_size, width, height, 3), dtype=np.uint8) visual_img[:, :, :] = dec return visual_img @torch.inference_mode() @spaces.GPU(duration=120) # Specify a duration to avoid timeout def generate_image(prompt, seed=None, guidance=5, t2i_temperature=1.0): # Clear CUDA cache and avoid tracking gradients torch.cuda.empty_cache() # Set the seed for reproducible results if seed is not None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) width = 384 height = 384 parallel_size = 5 with torch.no_grad(): messages = [{'role': '<|User|>', 'content': prompt}, {'role': '<|Assistant|>', 'content': ''}] text = vl_chat_processor.apply_sft_template_for_multi_turn_prompts(conversations=messages, sft_format=vl_chat_processor.sft_format, system_prompt='') text = text + vl_chat_processor.image_start_tag input_ids = torch.LongTensor(tokenizer.encode(text)) output, patches = generate(input_ids, width // 16 * 16, height // 16 * 16, cfg_weight=guidance, parallel_size=parallel_size, temperature=t2i_temperature) images = unpack(patches, width // 16 * 16, height // 16 * 16, parallel_size=parallel_size) return [Image.fromarray(images[i]).resize((768, 768), Image.LANCZOS) for i in range(parallel_size)] # Gradio interface with gr.Blocks() as demo: gr.Markdown(value="# Multimodal Understanding") with gr.Row(): image_input = gr.Image() with gr.Column(): question_input = gr.Textbox(label="Question") und_seed_input = gr.Number(label="Seed", precision=0, value=42) top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p") temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature") understanding_button = gr.Button("Chat") understanding_output = gr.Textbox(label="Response") examples_inpainting = gr.Examples( label="Multimodal Understanding examples", examples=[ [ "explain this meme", "doge.png", ], [ """Analyze the provided fundus image in exhaustive detail, following the standard ophthalmological protocol for fundus examination. Output an HTML report structured as a formal medical document. The report MUST: 1. **Image Quality Assessment:** Begin with a concise assessment of image quality, noting focus, illumination, field of view, and any artifacts (and their impact on assessability). 2. **Detailed Clinical Findings:** Describe each of the following areas with the utmost precision and specificity, using proper ophthalmological terminology: * **Optic Disc:** * Size and shape (including any abnormalities). * Color (specifically noting any pallor and its location). * Cup-to-Disc Ratio (CDR), providing both vertical and horizontal estimates. * Neuroretinal Rim: Assess rim thickness in all quadrants (superior, inferior, nasal, temporal). Explicitly state whether the ISNT rule is followed or violated. Describe any notching or focal thinning. * Peripapillary Region: Describe the presence/absence of peripapillary atrophy (PPA), differentiating between alpha and beta zones. Note any hemorrhages. * **Retinal Vasculature:** * Arterioles: Describe caliber (narrowing, dilation), tortuosity, and any focal abnormalities. * Venules: Describe caliber, tortuosity, and any abnormalities. * Arteriovenous (A/V) Ratio: Estimate the A/V ratio. * Crossing Changes: Note any arteriovenous nicking or other crossing abnormalities. * Vessel Course: Describe the course of the major vessels, and check for abnormalities. * **Macula:** * Foveal Reflex: Describe the presence/absence and quality of the foveal reflex. * Pigment Changes: Note any pigmentary abnormalities, drusen, or other lesions. * Edema/Exudates: Describe any signs of macular edema or exudates. * **Peripheral Retina:** * Mid-Periphery: Describe any abnormalities (hemorrhages, exudates, tears, etc.). * Far Periphery: Note the extent of visualization and any findings. 3. **Differential Diagnosis:** Based solely on the image findings, provide a prioritized differential diagnosis. Include the most likely diagnosis and any other plausible possibilities. For each diagnosis, explain the reasoning based on the observed features. 4. **Diagnostic Confidence:** Indicate the confidence level for the primary diagnosis. List the key image findings that support the diagnosis. 5. **Simulated AI Attention Metrics:** Create a table representing a *simulated* AI attention distribution. This should reflect the expected focus areas for the most likely diagnosis, based on the known importance of different features. Provide percentages for: * Optic Disc (Total) * Cup * Neuroretinal Rim (subdivided by region if significant differences exist) * Peripapillary Atrophy * Vessels * Macula * Periphery 6. **Summary and Impression:** Provide a concise summary of the key findings and the overall impression. 7. **Recommendations:** * Provide specific, actionable recommendations based on the image findings. * If referral is warranted, clearly state the urgency and the type of specialist. * List any recommended investigations (e.g., OCT, visual fields). 8. **Disclaimer:** Include a disclaimer stating that the report is based on image analysis alone and does not replace a full clinical examination. 9. **HTML Structure:** Use semantic HTML elements (h1-h3, p, ol, ul, table, div) to create a well-structured, readable report. Include: * A report header with a title ("EyeUnit.ai | AI for Ophthalmology") and a logo placeholder. * An image comparison section displaying the original fundus image and a placeholder for a heatmap (a canvas element with id "heatmapCanvas"). No actual heatmap generation is required; the canvas is a placeholder. * A placeholder for patient information(PATIENT ID, NAME, AGE, DATE OF EXAM) * Clearly labeled sections for each part of the analysis. * Tables for the "Overall Analysis Coverage" and "AI-Driven Attention Metrics." 10. **CSS Styling:** Apply CSS styles to make the report visually appealing and professional. The report should be suitable for both screen viewing and printing (use a `@media print` block to optimize for print). * **Crucial Details:** * **PATIENT ID, NAME, AGE and DATE OF EXAM** 11. **Crucial Details:** Output ONLY the complete HTML code. Do not provide any surrounding text or explanations. Focus solely on generating the HTML report. 12. **IMG SOURCE:** Use this image as the image source: `