computer-thoughts / train.py
zeddotes's picture
updated
5c85d5d
"""
train.py
A complete example of fine-tuning BLIP on 'agentsea/computer-thoughts' for captioning.
All processing is done in the collate function. This is simpler and avoids shape mismatches.
"""
import torch
from datasets import load_dataset, Image as HFImage
from transformers import (
BlipProcessor,
BlipForConditionalGeneration,
TrainingArguments,
Trainer
)
# 1. Load dataset
dataset = load_dataset("agentsea/computer-thoughts")
# 2. Rename "image_before" -> "image" and cast to HFImage so it becomes a PIL Image
dataset = dataset.rename_column("image_before", "image")
dataset = dataset.cast_column("image", HFImage())
# 3. Create a small subset for demo (just 5 examples). Remove this if you want the full data.
train_subset = dataset["train"].select(range(5))
# 4. Load the BLIP base model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")
# 5. Define a collate_fn that transforms images+text on-the-fly
def collate_fn(examples):
# examples is a list of dicts, each dict with keys:
# 'task', 'image', 'image_after', 'action', 'thought', 'bad_thought', 'subtask', 'bad_subtask', etc.
# We'll use 'image' (PIL) and 'subtask' (string) as the caption.
images = [ex["image"] for ex in examples] # PIL images
texts = [ex["subtask"] for ex in examples] # or whichever text column you want
inputs = processor(images=images, text=texts, return_tensors="pt", padding=True)
# Add labels so the model can compute cross-entropy loss
# For a basic approach: labels = input_ids
inputs["labels"] = inputs["input_ids"].clone()
return inputs
# 6. Define training arguments
training_args = TrainingArguments(
output_dir="./my_blip_computer_thoughts",
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=4, # effectively batch size 4 per device
logging_steps=5,
save_steps=20,
save_total_limit=2,
remove_unused_columns=False # important when custom columns are in the dataset
)
# 6. Create Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=train_subset, # or dataset["train"] for the full set
data_collator=collate_fn,
)
# 7. Train
trainer.train()
# 9. Push the final model + processor to Hugging Face Hub
# (Make sure you're logged in: huggingface-cli login)
model.push_to_hub("zeddotes/blip-computer-thoughts")
processor.push_to_hub("zeddotes/blip-computer-thoughts")
print("Done training and pushed model to zeddotes/blip-computer-thoughts!")