Zkli commited on
Commit
4e03346
·
verified ·
1 Parent(s): 8afac5b

Upload main.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. main.py +108 -0
main.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import numpy as np
4
+ import torch
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from importlib.metadata import version
7
+
8
+ from lib.prune import prune_wanda, prune_magnitude, prune_sparsegpt, prune_ablate, check_sparsity, find_layers
9
+ from lib.eval import eval_ppl, eval_zero_shot
10
+
11
+ print('torch', version('torch'))
12
+ print('transformers', version('transformers'))
13
+ print('accelerate', version('accelerate'))
14
+ print('# of gpus: ', torch.cuda.device_count())
15
+
16
+ def get_llm(model_name, cache_dir="llm_weights"):
17
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, cache_dir=cache_dir, low_cpu_mem_usage=True, device_map="auto")
18
+
19
+ model.seqlen = model.config.max_position_embeddings
20
+ return model
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument('--model', type=str, help='LLaMA model')
25
+ parser.add_argument('--seed', type=int, default=0, help='Seed for sampling the calibration data.')
26
+ parser.add_argument('--nsamples', type=int, default=1, help='Number of calibration samples.')
27
+ parser.add_argument('--sparsity_ratio', type=float, default=0, help='Sparsity level')
28
+ parser.add_argument("--sparsity_type", type=str, choices=["unstructured", "4:8", "2:4"])
29
+ parser.add_argument("--prune_method", type=str, choices=["magnitude", "wanda", "sparsegpt",
30
+ "ablate_mag_seq", "ablate_wanda_seq", "ablate_mag_iter", "ablate_wanda_iter", "search"])
31
+ parser.add_argument("--cache_dir", default="llm_weights", type=str )
32
+ parser.add_argument('--use_variant', action="store_true", help="whether to use the wanda variant described in the appendix")
33
+ parser.add_argument('--save', type=str, default=None, help='Path to save results.')
34
+ parser.add_argument('--save_model', type=str, default=None, help='Path to save the pruned model.')
35
+
36
+ parser.add_argument("--eval_zero_shot", action="store_true")
37
+ parser.add_argument('--module_name', type=str, default=None, help='Module to prune.')
38
+ parser.add_argument('--permutate_mode', type=str, default=None, help='Full or LoRA or Eval.')
39
+ parser.add_argument('--layer_id', type=int, default=None, help='Full or LoRA.')
40
+ args = parser.parse_args()
41
+
42
+ # Setting seeds for reproducibility
43
+ np.random.seed(args.seed)
44
+ torch.random.manual_seed(args.seed)
45
+
46
+ # Handling n:m sparsity
47
+ prune_n, prune_m = 0, 0
48
+ if args.sparsity_type != "unstructured":
49
+ assert args.sparsity_ratio == 0.5, "sparsity ratio must be 0.5 for structured N:M sparsity"
50
+ prune_n, prune_m = map(int, args.sparsity_type.split(":"))
51
+
52
+ model_name = args.model.split("/")[-1]
53
+ print(f"loading llm model {args.model}")
54
+ model = get_llm(args.model, args.cache_dir)
55
+ model.eval()
56
+ tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
57
+
58
+ device = torch.device("cuda:0")
59
+ if "30b" in args.model or "65b" in args.model: # for 30b and 65b we use device_map to load onto multiple A6000 GPUs, thus the processing here.
60
+ device = model.hf_device_map["lm_head"]
61
+ print("use device ", device)
62
+
63
+ if args.sparsity_ratio != 0:
64
+ print("pruning starts")
65
+ if args.prune_method == "wanda":
66
+ prune_wanda(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
67
+ elif args.prune_method == "magnitude":
68
+ prune_magnitude(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
69
+ elif args.prune_method == "sparsegpt":
70
+ prune_sparsegpt(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
71
+ elif "ablate" in args.prune_method:
72
+ prune_ablate(args, model, tokenizer, device, prune_n=prune_n, prune_m=prune_m)
73
+
74
+ ################################################################
75
+ print("*"*30)
76
+ sparsity_ratio = check_sparsity(model)
77
+ print(f"sparsity sanity check {sparsity_ratio:.4f}")
78
+ print("*"*30)
79
+ ################################################################
80
+ # ppl_test = eval_ppl(args, model, tokenizer, device)
81
+ # print(f"wikitext perplexity {ppl_test}")
82
+
83
+ if not os.path.exists(args.save):
84
+ os.makedirs(args.save)
85
+ save_filepath = os.path.join(args.save, f"log_{args.prune_method}.txt")
86
+ # with open(save_filepath, "w") as f:
87
+ # print("method\tactual_sparsity\tppl_test", file=f, flush=True)
88
+ # print(f"{args.prune_method}\t{sparsity_ratio:.4f}\t{ppl_test:.4f}", file=f, flush=True)
89
+
90
+ # if args.eval_zero_shot:
91
+ # accelerate=False
92
+ # if "30b" in args.model or "65b" in args.model or "70b" in args.model:
93
+ # accelerate=True
94
+
95
+ # # task_list = ["boolq", "rte","hellaswag","winogrande", "arc_easy","arc_challenge", "openbookqa"]
96
+ # task_list = ["winogrande"]
97
+ # num_shot = 0
98
+ # results = eval_zero_shot(args.model, model, tokenizer, task_list, num_shot, accelerate)
99
+ # print("********************************")
100
+ # print("zero_shot evaluation results")
101
+ # print(results)
102
+
103
+ if args.save_model:
104
+ model.save_pretrained(args.save_model)
105
+ tokenizer.save_pretrained(args.save_model)
106
+
107
+ if __name__ == '__main__':
108
+ main()