Spaces:
Running
on
Zero
Running
on
Zero
Update clip_slider_pipeline.py
Browse files- clip_slider_pipeline.py +7 -3
clip_slider_pipeline.py
CHANGED
|
@@ -4,7 +4,7 @@ import random
|
|
| 4 |
from tqdm import tqdm
|
| 5 |
from constants import SUBJECTS, MEDIUMS
|
| 6 |
from PIL import Image
|
| 7 |
-
|
| 8 |
class CLIPSlider:
|
| 9 |
def __init__(
|
| 10 |
self,
|
|
@@ -214,7 +214,7 @@ class CLIPSliderXL(CLIPSlider):
|
|
| 214 |
):
|
| 215 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
| 216 |
# if pooler token only [-4,4] work well
|
| 217 |
-
|
| 218 |
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
|
| 219 |
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
|
| 220 |
with torch.no_grad():
|
|
@@ -282,9 +282,13 @@ class CLIPSliderXL(CLIPSlider):
|
|
| 282 |
|
| 283 |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 284 |
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
| 285 |
-
|
|
|
|
| 286 |
torch.manual_seed(seed)
|
|
|
|
| 287 |
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
| 288 |
**pipeline_kwargs).images[0]
|
|
|
|
|
|
|
| 289 |
|
| 290 |
return image
|
|
|
|
| 4 |
from tqdm import tqdm
|
| 5 |
from constants import SUBJECTS, MEDIUMS
|
| 6 |
from PIL import Image
|
| 7 |
+
import time
|
| 8 |
class CLIPSlider:
|
| 9 |
def __init__(
|
| 10 |
self,
|
|
|
|
| 214 |
):
|
| 215 |
# if doing full sequence, [-0.3,0.3] work well, higher if correlation weighted is true
|
| 216 |
# if pooler token only [-4,4] work well
|
| 217 |
+
start_time = time.time()
|
| 218 |
text_encoders = [self.pipe.text_encoder, self.pipe.text_encoder_2]
|
| 219 |
tokenizers = [self.pipe.tokenizer, self.pipe.tokenizer_2]
|
| 220 |
with torch.no_grad():
|
|
|
|
| 282 |
|
| 283 |
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
|
| 284 |
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
|
| 285 |
+
end_time = time.time()
|
| 286 |
+
print(f"generation time - before pipe: {end_time - start_time:.2f} ms")
|
| 287 |
torch.manual_seed(seed)
|
| 288 |
+
start_time = time.time()
|
| 289 |
image = self.pipe(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds,
|
| 290 |
**pipeline_kwargs).images[0]
|
| 291 |
+
end_time = time.time()
|
| 292 |
+
print(f"generation time - pipe: {end_time - start_time:.2f} ms")
|
| 293 |
|
| 294 |
return image
|