Spaces:
Runtime error
Runtime error
File size: 5,180 Bytes
8969f81 1f7c716 8969f81 1f7c716 8969f81 1f7c716 8969f81 1f7c716 8969f81 1f7c716 8969f81 1f7c716 8969f81 1f7c716 8969f81 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import time
import torch
from transformers import (GPT2LMHeadModel, GPT2Tokenizer, GPT2Config,
OpenAIGPTLMHeadModel, OpenAIGPTTokenizer,
XLNetLMHeadModel, XLNetTokenizer,
TransfoXLLMHeadModel, TransfoXLTokenizer,
CTRLLMHeadModel, CTRLTokenizer)
model_metadata = {
"gpt2/small": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 550,
"checkpoint": "gpt2",
"identifier": "gpt2/small"
}, "gpt": {
"tokenizer": OpenAIGPTTokenizer,
"model": OpenAIGPTLMHeadModel,
"size": 550,
"checkpoint": "openai-community/openai-gpt",
"identifier": "gpt"
}, "xlnet": {
"tokenizer": XLNetTokenizer,
"model": XLNetLMHeadModel,
"size": 550,
"checkpoint": "xlnet-base-cased",
"identifier": "xlnet"
}, "gpt2/arxiv-nlp": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 550,
"checkpoint": "arxiv-nlp-v1",
"identifier": "gpt2/arxiv-nlp"
}, "gpt2/medium": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 1500,
"checkpoint": "openai-community/gpt2-medium",
"identifier": "gpt2/medium"
}, "gpt2/large": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 3300,
"checkpoint": "openai-community/gpt2-large",
"identifier": "gpt2/large"
}, "distilgpt2/small": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 350,
"checkpoint": "distilgpt2",
"identifier": "distilgpt2/small"
}, "ctrl": {
"tokenizer": CTRLTokenizer,
"model": CTRLLMHeadModel,
"size": 6300,
"checkpoint": "Salesforce/ctrl",
"identifier": "ctrl"
}, "pplm": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 3000,
"checkpoint": "openai-community/gpt2-large",
"identifier": "pplm"
}, "gpt2/xl": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 7000,
"checkpoint": "openai-community/gpt2-xl",
"identifier": "gpt2/xl"
}, "pplm": {
"tokenizer": GPT2Tokenizer,
"model": GPT2LMHeadModel,
"size": 4000,
"checkpoint": "openai-community/gpt2-medium",
"identifier": "pplm",
"configuration_options": {
"config": GPT2Config,
"options": {
"output_hidden_states": True
}
}
}
}
memory_overhead = 500
class GPU:
def __init__(self, id):
self.id = id
self.models = []
self.total_memory = torch.cuda.get_device_properties(
"cuda:{}".format(id)).total_memory / 1_000_000 - 1_000
print("INIT GPU WITH DEVICE", "cuda:{}".format(id))
def register_model(self, model, cached_path=None):
if self.total_memory_used() + model["size"] < self.total_memory:
model["device"] = "cuda:{}".format(self.id)
if cached_path:
model["cached_path"] = cached_path
self.models.append(model)
return True
else:
return False
def total_memory_used(self):
return sum([model["size"] for model in self.models]) + memory_overhead
def __repr__(self):
return str(
[(model["checkpoint"], model["size"]) for model in self.models] +
[str(round(100 * (self.total_memory_used() / self.total_memory))) + "%"] +
["cuda:{}".format(self.id)]
)
class GPUHandler:
def __init__(self, ids, model_list, gpu_ids, cached_models=None):
if cached_models is None:
cached_models = {}
self.gpus = [GPU(id) for id in gpu_ids]
print("GPU handler initiated with {} gpus.".format(len(self.gpus)))
self.sanity_check([model_metadata[model] for model in model_list])
for model in model_list:
self.register_model(model_metadata[model], cached_models.get(model))
def register_model(self, model, cached_path=None):
for index, gpu in enumerate(self.gpus):
if gpu.register_model(model, cached_path):
print("Registered model", model, "in GPU", gpu)
break
if index >= len(self.gpus):
raise ValueError("Could not load model", model["checkpoint"])
def sanity_check(self, model_list):
temp_gpus = [GPU(id) for id in range(len(self.gpus))]
for model in model_list:
current_gpu_index = 0
while current_gpu_index < len(temp_gpus):
if not temp_gpus[current_gpu_index].register_model(model):
current_gpu_index += 1
else:
break
if current_gpu_index >= len(temp_gpus):
raise RuntimeError("SANITY CHECK FAILED")
print("Current layout", temp_gpus)
def __repr__(self):
return f"NO. GPUS: {len(self.gpus)}.\n{self.gpus}"
|