alexnasa commited on
Commit
3ad111c
·
verified ·
1 Parent(s): 7a1adef

Update OmniAvatar/models/model_manager.py

Browse files
Files changed (1) hide show
  1. OmniAvatar/models/model_manager.py +0 -42
OmniAvatar/models/model_manager.py CHANGED
@@ -254,49 +254,8 @@ class ModelDetectorFromSplitedSingleFile(ModelDetectorFromSingleFile):
254
  loaded_model_names += loaded_model_names_
255
  loaded_models += loaded_models_
256
  return loaded_model_names, loaded_models
257
-
258
-
259
-
260
- class ModelDetectorFromHuggingfaceFolder:
261
- def __init__(self, model_loader_configs=[]):
262
- self.architecture_dict = {}
263
- for metadata in model_loader_configs:
264
- self.add_model_metadata(*metadata)
265
 
266
 
267
- def add_model_metadata(self, architecture, huggingface_lib, model_name, redirected_architecture):
268
- self.architecture_dict[architecture] = (huggingface_lib, model_name, redirected_architecture)
269
-
270
-
271
- def match(self, file_path="", state_dict={}):
272
- if not isinstance(file_path, str) or os.path.isfile(file_path):
273
- return False
274
- file_list = os.listdir(file_path)
275
- if "config.json" not in file_list:
276
- return False
277
- with open(os.path.join(file_path, "config.json"), "r") as f:
278
- config = json.load(f)
279
- if "architectures" not in config and "_class_name" not in config:
280
- return False
281
- return True
282
-
283
-
284
- def load(self, file_path="", state_dict={}, device="cuda", torch_dtype=torch.float16, **kwargs):
285
- with open(os.path.join(file_path, "config.json"), "r") as f:
286
- config = json.load(f)
287
- loaded_model_names, loaded_models = [], []
288
- architectures = config["architectures"] if "architectures" in config else [config["_class_name"]]
289
- for architecture in architectures:
290
- huggingface_lib, model_name, redirected_architecture = self.architecture_dict[architecture]
291
- if redirected_architecture is not None:
292
- architecture = redirected_architecture
293
- model_class = importlib.import_module(huggingface_lib).__getattribute__(architecture)
294
- loaded_model_names_, loaded_models_ = load_model_from_huggingface_folder(file_path, [model_name], [model_class], torch_dtype, device)
295
- loaded_model_names += loaded_model_names_
296
- loaded_models += loaded_models_
297
- return loaded_model_names, loaded_models
298
-
299
-
300
 
301
  class ModelDetectorFromPatchedSingleFile:
302
  def __init__(self, model_loader_configs=[]):
@@ -357,7 +316,6 @@ class ModelManager:
357
  self.model_detector = [
358
  ModelDetectorFromSingleFile(model_loader_configs),
359
  ModelDetectorFromSplitedSingleFile(model_loader_configs),
360
- ModelDetectorFromHuggingfaceFolder(huggingface_model_loader_configs),
361
  ]
362
  self.load_models(downloaded_files + file_path_list)
363
 
 
254
  loaded_model_names += loaded_model_names_
255
  loaded_models += loaded_models_
256
  return loaded_model_names, loaded_models
 
 
 
 
 
 
 
 
257
 
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
  class ModelDetectorFromPatchedSingleFile:
261
  def __init__(self, model_loader_configs=[]):
 
316
  self.model_detector = [
317
  ModelDetectorFromSingleFile(model_loader_configs),
318
  ModelDetectorFromSplitedSingleFile(model_loader_configs),
 
319
  ]
320
  self.load_models(downloaded_files + file_path_list)
321