Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
|
@@ -21,8 +21,6 @@ from lora_w2w import LoRAw2w
|
|
| 21 |
from huggingface_hub import snapshot_download
|
| 22 |
import spaces
|
| 23 |
|
| 24 |
-
|
| 25 |
-
|
| 26 |
global device
|
| 27 |
global generator
|
| 28 |
global unet
|
|
@@ -36,13 +34,13 @@ device = "cuda"
|
|
| 36 |
|
| 37 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
| 38 |
|
| 39 |
-
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16()
|
| 40 |
-
std = torch.load(f"{models_path}/files/std.pt").bfloat16()
|
| 41 |
-
v = torch.load(f"{models_path}/files/V.pt").bfloat16()
|
| 42 |
-
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16()
|
| 43 |
df = torch.load(f"{models_path}/files/identity_df.pt")
|
| 44 |
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
| 45 |
-
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt")
|
| 46 |
|
| 47 |
|
| 48 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
|
@@ -51,7 +49,10 @@ def sample_model():
|
|
| 51 |
global unet
|
| 52 |
del unet
|
| 53 |
global network
|
| 54 |
-
|
|
|
|
|
|
|
|
|
|
| 55 |
unet, _, _, _, _ = load_models(device)
|
| 56 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
| 57 |
|
|
|
|
| 21 |
from huggingface_hub import snapshot_download
|
| 22 |
import spaces
|
| 23 |
|
|
|
|
|
|
|
| 24 |
global device
|
| 25 |
global generator
|
| 26 |
global unet
|
|
|
|
| 34 |
|
| 35 |
models_path = snapshot_download(repo_id="Snapchat/w2w")
|
| 36 |
|
| 37 |
+
mean = torch.load(f"{models_path}/files/mean.pt").bfloat16()#.to(device)
|
| 38 |
+
std = torch.load(f"{models_path}/files/std.pt").bfloat16()#.to(device)
|
| 39 |
+
v = torch.load(f"{models_path}/files/V.pt").bfloat16()#.to(device)
|
| 40 |
+
proj = torch.load(f"{models_path}/files/proj_1000pc.pt").bfloat16()#.to(device)
|
| 41 |
df = torch.load(f"{models_path}/files/identity_df.pt")
|
| 42 |
weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
|
| 43 |
+
pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt")#.bfloat16()#.to(device)
|
| 44 |
|
| 45 |
|
| 46 |
unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
|
|
|
|
| 49 |
global unet
|
| 50 |
del unet
|
| 51 |
global network
|
| 52 |
+
mean.to(device)
|
| 53 |
+
std.to(device)
|
| 54 |
+
v.to(device)
|
| 55 |
+
proj.to(device)
|
| 56 |
unet, _, _, _, _ = load_models(device)
|
| 57 |
network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
|
| 58 |
|