Spaces:
Runtime error
Runtime error
Commit
·
725545d
1
Parent(s):
3e5a852
update space
Browse files
app.py
CHANGED
@@ -5,7 +5,7 @@ from functools import partial
|
|
5 |
from typing import Optional
|
6 |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
7 |
from shap_e.diffusion.sample import sample_latents
|
8 |
-
from shap_e.models.download import load_model, load_config
|
9 |
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
|
10 |
import trimesh
|
11 |
import torch.nn as nn
|
@@ -275,10 +275,25 @@ def main():
|
|
275 |
"""
|
276 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
277 |
print("device:", device)
|
278 |
-
latent_model = load_model('text300M', device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
279 |
print("loaded latent model")
|
280 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
281 |
print("loaded transmitter")
|
|
|
|
|
282 |
diffusion = diffusion_from_config(load_config('diffusion'))
|
283 |
freeze_params(xm.parameters())
|
284 |
models = dict()
|
|
|
5 |
from typing import Optional
|
6 |
from shap_e.diffusion.gaussian_diffusion import diffusion_from_config
|
7 |
from shap_e.diffusion.sample import sample_latents
|
8 |
+
from shap_e.models.download import load_model, load_config, load_checkpoint
|
9 |
from shap_e.util.notebooks import create_pan_cameras, decode_latent_mesh
|
10 |
import trimesh
|
11 |
import torch.nn as nn
|
|
|
275 |
"""
|
276 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
277 |
print("device:", device)
|
278 |
+
# latent_model = load_model('text300M', device=device)
|
279 |
+
|
280 |
+
latent_model = model_from_config(load_config('text300M'), device=device)
|
281 |
+
# print(model_name, kwargs)
|
282 |
+
# print(model)
|
283 |
+
latent_model.load_state_dict(load_checkpoint('text300M', device='cpu'))
|
284 |
+
latent_model.eval()
|
285 |
print("loaded latent model")
|
286 |
+
latent_model.to(device)
|
287 |
+
# xm = load_model('transmitter', device=device)
|
288 |
+
|
289 |
+
xm = model_from_config(load_config('transmitter'), device=device)
|
290 |
+
# print(model_name, kwargs)
|
291 |
+
# print(model)
|
292 |
+
xm.load_state_dict(load_checkpoint('transmitter', device='cpu'))
|
293 |
+
xm.eval()
|
294 |
print("loaded transmitter")
|
295 |
+
xm.to(device)
|
296 |
+
|
297 |
diffusion = diffusion_from_config(load_config('diffusion'))
|
298 |
freeze_params(xm.parameters())
|
299 |
models = dict()
|