dad1909 commited on
Commit
400fcf7
·
verified ·
1 Parent(s): 3b49ded

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -33,9 +33,14 @@ model, tokenizer = FastLanguageModel.from_pretrained(
33
  )
34
  print("Model and tokenizer loaded successfully.")
35
 
 
 
 
 
 
36
  print("Configuring PEFT model...")
37
  model = FastLanguageModel.get_peft_model(
38
- model,
39
  r=16,
40
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
41
  lora_alpha=16,
@@ -51,24 +56,18 @@ print("PEFT model configured.")
51
  # Updated alpaca_prompt for different types
52
  alpaca_prompt = {
53
  "learning_from": """Below is a CVE definition.
54
-
55
  ### CVE definition:
56
  {}
57
-
58
  ### detail CVE:
59
  {}""",
60
  "definition": """Below is a definition about software vulnerability. Explain it.
61
-
62
  ### Definition:
63
  {}
64
-
65
  ### Explanation:
66
  {}""",
67
  "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
68
-
69
  ### Code Snippet:
70
  {}
71
-
72
  ### Vulnerability solution:
73
  {}"""
74
  }
@@ -111,7 +110,7 @@ print("Formatting function applied.")
111
 
112
  print("Initializing trainer...")
113
  trainer = SFTTrainer(
114
- model=model,
115
  tokenizer=tokenizer,
116
  train_dataset=dataset,
117
  dataset_text_field="text",
@@ -145,11 +144,16 @@ num += 1
145
  uploads_models = f"cybersentinal-3.0"
146
 
147
  print("Saving the trained model...")
148
- model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
149
  print("Model saved successfully.")
150
 
151
  print("Pushing the model to the hub...")
152
- model.push_to_hub_merged(
 
 
 
 
 
153
  uploads_models,
154
  tokenizer,
155
  save_method="merged_16bit",
 
33
  )
34
  print("Model and tokenizer loaded successfully.")
35
 
36
+ # Wrap the model in DataParallel to use all GPUs
37
+ if torch.cuda.device_count() > 1:
38
+ print(f"Using {torch.cuda.device_count()} GPUs!")
39
+ model = torch.nn.DataParallel(model)
40
+
41
  print("Configuring PEFT model...")
42
  model = FastLanguageModel.get_peft_model(
43
+ model.module if isinstance(model, torch.nn.DataParallel) else model,
44
  r=16,
45
  target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
46
  lora_alpha=16,
 
56
  # Updated alpaca_prompt for different types
57
  alpaca_prompt = {
58
  "learning_from": """Below is a CVE definition.
 
59
  ### CVE definition:
60
  {}
 
61
  ### detail CVE:
62
  {}""",
63
  "definition": """Below is a definition about software vulnerability. Explain it.
 
64
  ### Definition:
65
  {}
 
66
  ### Explanation:
67
  {}""",
68
  "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
 
69
  ### Code Snippet:
70
  {}
 
71
  ### Vulnerability solution:
72
  {}"""
73
  }
 
110
 
111
  print("Initializing trainer...")
112
  trainer = SFTTrainer(
113
+ model=model.module if isinstance(model, torch.nn.DataParallel) else model,
114
  tokenizer=tokenizer,
115
  train_dataset=dataset,
116
  dataset_text_field="text",
 
144
  uploads_models = f"cybersentinal-3.0"
145
 
146
  print("Saving the trained model...")
147
+ model.module.save_pretrained_merged("model", tokenizer, save_method="merged_16bit") if isinstance(model, torch.nn.DataParallel) else model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
148
  print("Model saved successfully.")
149
 
150
  print("Pushing the model to the hub...")
151
+ model.module.push_to_hub_merged(
152
+ uploads_models,
153
+ tokenizer,
154
+ save_method="merged_16bit",
155
+ token=hf_token
156
+ ) if isinstance(model, torch.nn.DataParallel) else model.push_to_hub_merged(
157
  uploads_models,
158
  tokenizer,
159
  save_method="merged_16bit",