Spaces:
Running
on
Zero
Running
on
Zero
Update OmniAvatar/models/model_manager.py
Browse files
OmniAvatar/models/model_manager.py
CHANGED
@@ -254,49 +254,8 @@ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
|
|
254 |
loaded_model_names += loaded_model_names_
|
255 |
loaded_models += loaded_models_
|
256 |
return loaded_model_names, loaded_models
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
class ModelDetectorFromHuggingfaceFolder:
|
261 |
-
def __init__(self, model_loader_configs=[]):
|
262 |
-
self.architecture_dict = {}
|
263 |
-
for metadata in model_loader_configs:
|
264 |
-
self.add_model_metadata(*metadata)
|
265 |
|
266 |
|
267 |
-
def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
|
268 |
-
self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
|
269 |
-
|
270 |
-
|
271 |
-
def match(self, file_path="", state_dict={}):
|
272 |
-
if not isinstance(file_path, str) or os.path.isfile(file_path):
|
273 |
-
return False
|
274 |
-
file_list = os.listdir(file_path)
|
275 |
-
if "config.json" not in file_list:
|
276 |
-
return False
|
277 |
-
with open(os.path.join(file_path, "config.json"), "r") as f:
|
278 |
-
config = json.load(f)
|
279 |
-
if "architectures" not in config and "_class_name" not in config:
|
280 |
-
return False
|
281 |
-
return True
|
282 |
-
|
283 |
-
|
284 |
-
def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
|
285 |
-
with open(os.path.join(file_path, "config.json"), "r") as f:
|
286 |
-
config = json.load(f)
|
287 |
-
loaded_model_names, loaded_models = [], []
|
288 |
-
architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
|
289 |
-
for architecture in architectures:
|
290 |
-
huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
|
291 |
-
if redirected_architecture is not None:
|
292 |
-
architecture = redirected_architecture
|
293 |
-
model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
|
294 |
-
loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
|
295 |
-
loaded_model_names += loaded_model_names_
|
296 |
-
loaded_models += loaded_models_
|
297 |
-
return loaded_model_names, loaded_models
|
298 |
-
|
299 |
-
|
300 |
|
301 |
class ModelDetectorFromPatchedSingleFile:
|
302 |
def __init__(self, model_loader_configs=[]):
|
@@ -357,7 +316,6 @@ class ModelManager:
|
|
357 |
self.model_detector = [
|
358 |
ModelDetectorFromSingleFile(model_loader_configs),
|
359 |
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
360 |
-
ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
|
361 |
]
|
362 |
self.load_models(downloaded_files + file_path_list)
|
363 |
|
|
|
254 |
loaded_model_names += loaded_model_names_
|
255 |
loaded_models += loaded_models_
|
256 |
return loaded_model_names, loaded_models
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
|
258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
|
260 |
class ModelDetectorFromPatchedSingleFile:
|
261 |
def __init__(self, model_loader_configs=[]):
|
|
|
316 |
self.model_detector = [
|
317 |
ModelDetectorFromSingleFile(model_loader_configs),
|
318 |
ModelDetectorFromSplitedSingleFile(model_loader_configs),
|
|
|
319 |
]
|
320 |
self.load_models(downloaded_files + file_path_list)
|
321 |
|