VOIDER commited on
Commit
3e91ca9
·
verified ·
1 Parent(s): 112033c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -2
app.py CHANGED
@@ -10,6 +10,7 @@ import torch.nn as nn
10
  from transformers import pipeline
11
  from PIL import Image
12
  import inspect
 
13
 
14
  # =============================================================================
15
  # Aesthetic-Shadow (using Hugging Face transformers pipeline)
@@ -73,8 +74,12 @@ def load_clip_models(name: str = "ViT-L/14", device='cuda'):
73
 
74
  def load_model(model_path: str, input_size=768, device: str = 'cuda', dtype=None):
75
  model = MLP(input_size=input_size)
76
- state = torch.load(model_path, map_location=device, weights_only=False)
77
- model.load_state_dict(state)
 
 
 
 
78
  model.to(device)
79
  if dtype:
80
  model = model.to(dtype=dtype)
 
10
  from transformers import pipeline
11
  from PIL import Image
12
  import inspect
13
+ import safetensors.torch
14
 
15
  # =============================================================================
16
  # Aesthetic-Shadow (using Hugging Face transformers pipeline)
 
74
 
75
  def load_model(model_path: str, input_size=768, device: str = 'cuda', dtype=None):
76
  model = MLP(input_size=input_size)
77
+ if model_path.endswith(".safetensors"):
78
+ state_dict = safetensors.torch.load_file(model_path, device=device)
79
+ else:
80
+ state = torch.load(model_path, map_location=device, weights_only=False)
81
+ state_dict = state
82
+ model.load_state_dict(state_dict)
83
  model.to(device)
84
  if dtype:
85
  model = model.to(dtype=dtype)