Update app.py
Browse files
app.py
CHANGED
@@ -10,6 +10,7 @@ import torch.nn as nn
|
|
10 |
from transformers import pipeline
|
11 |
from PIL import Image
|
12 |
import inspect
|
|
|
13 |
|
14 |
# =============================================================================
|
15 |
# Aesthetic-Shadow (using Hugging Face transformers pipeline)
|
@@ -73,8 +74,12 @@ def load_clip_models(name: str = "ViT-L/14", device='cuda'):
|
|
73 |
|
74 |
def load_model(model_path: str, input_size=768, device: str = 'cuda', dtype=None):
|
75 |
model = MLP(input_size=input_size)
|
76 |
-
|
77 |
-
|
|
|
|
|
|
|
|
|
78 |
model.to(device)
|
79 |
if dtype:
|
80 |
model = model.to(dtype=dtype)
|
|
|
10 |
from transformers import pipeline
|
11 |
from PIL import Image
|
12 |
import inspect
|
13 |
+
import safetensors.torch
|
14 |
|
15 |
# =============================================================================
|
16 |
# Aesthetic-Shadow (using Hugging Face transformers pipeline)
|
|
|
74 |
|
75 |
def load_model(model_path: str, input_size=768, device: str = 'cuda', dtype=None):
|
76 |
model = MLP(input_size=input_size)
|
77 |
+
if model_path.endswith(".safetensors"):
|
78 |
+
state_dict = safetensors.torch.load_file(model_path, device=device)
|
79 |
+
else:
|
80 |
+
state = torch.load(model_path, map_location=device, weights_only=False)
|
81 |
+
state_dict = state
|
82 |
+
model.load_state_dict(state_dict)
|
83 |
model.to(device)
|
84 |
if dtype:
|
85 |
model = model.to(dtype=dtype)
|