s3777091 commited on
Commit
4deeced
·
1 Parent(s): 9c1bd77
Files changed (2) hide show
  1. app.py +153 -0
  2. requirements.txt +11 -0
app.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from unsloth import FastLanguageModel, is_bfloat16_supported
4
+ from trl import SFTTrainer
5
+ from transformers import TrainingArguments
6
+ from datasets import load_dataset
7
+ import gradio as gr
8
+
9
+
10
+ max_seq_length = 4096
11
+ dtype = None
12
+ load_in_4bit = True
13
+ hf_token = os.getenv("Token")
14
+
15
+ print("Starting model and tokenizer loading...")
16
+
17
+ # Load the model and tokenizer
18
+ model, tokenizer = FastLanguageModel.from_pretrained(
19
+ model_name="unsloth/llama-3-8b-Instruct-bnb-4bit",
20
+ max_seq_length=max_seq_length,
21
+ dtype=dtype,
22
+ load_in_4bit=load_in_4bit,
23
+ token=hf_token
24
+ )
25
+ print("Model and tokenizer loaded successfully.")
26
+
27
+ print("Configuring PEFT model...")
28
+ model = FastLanguageModel.get_peft_model(
29
+ model,
30
+ r=16,
31
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
32
+ lora_alpha=16,
33
+ lora_dropout=0,
34
+ bias="none",
35
+ use_gradient_checkpointing="unsloth",
36
+ random_state=3407,
37
+ use_rslora=False,
38
+ loftq_config=None,
39
+ )
40
+ print("PEFT model configured.")
41
+
42
+ # Updated alpaca_prompt for different types
43
+ alpaca_prompt = {
44
+ "learning_from": """Below is a CVE definition.
45
+
46
+ ### CVE definition:
47
+ {}
48
+
49
+ ### detail CVE:
50
+ {}""",
51
+ "definition": """Below is a definition about software vulnerability. Explain it.
52
+
53
+ ### Definition:
54
+ {}
55
+
56
+ ### Explanation:
57
+ {}""",
58
+ "code_vulnerability": """Below is a code snippet. Identify the line of code that is vulnerable and describe the type of software vulnerability.
59
+
60
+ ### Code Snippet:
61
+ {}
62
+
63
+ ### Vulnerability solution:
64
+ {}"""
65
+ }
66
+
67
+ EOS_TOKEN = tokenizer.eos_token
68
+
69
+ def detect_prompt_type(instruction):
70
+ if instruction.startswith("what is code vulnerable of this code:"):
71
+ return "code_vulnerability"
72
+ elif instruction.startswith("Learning from"):
73
+ return "learning_from"
74
+ elif instruction.startswith("what is"):
75
+ return "definition"
76
+ else:
77
+ return "unknown"
78
+
79
+ def formatting_prompts_func(examples):
80
+ instructions = examples["instruction"]
81
+ outputs = examples["output"]
82
+ texts = []
83
+
84
+ for instruction, output in zip(instructions, outputs):
85
+ prompt_type = detect_prompt_type(instruction)
86
+ if prompt_type in alpaca_prompt:
87
+ prompt = alpaca_prompt[prompt_type].format(instruction, output)
88
+ else:
89
+ prompt = instruction + "\n\n" + output
90
+ text = prompt + EOS_TOKEN
91
+ texts.append(text)
92
+
93
+ return {"text": texts}
94
+
95
+ print("Loading dataset...")
96
+ dataset = load_dataset("dad1909/DCSV", split="train")
97
+ print("Dataset loaded successfully.")
98
+
99
+ print("Applying formatting function to the dataset...")
100
+ dataset = dataset.map(formatting_prompts_func, batched=True)
101
+ print("Formatting function applied.")
102
+
103
+ print("Initializing trainer...")
104
+ trainer = SFTTrainer(
105
+ model=model,
106
+ tokenizer=tokenizer,
107
+ train_dataset=dataset,
108
+ dataset_text_field="text",
109
+ max_seq_length=max_seq_length,
110
+ dataset_num_proc=2,
111
+ packing=False,
112
+ args=TrainingArguments(
113
+ per_device_train_batch_size=2,
114
+ gradient_accumulation_steps=2,
115
+ learning_rate=2e-4,
116
+ fp16=not is_bfloat16_supported(),
117
+ bf16=is_bfloat16_supported(),
118
+ warmup_steps=5,
119
+ logging_steps=10,
120
+ optim="adamw_8bit",
121
+ weight_decay=0.01,
122
+ lr_scheduler_type="linear",
123
+ seed=3407,
124
+ output_dir="outputs",
125
+ ),
126
+ )
127
+ print("Trainer initialized.")
128
+
129
+ print("Starting training...")
130
+ trainer_stats = trainer.train()
131
+ print("Training completed.")
132
+
133
+ print("Saving the trained model...")
134
+ model.save_pretrained_merged("model", tokenizer, save_method="merged_16bit")
135
+ print("Model saved successfully.")
136
+
137
+ print("Pushing the model to the hub...")
138
+ model.push_to_hub_merged(
139
+ "CyberSentinel-16bit",
140
+ tokenizer,
141
+ save_method="merged_16bit",
142
+ token=True
143
+ )
144
+ print("Model pushed to hub successfully.")
145
+
146
+ # Gradio app
147
+ print("Launching Gradio app...")
148
+ def greet(name):
149
+ return "Hello " + name + "!!"
150
+
151
+ demo = gr.Interface(fn=greet, inputs="text", outputs="text")
152
+ demo.launch()
153
+ print("Gradio app launched.")
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unsloth[colab-new] @ git+https://github.com/unslothai/unsloth.git
2
+ xformers[cuda]
3
+ torch
4
+ transformers
5
+ datasets
6
+ gradio
7
+ trl
8
+ peft
9
+ accelerate
10
+ bitsandbytes
11
+ huggingface_hub