import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoModelForCausalLM, AutoTokenizer import re import transformers import torch from tqdm import tqdm from transformers import GPT2LMHeadModel, GPT2TokenizerFast import warnings warnings.filterwarnings("ignore") device = "cuda" model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device) tokenizer = GPT2TokenizerFast.from_pretrained("openai-community/gpt2") from datasets import load_dataset test = load_dataset("wikitext", "wikitext-2-raw-v1", split="validation") # print(len(test)) encodings = tokenizer("\n\n".join(test["text"]), return_tensors="pt") import time import gc def run_experiment(model): print(f'Memory usage of model alone = {model.get_memory_footprint()/10**6}') max_length = model.config.n_positions stride = 512 seq_len = encodings.input_ids.size(1) nlls = [] start_time = time.time() prev_end_loc = 0 for begin_loc in tqdm(range(0, seq_len, stride)): end_loc = min(begin_loc + max_length, seq_len) trg_len = end_loc - prev_end_loc # may be different from stride on last loop input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device) target_ids = input_ids.clone() target_ids[:, :-trg_len] = -100 with torch.no_grad(): outputs = model(input_ids, labels=target_ids) # loss is calculated using CrossEntropyLoss which averages over valid labels neg_log_likelihood = outputs.loss if begin_loc == 0: print(f'Memory usage at forward pass = {torch.cuda.memory_allocated(0)/10**6}') nlls.append(neg_log_likelihood) prev_end_loc = end_loc if end_loc == seq_len: break ppl = torch.exp(torch.stack(nlls).mean()) print(f'Loss = {ppl.item()}') print(f'Time taken: {- start_time + time.time()}') from quant import perform_quantization model_type = 0 if model_type == 0: ## Normal print('Normal model') run_experiment(model) print() ## Full model quant (including lm_head) if model_type == 0: print('Full model quant') perform_quantization(model) torch.save(model, 'q1-full-quant.pt') # print(model) run_experiment(model) print() # Without lm_head if model_type == 0: print('Full model without lm_head') model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device) perform_quantization(model, regex=r"transformer\.h\.\d+\.[a-zA-Z]+") # print(model) run_experiment(model) print() # Only lm_head if model_type == 0: print('Only LM head') model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device) perform_quantization(model, regex=r"[\w.]*lm_head[\w.]*") # print(gc.collect()) # print(model) run_experiment(model) print() # Last 4 layers if model_type == 0: print('Last 4 attention layers') model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device) perform_quantization(model, regex=r"transformer\.h\.(8|9|10|11)\.[a-zA-Z]+") # print(gc.collect()) # print(model) run_experiment(model) print() # Only q,k,v if model_type == 0: print('Only q,k,v') model =AutoModelForCausalLM.from_pretrained("gpt2", ).to(device) perform_quantization(model, regex=r"[\w.]*attn[\w.]*") # print(model) run_experiment(model) print()