awacke1 commited on
Commit
b4eeb2f
·
verified ·
1 Parent(s): f3f1e60

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -9
app.py CHANGED
@@ -42,7 +42,7 @@ class ModelConfig:
42
  def model_path(self):
43
  return f"models/{self.name}"
44
 
45
- # Custom Dataset for SFT
46
  class SFTDataset(Dataset):
47
  def __init__(self, data, tokenizer, max_length=128):
48
  self.data = data
@@ -56,18 +56,39 @@ class SFTDataset(Dataset):
56
  prompt = self.data[idx]["prompt"]
57
  response = self.data[idx]["response"]
58
 
59
- prompt_encoding = self.tokenizer(prompt, max_length=self.max_length // 2, padding="max_length", truncation=True, return_tensors="pt")
60
  full_text = f"{prompt} {response}"
61
- full_encoding = self.tokenizer(full_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt")
 
 
 
 
 
 
62
 
63
- input_ids = prompt_encoding["input_ids"].squeeze()
64
- attention_mask = prompt_encoding["attention_mask"].squeeze()
65
- labels = full_encoding["input_ids"].squeeze()
 
 
 
 
 
66
 
67
- prompt_len = prompt_encoding["input_ids"].ne(self.tokenizer.pad_token_id).sum().item()
68
- labels[:prompt_len] = -100
 
69
 
70
- return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels}
 
 
 
 
 
 
 
 
 
71
 
72
  # Model Builder Class with Easter Egg Jokes
73
  class ModelBuilder:
@@ -111,6 +132,10 @@ class ModelBuilder:
111
  input_ids = batch["input_ids"].to(device)
112
  attention_mask = batch["attention_mask"].to(device)
113
  labels = batch["labels"].to(device)
 
 
 
 
114
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
115
  loss = outputs.loss
116
  loss.backward()
 
42
  def model_path(self):
43
  return f"models/{self.name}"
44
 
45
+ # Custom Dataset for SFT (Fixed)
46
  class SFTDataset(Dataset):
47
  def __init__(self, data, tokenizer, max_length=128):
48
  self.data = data
 
56
  prompt = self.data[idx]["prompt"]
57
  response = self.data[idx]["response"]
58
 
59
+ # Tokenize the full sequence once
60
  full_text = f"{prompt} {response}"
61
+ full_encoding = self.tokenizer(
62
+ full_text,
63
+ max_length=self.max_length,
64
+ padding="max_length",
65
+ truncation=True,
66
+ return_tensors="pt"
67
+ )
68
 
69
+ # Tokenize prompt separately to get its length
70
+ prompt_encoding = self.tokenizer(
71
+ prompt,
72
+ max_length=self.max_length,
73
+ padding=False, # No padding here, just to get length
74
+ truncation=True,
75
+ return_tensors="pt"
76
+ )
77
 
78
+ input_ids = full_encoding["input_ids"].squeeze()
79
+ attention_mask = full_encoding["attention_mask"].squeeze()
80
+ labels = input_ids.clone() # Clone to avoid modifying input_ids
81
 
82
+ # Mask prompt tokens in labels
83
+ prompt_len = prompt_encoding["input_ids"].shape[1] # Actual length of prompt
84
+ if prompt_len < self.max_length:
85
+ labels[:prompt_len] = -100 # Ignore prompt in loss
86
+
87
+ return {
88
+ "input_ids": input_ids,
89
+ "attention_mask": attention_mask,
90
+ "labels": labels
91
+ }
92
 
93
  # Model Builder Class with Easter Egg Jokes
94
  class ModelBuilder:
 
132
  input_ids = batch["input_ids"].to(device)
133
  attention_mask = batch["attention_mask"].to(device)
134
  labels = batch["labels"].to(device)
135
+
136
+ # Debug shapes
137
+ assert input_ids.shape[0] == labels.shape[0], f"Batch size mismatch: input_ids {input_ids.shape}, labels {labels.shape}"
138
+
139
  outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
140
  loss = outputs.loss
141
  loss.backward()