Update app.py
Browse files
app.py
CHANGED
@@ -33,9 +33,14 @@ model, tokenizer = FastLanguageModel.from_pretrained(
|
|
33 |
)
|
34 |
print("Model and tokenizer loaded successfully.")
|
35 |
|
|
|
|
|
|
|
|
|
|
|
36 |
print("Configuring PEFT model...")
|
37 |
model = FastLanguageModel.get_peft_model(
|
38 |
-
model,
|
39 |
r=16,
|
40 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
41 |
lora_alpha=16,
|
@@ -51,24 +56,18 @@ print("PEFT model configured.")
|
|
51 |
# Updated alpaca_prompt for different types
|
52 |
alpaca_prompt = {
|
53 |
"learning_from": """Below is a CVE definition.
|
54 |
-
|
55 |
### CVE definition:
|
56 |
{}
|
57 |
-
|
58 |
### detail CVE:
|
59 |
{}""",
|
60 |
"definition": """Below is a definition about software vulnerability. Explain it.
|
61 |
-
|
62 |
### Definition:
|
63 |
{}
|
64 |
-
|
65 |
### Explanation:
|
66 |
{}""",
|
67 |
"code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
|
68 |
-
|
69 |
### Code Snippet:
|
70 |
{}
|
71 |
-
|
72 |
### Vulnerability solution:
|
73 |
{}"""
|
74 |
}
|
@@ -111,7 +110,7 @@ print("Formatting function applied.")
|
|
111 |
|
112 |
print("Initializing trainer...")
|
113 |
trainer = SFTTrainer(
|
114 |
-
model=model,
|
115 |
tokenizer=tokenizer,
|
116 |
train_dataset=dataset,
|
117 |
dataset_text_field="text",
|
@@ -145,11 +144,16 @@ num += 1
|
|
145 |
uploads_models = f"cybersentinal-3.0"
|
146 |
|
147 |
print("Saving the trained model...")
|
148 |
-
model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
|
149 |
print("Model saved successfully.")
|
150 |
|
151 |
print("Pushing the model to the hub...")
|
152 |
-
model.push_to_hub_merged(
|
|
|
|
|
|
|
|
|
|
|
153 |
uploads_models,
|
154 |
tokenizer,
|
155 |
save_method="merged_16bit",
|
|
|
33 |
)
|
34 |
print("Model and tokenizer loaded successfully.")
|
35 |
|
36 |
+
# Wrap the model in DataParallel to use all GPUs
|
37 |
+
if torch.cuda.device_count() > 1:
|
38 |
+
print(f"Using {torch.cuda.device_count()} GPUs!")
|
39 |
+
model = torch.nn.DataParallel(model)
|
40 |
+
|
41 |
print("Configuring PEFT model...")
|
42 |
model = FastLanguageModel.get_peft_model(
|
43 |
+
model.module if isinstance(model, torch.nn.DataParallel) else model,
|
44 |
r=16,
|
45 |
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
46 |
lora_alpha=16,
|
|
|
56 |
# Updated alpaca_prompt for different types
|
57 |
alpaca_prompt = {
|
58 |
"learning_from": """Below is a CVE definition.
|
|
|
59 |
### CVE definition:
|
60 |
{}
|
|
|
61 |
### detail CVE:
|
62 |
{}""",
|
63 |
"definition": """Below is a definition about software vulnerability. Explain it.
|
|
|
64 |
### Definition:
|
65 |
{}
|
|
|
66 |
### Explanation:
|
67 |
{}""",
|
68 |
"code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
|
|
|
69 |
### Code Snippet:
|
70 |
{}
|
|
|
71 |
### Vulnerability solution:
|
72 |
{}"""
|
73 |
}
|
|
|
110 |
|
111 |
print("Initializing trainer...")
|
112 |
trainer = SFTTrainer(
|
113 |
+
model=model.module if isinstance(model, torch.nn.DataParallel) else model,
|
114 |
tokenizer=tokenizer,
|
115 |
train_dataset=dataset,
|
116 |
dataset_text_field="text",
|
|
|
144 |
uploads_models = f"cybersentinal-3.0"
|
145 |
|
146 |
print("Saving the trained model...")
|
147 |
+
model.module.save_pretrained_merged("model", tokenizer, save_method="merged_16bit") if isinstance(model, torch.nn.DataParallel) else model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
|
148 |
print("Model saved successfully.")
|
149 |
|
150 |
print("Pushing the model to the hub...")
|
151 |
+
model.module.push_to_hub_merged(
|
152 |
+
uploads_models,
|
153 |
+
tokenizer,
|
154 |
+
save_method="merged_16bit",
|
155 |
+
token=hf_token
|
156 |
+
) if isinstance(model, torch.nn.DataParallel) else model.push_to_hub_merged(
|
157 |
uploads_models,
|
158 |
tokenizer,
|
159 |
save_method="merged_16bit",
|