MohamedRashad commited on
Commit
66cfd91
·
1 Parent(s): 9d8246c

Make the generator CPU

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -160,7 +160,7 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
160
  return save_file
161
 
162
  def load_tokenizer(t5_path =''):
163
- print(f'[Loading tokenizer and text encoder]')
164
  text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
165
  text_tokenizer.model_max_length = 512
166
  text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
@@ -186,7 +186,7 @@ def load_infinity(
186
  use_flex_attn=False,
187
  bf16=False,
188
  ):
189
- print(f'[Loading Infinity]')
190
 
191
  # Set device if not provided
192
  if device is None:
@@ -232,13 +232,13 @@ def load_infinity(
232
  infinity_test.eval()
233
  infinity_test.requires_grad_(False)
234
 
235
- print(f'[Load Infinity weights]')
236
  state_dict = torch.load(model_path, map_location=device)
237
  print(infinity_test.load_state_dict(state_dict))
238
 
239
  # Initialize random number generator on the correct device
240
- infinity_test.rng = torch.Generator(device=device)
241
-
242
  return infinity_test
243
 
244
  def transform(pil_img, tgt_h, tgt_w):
 
160
  return save_file
161
 
162
  def load_tokenizer(t5_path =''):
163
+ print('[Loading tokenizer and text encoder]')
164
  text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
165
  text_tokenizer.model_max_length = 512
166
  text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
 
186
  use_flex_attn=False,
187
  bf16=False,
188
  ):
189
+ print('[Loading Infinity]')
190
 
191
  # Set device if not provided
192
  if device is None:
 
232
  infinity_test.eval()
233
  infinity_test.requires_grad_(False)
234
 
235
+ print('[Load Infinity weights]')
236
  state_dict = torch.load(model_path, map_location=device)
237
  print(infinity_test.load_state_dict(state_dict))
238
 
239
  # Initialize random number generator on the correct device
240
+ infinity_test.rng = torch.Generator(device="cpu")
241
+
242
  return infinity_test
243
 
244
  def transform(pil_img, tgt_h, tgt_w):