Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,6 @@ from taming.models.vqgan import VQModel
|
|
23 |
from omegaconf import OmegaConf
|
24 |
from taming.models.vqgan import GumbelVQ
|
25 |
import gradio as gr
|
26 |
-
from modules.finetunedvqgan import Generator
|
27 |
from modules.modelz import DeepfakeToSourceTransformer
|
28 |
from modules.frameworkeval import DF
|
29 |
from modules.segmentface import FaceSegmenter
|
@@ -31,6 +30,25 @@ from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
|
|
31 |
|
32 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
33 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
##________________________Transformation______________________________
|
35 |
|
36 |
transform = T.Compose([
|
@@ -54,10 +72,17 @@ def gen_sources(deepfake_img):
|
|
54 |
#------------Initialize:Decoder-F------------------------
|
55 |
config_path = "./models/config.yaml"
|
56 |
checkpoint_path_f = "./models/model_vaq1_ff.pth"
|
57 |
-
|
|
|
|
|
|
|
|
|
58 |
#------------Initialize:Decoder-G------------------------
|
59 |
checkpoint_path_g = "./models/model_vaq2_gg.pth"
|
60 |
-
|
|
|
|
|
|
|
61 |
##------------------------Initialize Model-F-------------------------------------
|
62 |
model_z1 = DeepfakeToSourceTransformer().to(device)
|
63 |
model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
|
@@ -73,7 +98,6 @@ def gen_sources(deepfake_img):
|
|
73 |
with torch.no_grad():
|
74 |
# Load and preprocess input image
|
75 |
#img = Image.open(deepfake_img).convert('RGB')
|
76 |
-
#segimg = Image.open(deepfake_seg).convert('RGB')
|
77 |
df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
|
78 |
seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
|
79 |
|
|
|
23 |
from omegaconf import OmegaConf
|
24 |
from taming.models.vqgan import GumbelVQ
|
25 |
import gradio as gr
|
|
|
26 |
from modules.modelz import DeepfakeToSourceTransformer
|
27 |
from modules.frameworkeval import DF
|
28 |
from modules.segmentface import FaceSegmenter
|
|
|
30 |
|
31 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
|
33 |
+
config = OmegaConf.load(self.config_path)
|
34 |
+
# Extract parameters specific to GumbelVQ
|
35 |
+
vq_params = config.model.params
|
36 |
+
# Initialize the GumbelVQ models
|
37 |
+
model_vaq_f = GumbelVQ(
|
38 |
+
ddconfig=vq_params.ddconfig,
|
39 |
+
lossconfig=vq_params.lossconfig,
|
40 |
+
n_embed=vq_params.n_embed,
|
41 |
+
embed_dim=vq_params.embed_dim,
|
42 |
+
kl_weight=vq_params.kl_weight,
|
43 |
+
temperature_scheduler_config=vq_params.temperature_scheduler_config).to(device)
|
44 |
+
model_vaq_g = GumbelVQ(
|
45 |
+
ddconfig=vq_params.ddconfig,
|
46 |
+
lossconfig=vq_params.lossconfig,
|
47 |
+
n_embed=vq_params.n_embed,
|
48 |
+
embed_dim=vq_params.embed_dim,
|
49 |
+
kl_weight=vq_params.kl_weight,
|
50 |
+
temperature_scheduler_config=vq_params.temperature_scheduler_config).to(device)
|
51 |
+
|
52 |
##________________________Transformation______________________________
|
53 |
|
54 |
transform = T.Compose([
|
|
|
72 |
#------------Initialize:Decoder-F------------------------
|
73 |
config_path = "./models/config.yaml"
|
74 |
checkpoint_path_f = "./models/model_vaq1_ff.pth"
|
75 |
+
# Load model checkpoints
|
76 |
+
checkpoint_f = torch.load(self.checkpoint_path_f, map_location=self.device)
|
77 |
+
# Load the state dictionary into the models
|
78 |
+
model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True)
|
79 |
+
model_vaq_f.eval()
|
80 |
#------------Initialize:Decoder-G------------------------
|
81 |
checkpoint_path_g = "./models/model_vaq2_gg.pth"
|
82 |
+
checkpoint_g = torch.load(self.checkpoint_path_g, map_location=self.device)
|
83 |
+
# Load the state dictionary into the models
|
84 |
+
model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True)
|
85 |
+
model_vaq_g.eval()
|
86 |
##------------------------Initialize Model-F-------------------------------------
|
87 |
model_z1 = DeepfakeToSourceTransformer().to(device)
|
88 |
model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
|
|
|
98 |
with torch.no_grad():
|
99 |
# Load and preprocess input image
|
100 |
#img = Image.open(deepfake_img).convert('RGB')
|
|
|
101 |
df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
|
102 |
seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
|
103 |
|