File size: 4,641 Bytes
a4e6c29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
import re
from typing import Dict

import torch
from datasets import load_dataset
from torch.optim import AdamW
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, get_scheduler


def setup_model():
    # Using a smaller CodeT5 model suitable for the free tier
    model_name = "Salesforce/codet5-small"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
    return model, tokenizer


def prepare_dataset():
    # Load Python subset of CodeSearchNet
    dataset = load_dataset(
        "code_search_net", "python", split="train[:1000]", trust_remote_code=True
    )  # Limited to 1000 examples for free tier

    def extract_function_info(example: Dict) -> Dict:
        """Extract clean function definitions and docstrings."""
        code = example["whole_func_string"]

        # Basic filtering for API-style functions
        if not code.strip().startswith("def "):
            # Empty strings are better handled downstream.
            return {
                "function": "",
                "documentation": "",
                "input": "",
                "output": ""
            }

        # Remove multiple newlines and standardize spacing
        code = re.sub(r"\n\s*\n", "\n", code)
        docstring = example["func_documentation_string"].strip()

        return {
            "function": code,
            "documentation": docstring,
            "input": f"Write a Python function that: {docstring}",
            "output": code,
        }

    # Process and filter the dataset
    processed_dataset = dataset.map(extract_function_info)
    # Filter out empty entries after mapping
    processed_dataset = processed_dataset.filter(lambda x: x["function"] != "")

    return processed_dataset


def tokenize_data(examples, tokenizer, max_length=512):
    """Tokenize inputs and outputs for training."""
    # Batch tokenization for inputs
    model_inputs = tokenizer(
        examples['input'],
        max_length=max_length,
        padding='max_length',
        truncation=True
    )
    
    # Batch tokenization for outputs
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
            examples['output'],
            max_length=max_length,
            padding='max_length',
            truncation=True
        ).input_ids
    
    model_inputs['labels'] = labels
    return model_inputs


def train():
    model, tokenizer = setup_model()
    dataset = prepare_dataset()

    # Training configuration
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
    model.to(device)

    # Hyperparameters
    batch_size = 8
    num_epochs = 3
    learning_rate = 5e-5
    max_length = 512

    # Modify the dataset mapping
    tokenized_dataset = dataset.map(
        lambda x: tokenize_data(x, tokenizer, max_length),
        batched=True,
        batch_size=16,  # Explicit batch size for processing
        remove_columns=dataset.column_names,
    )

    def collate_fn(examples):
        return {
            'input_ids': torch.stack([torch.tensor(example['input_ids']) for example in examples]).to(device),
            'attention_mask': torch.stack([torch.tensor(example['attention_mask']) for example in examples]).to(device),
            'labels': torch.stack([torch.tensor(example['labels']) for example in examples]).to(device)
        }

    train_dataloader = DataLoader(
        tokenized_dataset,
        shuffle=True,
        batch_size=batch_size,
        collate_fn=collate_fn
    )

    # Initialize optimizer and scheduler
    optimizer = AdamW(model.parameters(), lr=learning_rate)
    num_training_steps = num_epochs * len(train_dataloader)
    lr_scheduler = get_scheduler(
        name="linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=num_training_steps,
    )

    # Training loop
    progress_bar = tqdm(range(num_training_steps))
    model.train()

    for epoch in range(num_epochs):
        for batch in train_dataloader:
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()

            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            progress_bar.set_description(f"Loss: {loss.item():.4f}")

        # Save checkpoint after each epoch
        model.save_pretrained(f"checkpoint-epoch-{epoch}")
        tokenizer.save_pretrained(f"checkpoint-epoch-{epoch}")

    print("Training completed!")


if __name__ == "__main__":
    train()