Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
66cfd91
1
Parent(s):
9d8246c
Make the generator CPU
Browse files
app.py
CHANGED
@@ -160,7 +160,7 @@ def save_slim_model(infinity_model_path, save_file=None, device='cpu', key='gpt_
|
|
160 |
return save_file
|
161 |
|
162 |
def load_tokenizer(t5_path =''):
|
163 |
-
print(
|
164 |
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
|
165 |
text_tokenizer.model_max_length = 512
|
166 |
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
|
@@ -186,7 +186,7 @@ def load_infinity(
|
|
186 |
use_flex_attn=False,
|
187 |
bf16=False,
|
188 |
):
|
189 |
-
print(
|
190 |
|
191 |
# Set device if not provided
|
192 |
if device is None:
|
@@ -232,13 +232,13 @@ def load_infinity(
|
|
232 |
infinity_test.eval()
|
233 |
infinity_test.requires_grad_(False)
|
234 |
|
235 |
-
print(
|
236 |
state_dict = torch.load(model_path, map_location=device)
|
237 |
print(infinity_test.load_state_dict(state_dict))
|
238 |
|
239 |
# Initialize random number generator on the correct device
|
240 |
-
infinity_test.rng = torch.Generator(device=
|
241 |
-
|
242 |
return infinity_test
|
243 |
|
244 |
def transform(pil_img, tgt_h, tgt_w):
|
|
|
160 |
return save_file
|
161 |
|
162 |
def load_tokenizer(t5_path =''):
|
163 |
+
print('[Loading tokenizer and text encoder]')
|
164 |
text_tokenizer: T5TokenizerFast = AutoTokenizer.from_pretrained(t5_path, revision=None, legacy=True)
|
165 |
text_tokenizer.model_max_length = 512
|
166 |
text_encoder: T5EncoderModel = T5EncoderModel.from_pretrained(t5_path, torch_dtype=torch.float16)
|
|
|
186 |
use_flex_attn=False,
|
187 |
bf16=False,
|
188 |
):
|
189 |
+
print('[Loading Infinity]')
|
190 |
|
191 |
# Set device if not provided
|
192 |
if device is None:
|
|
|
232 |
infinity_test.eval()
|
233 |
infinity_test.requires_grad_(False)
|
234 |
|
235 |
+
print('[Load Infinity weights]')
|
236 |
state_dict = torch.load(model_path, map_location=device)
|
237 |
print(infinity_test.load_state_dict(state_dict))
|
238 |
|
239 |
# Initialize random number generator on the correct device
|
240 |
+
infinity_test.rng = torch.Generator(device="cpu")
|
241 |
+
|
242 |
return infinity_test
|
243 |
|
244 |
def transform(pil_img, tgt_h, tgt_w):
|