|
import gradio as gr |
|
from llava_med import LlavaMedProcessor, LlavaMedForCausalLM |
|
from PIL import Image |
|
import torch |
|
|
|
|
|
model = LlavaMedForCausalLM.from_pretrained( |
|
"microsoft/llava-med-v1.5-mistral-7b", |
|
torch_dtype=torch.float32, |
|
low_cpu_mem_usage=True, |
|
device_map="cpu" |
|
) |
|
processor = LlavaMedProcessor.from_pretrained( |
|
"microsoft/llava-med-v1.5-mistral-7b" |
|
) |
|
|
|
def analyze_medical_image(image, question): |
|
|
|
prompt = f"Question: {question} Answer:" |
|
|
|
|
|
inputs = processor( |
|
text=prompt, |
|
images=image, |
|
return_tensors="pt", |
|
padding=True |
|
).to("cpu") |
|
|
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_new_tokens=256, |
|
do_sample=True, |
|
temperature=0.7, |
|
top_p=0.9 |
|
) |
|
|
|
|
|
response = processor.batch_decode( |
|
outputs, |
|
skip_special_tokens=True |
|
)[0].split("Answer:")[-1].strip() |
|
|
|
return response |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# LLaVA-Med Medical Analysis (CPU)") |
|
gr.Markdown("Official Microsoft LLaVA-Med 1.5-Mistral-7B implementation") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
image_input = gr.Image(label="Medical Image", type="pil") |
|
question_input = gr.Textbox(label="Clinical Question", placeholder="Enter your medical question...") |
|
submit_btn = gr.Button("Analyze") |
|
|
|
with gr.Column(): |
|
output_text = gr.Textbox(label="Clinical Analysis", interactive=False) |
|
|
|
submit_btn.click( |
|
fn=analyze_medical_image, |
|
inputs=[image_input, question_input], |
|
outputs=output_text |
|
) |
|
|
|
demo.queue(max_size=5).launch() |