Spaces:
Runtime error
Runtime error
fix
Browse files- scripts/anime.py +1 -1
- scripts/data.py +1 -1
- scripts/model.py +1 -4
scripts/anime.py
CHANGED
|
@@ -19,7 +19,7 @@ model = None
|
|
| 19 |
def init_model(use_local=False):
|
| 20 |
global model
|
| 21 |
model_opt = "default"
|
| 22 |
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 23 |
model = create_model(model_opt, use_local).to(device)
|
| 24 |
model.eval()
|
| 25 |
|
|
|
|
| 19 |
def init_model(use_local=False):
|
| 20 |
global model
|
| 21 |
model_opt = "default"
|
| 22 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # issue: nevetherless, use_gpu is False, it still uses GPU
|
| 23 |
model = create_model(model_opt, use_local).to(device)
|
| 24 |
model.eval()
|
| 25 |
|
scripts/data.py
CHANGED
|
@@ -40,7 +40,7 @@ def get_transform(load_size=0, grayscale=False, method=bic, convert=True):
|
|
| 40 |
transform_list.append(transforms.Grayscale(1))
|
| 41 |
if load_size > 0:
|
| 42 |
osize = [load_size, load_size]
|
| 43 |
-
transform_list.append(transforms.Resize(osize, method))
|
| 44 |
if convert:
|
| 45 |
# transform_list += [transforms.ToTensor()]
|
| 46 |
if grayscale:
|
|
|
|
| 40 |
transform_list.append(transforms.Grayscale(1))
|
| 41 |
if load_size > 0:
|
| 42 |
osize = [load_size, load_size]
|
| 43 |
+
transform_list.append(transforms.Resize(osize, method, antialias=False))
|
| 44 |
if convert:
|
| 45 |
# transform_list += [transforms.ToTensor()]
|
| 46 |
if grayscale:
|
scripts/model.py
CHANGED
|
@@ -154,8 +154,7 @@ def create_model(model, use_local):
|
|
| 154 |
|
| 155 |
import os
|
| 156 |
if model == 'default':
|
| 157 |
-
model_path = (lambda filename, subfolder: os.path.join(subfolder, filename) if use_local else download_file(filename, subfolder))
|
| 158 |
-
("netG.pth", "models/Anime2Sketch")
|
| 159 |
# model_path = ((filename, subfolder) => if (use_local) os.path.join(subfolder, filename) else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch") // JavaScript
|
| 160 |
|
| 161 |
ckpt = torch.load(model_path)
|
|
@@ -176,8 +175,6 @@ def create_model(model, use_local):
|
|
| 176 |
base = base.model[3]
|
| 177 |
|
| 178 |
net.load_state_dict(ckpt)
|
| 179 |
-
|
| 180 |
-
os.chdir(cwd) # 元のディレクトリに戻る
|
| 181 |
|
| 182 |
else:
|
| 183 |
raise ValueError(f"model should be one of ['default', 'improved'], but got {model}")
|
|
|
|
| 154 |
|
| 155 |
import os
|
| 156 |
if model == 'default':
|
| 157 |
+
model_path = (lambda filename, subfolder: os.path.join(subfolder, filename) if use_local else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch")
|
|
|
|
| 158 |
# model_path = ((filename, subfolder) => if (use_local) os.path.join(subfolder, filename) else download_file(filename, subfolder))("netG.pth", "models/Anime2Sketch") // JavaScript
|
| 159 |
|
| 160 |
ckpt = torch.load(model_path)
|
|
|
|
| 175 |
base = base.model[3]
|
| 176 |
|
| 177 |
net.load_state_dict(ckpt)
|
|
|
|
|
|
|
| 178 |
|
| 179 |
else:
|
| 180 |
raise ValueError(f"model should be one of ['default', 'improved'], but got {model}")
|