Upload main.py with huggingface_hub
Browse files
main.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
6 |
+
from importlib.metadata import version
|
7 |
+
|
8 |
+
from lib.prune import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers
|
9 |
+
from lib.eval import eval_ppl, eval_zero_shot
|
10 |
+
|
11 |
+
print('torch', version('torch'))
|
12 |
+
print('transformers', version('transformers'))
|
13 |
+
print('accelerate', version('accelerate'))
|
14 |
+
print('# of gpus: ', torch.cuda.device_count())
|
15 |
+
|
16 |
+
def get_llm(model_name, cache_dir="llm_weights"):
|
17 |
+
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map="auto")
|
18 |
+
|
19 |
+
model.seqlen = model.config.max_position_embeddings
|
20 |
+
return model
|
21 |
+
|
22 |
+
def main():
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument('--model', type=str, help='LLaMA model')
|
25 |
+
parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
|
26 |
+
parser.add_argument('--nsamples', type=int, default=1, help='Number of calibration samples.')
|
27 |
+
parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
|
28 |
+
parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
|
29 |
+
parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
|
30 |
+
"ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
|
31 |
+
parser.add_argument("--cache_dir", default="llm_weights", type=str )
|
32 |
+
parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
|
33 |
+
parser.add_argument('--save', type=str, default=None, help='Path to save results.')
|
34 |
+
parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
|
35 |
+
|
36 |
+
parser.add_argument("--eval_zero_shot", action="store_true")
|
37 |
+
parser.add_argument('--module_name', type=str, default=None, help='Module to prune.')
|
38 |
+
parser.add_argument('--permutate_mode', type=str, default=None, help='Full or LoRA or Eval.')
|
39 |
+
parser.add_argument('--layer_id', type=int, default=None, help='Full or LoRA.')
|
40 |
+
args = parser.parse_args()
|
41 |
+
|
42 |
+
# Setting seeds for reproducibility
|
43 |
+
np.random.seed(args.seed)
|
44 |
+
torch.random.manual_seed(args.seed)
|
45 |
+
|
46 |
+
# Handling n:m sparsity
|
47 |
+
prune_n, prune_m = 0, 0
|
48 |
+
if args.sparsity_type != "unstructured":
|
49 |
+
assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
|
50 |
+
prune_n, prune_m = map(int, args.sparsity_type.split(":"))
|
51 |
+
|
52 |
+
model_name = args.model.split("/")[-1]
|
53 |
+
print(f"loading llm model {args.model}")
|
54 |
+
model = get_llm(args.model, args.cache_dir)
|
55 |
+
model.eval()
|
56 |
+
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
|
57 |
+
|
58 |
+
device = torch.device("cuda:0")
|
59 |
+
if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
|
60 |
+
device = model.hf_device_map["lm_head"]
|
61 |
+
print("use device ", device)
|
62 |
+
|
63 |
+
if args.sparsity_ratio != 0:
|
64 |
+
print("pruning starts")
|
65 |
+
if args.prune_method == "wanda":
|
66 |
+
prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
|
67 |
+
elif args.prune_method == "magnitude":
|
68 |
+
prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
|
69 |
+
elif args.prune_method == "sparsegpt":
|
70 |
+
prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
|
71 |
+
elif "ablate" in args.prune_method:
|
72 |
+
prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
|
73 |
+
|
74 |
+
################################################################
|
75 |
+
print("*"*30)
|
76 |
+
sparsity_ratio = check_sparsity(model)
|
77 |
+
print(f"sparsity sanity check {sparsity_ratio:.4f}")
|
78 |
+
print("*"*30)
|
79 |
+
################################################################
|
80 |
+
# ppl_test = eval_ppl(args, model, tokenizer, device)
|
81 |
+
# print(f"wikitext perplexity {ppl_test}")
|
82 |
+
|
83 |
+
if not os.path.exists(args.save):
|
84 |
+
os.makedirs(args.save)
|
85 |
+
save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt")
|
86 |
+
# with open(save_filepath, "w") as f:
|
87 |
+
# print("method\tactual_sparsity\tppl_test", file=f, flush=True)
|
88 |
+
# print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True)
|
89 |
+
|
90 |
+
# if args.eval_zero_shot:
|
91 |
+
# accelerate=False
|
92 |
+
# if "30b" in args.model or "65b" in args.model or "70b" in args.model:
|
93 |
+
# accelerate=True
|
94 |
+
|
95 |
+
# # task_list = ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"]
|
96 |
+
# task_list = ["winogrande"]
|
97 |
+
# num_shot = 0
|
98 |
+
# results = eval_zero_shot(args.model, model, tokenizer, task_list, num_shot, accelerate)
|
99 |
+
# print("********************************")
|
100 |
+
# print("zero_shot evaluation results")
|
101 |
+
# print(results)
|
102 |
+
|
103 |
+
if args.save_model:
|
104 |
+
model.save_pretrained(args.save_model)
|
105 |
+
tokenizer.save_pretrained(args.save_model)
|
106 |
+
|
107 |
+
if __name__ == '__main__':
|
108 |
+
main()
|