MedVLM-R1

Introduction

MedVLM-R1 is a medical Vision-Language Model built upon Qwen2-VL-2B and fine-tuned using the GRPO reinforcement learning framework. Trained on 600 MRI VQA samples from the HuatuoGPT-Vision dataset, MedVLM-R1 excels in out-of-distribution performance on CT and X-ray VQA tasks. It also demonstrates explicit medical reasoning capabilities beyond merely providing final answers, ensuring greater interpretability and trustworthiness in clinical applications.

Quick Start

1. Load the model

from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, GenerationConfig
from qwen_vl_utils import process_vision_info
import torch

MODEL_PATH = 'JZPeterPan/MedVLM-R1'

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2",
    device_map="auto",
)

processor = AutoProcessor.from_pretrained(MODEL_PATH)

temp_generation_config = GenerationConfig(
    max_new_tokens=1024,
    do_sample=False,  
    temperature=1, 
    num_return_sequences=1,
    pad_token_id=151643,
)

2. Load the VQA Data

Pick one of the following examples. These are samples from OmniMedVQA data and are bundled by HuatuoGPT-Vision.

question = {"image": ['images/successful_cases/mdb146.png'], "problem": "What content appears in this image?\nA) Cardiac tissue\nB) Breast tissue\nC) Liver tissue\nD) Skin tissue", "solution": "B", "answer": "Breast tissue"}

question = {"image": ["images/successful_cases/person19_virus_50.jpeg"], "problem": "What content appears in this image?\nA) Lungs\nB) Bladder\nC) Brain\nD) Heart", "solution": "A", "answer": "Lungs"}

question = {"image":["images/successful_cases/abd-normal023599.png"],"problem":"Is any abnormality evident in this image?\nA) No\nB) Yes.","solution":"A","answer":"No"}

question = {"image":["images/successful_cases/foot089224.png"],"problem":"Which imaging technique was utilized for acquiring this image?\nA) MRI\nB) Electroencephalogram (EEG)\nC) Ultrasound\nD) Angiography","solution":"A","answer":"MRI"}

question = {"image":["images/successful_cases/knee031316.png"],"problem":"What can be observed in this image?\nA) Chondral abnormality\nB) Bone density loss\nC) Synovial cyst formation\nD) Ligament tear","solution":"A","answer":"Chondral abnormality"}

question = {"image":["images/successful_cases/shoulder045906.png"],"problem":"What can be visually detected in this picture?\nA) Bone fracture\nB) Soft tissue fluid\nC) Blood clot\nD) Tendon tear","solution":"B","answer":"Soft tissue fluid"}

question = {"image":["images/successful_cases/brain003631.png"],"problem":"What attribute can be observed in this image?\nA) Focal flair hyperintensity\nB) Bone fracture\nC) Vascular malformation\nD) Ligament tear","solution":"A","answer":"Focal flair hyperintensity"}

question = {"image":["images/successful_cases/mrabd005680.png"],"problem":"What can be observed in this image?\nA) Pulmonary embolism\nB) Pancreatic abscess\nC) Intraperitoneal mass\nD) Cardiac tamponade","solution":"C","answer":"Intraperitoneal mass"}

3. Run the inference

QUESTION_TEMPLATE = """
    {Question} 
    Your task: 
    1. Think through the question step by step, enclose your reasoning process in <think>...</think> tags. 
    2. Then provide the correct single-letter choice (A, B, C, D,...) inside <answer>...</answer> tags.
    3. No extra information or text outside of these tags.
    """

message = [{
    "role": "user",
    "content": [{"type": "image", "image": f"file://{question['image'][0]}"}, {"type": "text","text": QUESTION_TEMPLATE.format(Question=question['problem'])}]
}]

text = processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True)
    
image_inputs, video_inputs = process_vision_info(message)
inputs = processor(
    text=text,
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
).to("cuda")

generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False, generation_config=temp_generation_config)

generated_ids_trimmed = [out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]

output_text = processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(f'model output: {output_text[0]}')

Failure cases

MedVLM-R1's reasoning fails when testing on more difficult VQA examples. Although it can output correct choices in the following examples, the reasoning of them is either superficial or contradictory.

question = {"image":["images/failure_cases/mrabd021764.png"],"problem":"What is the observable finding in this image?\nA) Brain lesion\nB) Intestinal lesion\nC) Gallbladder lesion\nD) Pancreatic lesion","solution":"D","answer":"Pancreatic lesion"}

question = {"image":["images/failure_cases/spine010017.png"],"problem":"What can be observed in this image?\nA) Cystic lesions\nB) Fractured bones\nC) Inflamed tissue\nD) Nerve damage","solution":"A","answer":"Cystic lesions"}

question = {"image":["images/failure_cases/ankle056120.png"],"problem":"What attribute can be observed in this image?\nA) Bursitis\nB) Flexor pathology\nC) Tendonitis\nD) Joint inflammation","solution":"B","answer":"Flexor pathology"}

question = {"image":["images/failure_cases/lung067009.png"],"problem":"What is the term for the anomaly depicted in the image?\nA) Pulmonary embolism\nB) Airspace opacity\nC) Lung consolidation\nD) Atelectasis","solution":"B","answer":"Airspace opacity"}

Acknowledgement

We thank all machine learning / medical workers for making public codebase / datasets available to the community 🫶🫶🫶

If you find our work helpful, feel free to give us a cite.

@article{pan2025medvlm,
  title={MedVLM-R1: Incentivizing Medical Reasoning Capability of Vision-Language Models (VLMs) via Reinforcement Learning},
  author={Pan, Jiazhen and Liu, Che and Wu, Junde and Liu, Fenglin and Zhu, Jiayuan and Li, Hongwei Bran and Chen, Chen and Ouyang, Cheng and Rueckert, Daniel},
  journal={arXiv preprint arXiv:2502.19634},
  year={2025}
}
Downloads last month
7
Safetensors
Model size
2.21B params
Tensor type
BF16
·
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Model tree for JZPeterPan/MedVLM-R1

Base model

Qwen/Qwen2-VL-2B
Finetuned
(154)
this model