import torch from janus.models import MultiModalityCausalLM, VLChatProcessor from PIL import Image from diffusers import AutoencoderKL import numpy as np import gradio as gr # Configure device device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Initialize medical imaging components def load_medical_models(): try: # Load processor and tokenizer processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B") # Load base model model = MultiModalityCausalLM.from_pretrained( "deepseek-ai/Janus-1.3B", torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32 ).to(device).eval() # Load VAE for image processing vae = AutoencoderKL.from_pretrained( "stabilityai/sdxl-vae", torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32 ).to(device).eval() return processor, model, vae except Exception as e: print(f"Error loading models: {str(e)}") raise processor, model, vae = load_medical_models() # Medical image analysis function def medical_analysis(image, question, seed=42, top_p=0.95, temperature=0.1): try: # Set random seed for reproducibility torch.manual_seed(seed) np.random.seed(seed) # Prepare inputs if isinstance(image, np.ndarray): image = Image.fromarray(image).convert("RGB") inputs = processor( text=question, images=[image], return_tensors="pt" ).to(device) # Generate analysis outputs = model.generate( inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=512, temperature=temperature, top_p=top_p ) return processor.decode(outputs[0], skip_special_tokens=True) except Exception as e: return f"Analysis error: {str(e)}" # Medical interface with gr.Blocks(title="Medical Imaging Assistant") as demo: gr.Markdown("# Medical Imaging AI Assistant") with gr.Tab("Analysis"): with gr.Row(): med_image = gr.Image(label="Input Image", type="pil") med_question = gr.Textbox(label="Clinical Query") analysis_output = gr.Textbox(label="Findings") gr.Examples( examples=[ ["ultrasound_sample.jpg", "Identify any abnormalities in this ultrasound"], ["xray_sample.jpg", "Describe the bone structure visible in this X-ray"] ], inputs=[med_image, med_question] ) med_question.submit( medical_analysis, inputs=[med_image, med_question], outputs=analysis_output ) demo.launch()