Spaces:
Runtime error
Runtime error
stakelovelace
commited on
Commit
·
635201d
1
Parent(s):
d16807d
commit from tesla
Browse files
app.py
CHANGED
@@ -33,7 +33,8 @@ def load_data_and_config(data_path):
|
|
33 |
|
34 |
def train_model(model, tokenizer, data, device):
|
35 |
"""Trains the model using the Hugging Face Trainer API."""
|
36 |
-
inputs = [tokenizer(d['text'], max_length=256, truncation=True, padding='max_length', return_tensors="pt") for d in data]
|
|
|
37 |
dataset = Dataset.from_dict({
|
38 |
'input_ids': [x['input_ids'].squeeze() for x in inputs],
|
39 |
'labels': [x['input_ids'].squeeze() for x in inputs]
|
@@ -43,7 +44,7 @@ def train_model(model, tokenizer, data, device):
|
|
43 |
output_dir='./results',
|
44 |
num_train_epochs=3,
|
45 |
per_device_train_batch_size=1,
|
46 |
-
gradient_accumulation_steps=
|
47 |
fp16=True, # Enable mixed precision
|
48 |
warmup_steps=500,
|
49 |
weight_decay=0.01,
|
@@ -61,9 +62,8 @@ def train_model(model, tokenizer, data, device):
|
|
61 |
trainer.train()
|
62 |
|
63 |
# Optionally clear cache if using GPU or MPS
|
64 |
-
print(torch.cuda.memory_summary(device=None, abbreviated=False))
|
65 |
-
|
66 |
if torch.cuda.is_available():
|
|
|
67 |
torch.cuda.empty_cache()
|
68 |
elif torch.has_mps:
|
69 |
torch.mps.empty_cache()
|
@@ -85,6 +85,9 @@ def main(api_name, base_url):
|
|
85 |
#model = BertLMHeadModel.from_pretrained('google/codegemma-2b', is_decoder=True)
|
86 |
# Example assuming you have a prepared dataset for classification
|
87 |
#model = BertForSequenceClassification.from_pretrained('thenlper/gte-small', num_labels=2, is_decoder=True) # binary classification
|
|
|
|
|
|
|
88 |
model.to(device) # Move model to the appropriate device
|
89 |
|
90 |
train_model(model, tokenizer, data, device)
|
@@ -104,7 +107,8 @@ def generate_api_query(model, tokenizer, prompt, desired_output, api_name, base_
|
|
104 |
input_ids = input_ids.to(model.device)
|
105 |
|
106 |
# Generate query using model with temperature for randomness
|
107 |
-
output = model.generate(input_ids, max_length=128,
|
|
|
108 |
|
109 |
# Decode the generated query tokens
|
110 |
query = tokenizer.decode(output[0], skip_special_tokens=True)
|
|
|
33 |
|
34 |
def train_model(model, tokenizer, data, device):
|
35 |
"""Trains the model using the Hugging Face Trainer API."""
|
36 |
+
# inputs = [tokenizer(d['text'], max_length=256, truncation=True, padding='max_length', return_tensors="pt") for d in data]
|
37 |
+
inputs = [tokenizer(d['text'], max_length=256, truncation=True, padding='max_length', return_tensors="pt").to(torch.float16) for d in data]
|
38 |
dataset = Dataset.from_dict({
|
39 |
'input_ids': [x['input_ids'].squeeze() for x in inputs],
|
40 |
'labels': [x['input_ids'].squeeze() for x in inputs]
|
|
|
44 |
output_dir='./results',
|
45 |
num_train_epochs=3,
|
46 |
per_device_train_batch_size=1,
|
47 |
+
gradient_accumulation_steps=4,
|
48 |
fp16=True, # Enable mixed precision
|
49 |
warmup_steps=500,
|
50 |
weight_decay=0.01,
|
|
|
62 |
trainer.train()
|
63 |
|
64 |
# Optionally clear cache if using GPU or MPS
|
|
|
|
|
65 |
if torch.cuda.is_available():
|
66 |
+
print(torch.cuda.memory_summary(device=None, abbreviated=False))
|
67 |
torch.cuda.empty_cache()
|
68 |
elif torch.has_mps:
|
69 |
torch.mps.empty_cache()
|
|
|
85 |
#model = BertLMHeadModel.from_pretrained('google/codegemma-2b', is_decoder=True)
|
86 |
# Example assuming you have a prepared dataset for classification
|
87 |
#model = BertForSequenceClassification.from_pretrained('thenlper/gte-small', num_labels=2, is_decoder=True) # binary classification
|
88 |
+
# Example: Offloading embeddings to CPU
|
89 |
+
model.embeddings.to('cpu')
|
90 |
+
|
91 |
model.to(device) # Move model to the appropriate device
|
92 |
|
93 |
train_model(model, tokenizer, data, device)
|
|
|
107 |
input_ids = input_ids.to(model.device)
|
108 |
|
109 |
# Generate query using model with temperature for randomness
|
110 |
+
output = model.generate(input_ids, max_length=128, temperature=0.001, do_sample=True)
|
111 |
+
|
112 |
|
113 |
# Decode the generated query tokens
|
114 |
query = tokenizer.decode(output[0], skip_special_tokens=True)
|