import gradio as gr from llava_med import LlavaMedProcessor, LlavaMedForCausalLM from PIL import Image import torch # Load model and processor model = LlavaMedForCausalLM.from_pretrained( "microsoft/llava-med-v1.5-mistral-7b", torch_dtype=torch.float32, # Use float32 for CPU stability low_cpu_mem_usage=True, device_map="cpu" ) processor = LlavaMedProcessor.from_pretrained( "microsoft/llava-med-v1.5-mistral-7b" ) def analyze_medical_image(image, question): # Prepare inputs prompt = f"Question: {question} Answer:" # Process inputs inputs = processor( text=prompt, images=image, return_tensors="pt", padding=True ).to("cpu") # Generate response with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=256, do_sample=True, temperature=0.7, top_p=0.9 ) # Decode response response = processor.batch_decode( outputs, skip_special_tokens=True )[0].split("Answer:")[-1].strip() return response # Gradio interface with gr.Blocks() as demo: gr.Markdown("# LLaVA-Med Medical Analysis (CPU)") gr.Markdown("Official Microsoft LLaVA-Med 1.5-Mistral-7B implementation") with gr.Row(): with gr.Column(): image_input = gr.Image(label="Medical Image", type="pil") question_input = gr.Textbox(label="Clinical Question", placeholder="Enter your medical question...") submit_btn = gr.Button("Analyze") with gr.Column(): output_text = gr.Textbox(label="Clinical Analysis", interactive=False) submit_btn.click( fn=analyze_medical_image, inputs=[image_input, question_input], outputs=output_text ) demo.queue(max_size=5).launch()