Jannat24 commited on
Commit
e49aff4
·
verified ·
1 Parent(s): 038b7f1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -4
app.py CHANGED
@@ -23,7 +23,6 @@ from taming.models.vqgan import VQModel
23
  from omegaconf import OmegaConf
24
  from taming.models.vqgan import GumbelVQ
25
  import gradio as gr
26
- from modules.finetunedvqgan import Generator
27
  from modules.modelz import DeepfakeToSourceTransformer
28
  from modules.frameworkeval import DF
29
  from modules.segmentface import FaceSegmenter
@@ -31,6 +30,25 @@ from modules.denormalize import denormalize_bin, denormalize_tr, denormalize_ar
31
 
32
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  ##________________________Transformation______________________________
35
 
36
  transform = T.Compose([
@@ -54,10 +72,17 @@ def gen_sources(deepfake_img):
54
  #------------Initialize:Decoder-F------------------------
55
  config_path = "./models/config.yaml"
56
  checkpoint_path_f = "./models/model_vaq1_ff.pth"
57
- model_vaq_f = Generator(config_path, checkpoint_path_f, device)
 
 
 
 
58
  #------------Initialize:Decoder-G------------------------
59
  checkpoint_path_g = "./models/model_vaq2_gg.pth"
60
- model_vaq_g = Generator(config_path, checkpoint_path_g, device)
 
 
 
61
  ##------------------------Initialize Model-F-------------------------------------
62
  model_z1 = DeepfakeToSourceTransformer().to(device)
63
  model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
@@ -73,7 +98,6 @@ def gen_sources(deepfake_img):
73
  with torch.no_grad():
74
  # Load and preprocess input image
75
  #img = Image.open(deepfake_img).convert('RGB')
76
- #segimg = Image.open(deepfake_seg).convert('RGB')
77
  df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
78
  seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
79
 
 
23
  from omegaconf import OmegaConf
24
  from taming.models.vqgan import GumbelVQ
25
  import gradio as gr
 
26
  from modules.modelz import DeepfakeToSourceTransformer
27
  from modules.frameworkeval import DF
28
  from modules.segmentface import FaceSegmenter
 
30
 
31
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
 
33
+ config = OmegaConf.load(self.config_path)
34
+ # Extract parameters specific to GumbelVQ
35
+ vq_params = config.model.params
36
+ # Initialize the GumbelVQ models
37
+ model_vaq_f = GumbelVQ(
38
+ ddconfig=vq_params.ddconfig,
39
+ lossconfig=vq_params.lossconfig,
40
+ n_embed=vq_params.n_embed,
41
+ embed_dim=vq_params.embed_dim,
42
+ kl_weight=vq_params.kl_weight,
43
+ temperature_scheduler_config=vq_params.temperature_scheduler_config).to(device)
44
+ model_vaq_g = GumbelVQ(
45
+ ddconfig=vq_params.ddconfig,
46
+ lossconfig=vq_params.lossconfig,
47
+ n_embed=vq_params.n_embed,
48
+ embed_dim=vq_params.embed_dim,
49
+ kl_weight=vq_params.kl_weight,
50
+ temperature_scheduler_config=vq_params.temperature_scheduler_config).to(device)
51
+
52
  ##________________________Transformation______________________________
53
 
54
  transform = T.Compose([
 
72
  #------------Initialize:Decoder-F------------------------
73
  config_path = "./models/config.yaml"
74
  checkpoint_path_f = "./models/model_vaq1_ff.pth"
75
+ # Load model checkpoints
76
+ checkpoint_f = torch.load(self.checkpoint_path_f, map_location=self.device)
77
+ # Load the state dictionary into the models
78
+ model_vaq_f = model_vaq_f.load_state_dict(checkpoint_f, strict=True)
79
+ model_vaq_f.eval()
80
  #------------Initialize:Decoder-G------------------------
81
  checkpoint_path_g = "./models/model_vaq2_gg.pth"
82
+ checkpoint_g = torch.load(self.checkpoint_path_g, map_location=self.device)
83
+ # Load the state dictionary into the models
84
+ model_vaq_g = model_vaq_g.load_state_dict(checkpoint_g, strict=True)
85
+ model_vaq_g.eval()
86
  ##------------------------Initialize Model-F-------------------------------------
87
  model_z1 = DeepfakeToSourceTransformer().to(device)
88
  model_z1.load_state_dict(torch.load("./models/model_z1_ff.pth",map_location=device),strict=True)
 
98
  with torch.no_grad():
99
  # Load and preprocess input image
100
  #img = Image.open(deepfake_img).convert('RGB')
 
101
  df_img = transform(deepfake_img.convert('RGB')).unsqueeze(0).to(device) # Shape: (1, 3, 256, 256)
102
  seg_img = transform(deepfake_seg).unsqueeze(0).to(device)
103