File size: 2,666 Bytes
5c85d5d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
"""
train.py

A complete example of fine-tuning BLIP on 'agentsea/computer-thoughts' for captioning.
All processing is done in the collate function. This is simpler and avoids shape mismatches.
"""

import torch
from datasets import load_dataset, Image as HFImage
from transformers import (
    BlipProcessor,
    BlipForConditionalGeneration,
    TrainingArguments,
    Trainer
)

# 1. Load dataset
dataset = load_dataset("agentsea/computer-thoughts")

# 2. Rename "image_before" -> "image" and cast to HFImage so it becomes a PIL Image
dataset = dataset.rename_column("image_before", "image")
dataset = dataset.cast_column("image", HFImage())

# 3. Create a small subset for demo (just 5 examples). Remove this if you want the full data.
train_subset = dataset["train"].select(range(5))

# 4. Load the BLIP base model and processor
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")

# 5. Define a collate_fn that transforms images+text on-the-fly
def collate_fn(examples):
    # examples is a list of dicts, each dict with keys: 
    #   'task', 'image', 'image_after', 'action', 'thought', 'bad_thought', 'subtask', 'bad_subtask', etc.
    # We'll use 'image' (PIL) and 'subtask' (string) as the caption.
    images = [ex["image"] for ex in examples]   # PIL images
    texts = [ex["subtask"] for ex in examples]  # or whichever text column you want

    inputs = processor(images=images, text=texts, return_tensors="pt", padding=True)
    
    # Add labels so the model can compute cross-entropy loss
    # For a basic approach: labels = input_ids
    inputs["labels"] = inputs["input_ids"].clone()

    return inputs

# 6. Define training arguments
training_args = TrainingArguments(
    output_dir="./my_blip_computer_thoughts",
    num_train_epochs=1,
    per_device_train_batch_size=1,  
    gradient_accumulation_steps=4,  # effectively batch size 4 per device
    logging_steps=5,
    save_steps=20,
    save_total_limit=2,
    remove_unused_columns=False  # important when custom columns are in the dataset
)

# 6. Create Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_subset,   # or dataset["train"] for the full set
    data_collator=collate_fn,
)

# 7. Train
trainer.train()

# 9. Push the final model + processor to Hugging Face Hub
#    (Make sure you're logged in: huggingface-cli login)
model.push_to_hub("zeddotes/blip-computer-thoughts")
processor.push_to_hub("zeddotes/blip-computer-thoughts")

print("Done training and pushed model to zeddotes/blip-computer-thoughts!")