Spaces:
Runtime error
Runtime error
Commit
·
86d2837
1
Parent(s):
9c7e8e1
limit to 10 rows from 1 user for diversity.
Browse files
app.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
|
| 2 |
|
| 3 |
-
|
| 4 |
# TODO save & restart from (if it exists) dataframe parquet
|
| 5 |
import torch
|
| 6 |
|
|
@@ -37,12 +37,9 @@ torch.set_grad_enabled(False)
|
|
| 37 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 38 |
torch.backends.cudnn.allow_tf32 = True
|
| 39 |
|
| 40 |
-
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate'])
|
| 41 |
|
| 42 |
import spaces
|
| 43 |
-
prompt_list = [p for p in list(set(
|
| 44 |
-
pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str]
|
| 45 |
-
|
| 46 |
start_time = time.time()
|
| 47 |
|
| 48 |
####################### Setup Model
|
|
@@ -55,13 +52,13 @@ from transformers import CLIPVisionModelWithProjection
|
|
| 55 |
import uuid
|
| 56 |
import av
|
| 57 |
|
| 58 |
-
def
|
| 59 |
print('Saving')
|
| 60 |
container = av.open(file_name, mode="w")
|
| 61 |
|
| 62 |
stream = container.add_stream("h264", rate=fps)
|
| 63 |
# stream.options = {'preset': 'faster'}
|
| 64 |
-
stream.thread_count =
|
| 65 |
stream.width = 512
|
| 66 |
stream.height = 512
|
| 67 |
stream.pix_fmt = "yuv420p"
|
|
@@ -79,8 +76,16 @@ def write_video(file_name, images, fps=17):
|
|
| 79 |
container.close()
|
| 80 |
print('Saved')
|
| 81 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype
|
|
|
|
| 84 |
#vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
|
| 85 |
|
| 86 |
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
|
@@ -91,8 +96,9 @@ image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter",
|
|
| 91 |
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
| 92 |
|
| 93 |
|
| 94 |
-
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet').to(dtype)
|
| 95 |
-
text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder'
|
|
|
|
| 96 |
|
| 97 |
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
| 98 |
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
|
|
@@ -101,6 +107,7 @@ pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_
|
|
| 101 |
pipe.set_adapters(["lcm-lora"], [.9])
|
| 102 |
pipe.fuse_lora()
|
| 103 |
|
|
|
|
| 104 |
#pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
|
| 105 |
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
| 106 |
#repo = "ByteDance/AnimateDiff-Lightning"
|
|
@@ -116,8 +123,7 @@ pipe.unet.fuse_qkv_projections()
|
|
| 116 |
pipe.to(device=DEVICE)
|
| 117 |
#pipe.unet = torch.compile(pipe.unet)
|
| 118 |
#pipe.vae = torch.compile(pipe.vae)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
#im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype)
|
| 122 |
#output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
|
| 123 |
#leave_im_emb, _ = pipe.encode_image(
|
|
@@ -126,13 +132,13 @@ pipe.to(device=DEVICE)
|
|
| 126 |
#assert len(output.frames[0]) == 16
|
| 127 |
#leave_im_emb.detach().to('cpu')
|
| 128 |
|
| 129 |
-
@spaces.GPU(duration=
|
| 130 |
def generate_gpu(in_im_embs):
|
| 131 |
print('start gen')
|
| 132 |
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
| 133 |
#im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0)
|
| 134 |
|
| 135 |
-
output = pipe(prompt='
|
| 136 |
print('image is made')
|
| 137 |
im_emb, _ = pipe.encode_image(
|
| 138 |
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
|
@@ -163,10 +169,6 @@ def generate(in_im_embs):
|
|
| 163 |
|
| 164 |
#######################
|
| 165 |
|
| 166 |
-
|
| 167 |
-
# TODO only generate ~5 new images ahead from a specific user embedding. Do this by tracking a column of who's embedding it was and
|
| 168 |
-
# taking the intersection for unrated by that user and from that users' embedding. Then we keep styles less consistent for better variety.
|
| 169 |
-
|
| 170 |
def get_user_emb(embs, ys):
|
| 171 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
| 172 |
if len(list(set(ys))) <= 1:
|
|
@@ -245,7 +247,17 @@ def background_next_image():
|
|
| 245 |
for uid in user_id_list:
|
| 246 |
rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]]
|
| 247 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]]
|
| 248 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
print(f'latest user {uid} has < 4 rows') # or > 7 unrated rows')
|
| 250 |
continue
|
| 251 |
|
|
@@ -260,6 +272,7 @@ def background_next_image():
|
|
| 260 |
tmp_df['paths'] = [img]
|
| 261 |
tmp_df['embeddings'] = [embs]
|
| 262 |
tmp_df['user:rating'] = [{' ': ' '}]
|
|
|
|
| 263 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 264 |
# we can free up storage by deleting the image
|
| 265 |
if len(prevs_df) > 50:
|
|
@@ -345,7 +358,9 @@ def choose(img, choice, calibrate_prompts, user_id, request: gr.Request):
|
|
| 345 |
choice = 0
|
| 346 |
|
| 347 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
| 348 |
-
|
|
|
|
|
|
|
| 349 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
| 350 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
| 351 |
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
|
@@ -411,6 +426,7 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
| 411 |
''', elem_id="description")
|
| 412 |
user_id = gr.State()
|
| 413 |
print('USER_ID: ',user_id)
|
|
|
|
| 414 |
calibrate_prompts = gr.State([
|
| 415 |
'./first.mp4',
|
| 416 |
'./second.mp4',
|
|
@@ -429,7 +445,7 @@ Explore the latent space without text prompts based on your preferences. Learn m
|
|
| 429 |
interactive=False,
|
| 430 |
height=512,
|
| 431 |
width=512,
|
| 432 |
-
include_audio=False,
|
| 433 |
elem_id="video_output"
|
| 434 |
)
|
| 435 |
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
|
@@ -471,12 +487,12 @@ log = logging.getLogger('log_here')
|
|
| 471 |
log.setLevel(logging.ERROR)
|
| 472 |
|
| 473 |
scheduler = BackgroundScheduler()
|
| 474 |
-
scheduler.add_job(func=background_next_image, trigger="interval", seconds
|
| 475 |
scheduler.start()
|
| 476 |
|
| 477 |
def encode_space(x):
|
| 478 |
im_emb, _ = pipe.encode_image(
|
| 479 |
-
image,
|
| 480 |
)
|
| 481 |
return im_emb.detach().to('cpu').to(torch.float32)
|
| 482 |
|
|
|
|
| 1 |
|
| 2 |
|
| 3 |
+
# TODO unify/merge origin and this
|
| 4 |
# TODO save & restart from (if it exists) dataframe parquet
|
| 5 |
import torch
|
| 6 |
|
|
|
|
| 37 |
torch.backends.cuda.matmul.allow_tf32 = True
|
| 38 |
torch.backends.cudnn.allow_tf32 = True
|
| 39 |
|
| 40 |
+
prevs_df = pd.DataFrame(columns=['paths', 'embeddings', 'ips', 'user:rating', 'latest_user_to_rate', 'from_user_id'])
|
| 41 |
|
| 42 |
import spaces
|
|
|
|
|
|
|
|
|
|
| 43 |
start_time = time.time()
|
| 44 |
|
| 45 |
####################### Setup Model
|
|
|
|
| 52 |
import uuid
|
| 53 |
import av
|
| 54 |
|
| 55 |
+
def write_video_av(file_name, images, fps=17):
|
| 56 |
print('Saving')
|
| 57 |
container = av.open(file_name, mode="w")
|
| 58 |
|
| 59 |
stream = container.add_stream("h264", rate=fps)
|
| 60 |
# stream.options = {'preset': 'faster'}
|
| 61 |
+
stream.thread_count = -1
|
| 62 |
stream.width = 512
|
| 63 |
stream.height = 512
|
| 64 |
stream.pix_fmt = "yuv420p"
|
|
|
|
| 76 |
container.close()
|
| 77 |
print('Saved')
|
| 78 |
|
| 79 |
+
def write_video(file_name, images, fps=15):
|
| 80 |
+
writer = imageio.get_writer(file_name, fps=fps)
|
| 81 |
+
|
| 82 |
+
for im in images:
|
| 83 |
+
writer.append_data(np.array(im))
|
| 84 |
+
writer.close()
|
| 85 |
+
|
| 86 |
|
| 87 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="sdxl_models/image_encoder", torch_dtype=dtype,
|
| 88 |
+
device_map='cpu')
|
| 89 |
#vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=dtype)
|
| 90 |
|
| 91 |
# vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder", torch_dtype=dtype)
|
|
|
|
| 96 |
#text_encoder = CLIPTextModel.from_pretrained(finetune_path+'/text_encoder/').to(dtype)
|
| 97 |
|
| 98 |
|
| 99 |
+
unet = UNet2DConditionModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='unet',).to(dtype).to('cpu')
|
| 100 |
+
text_encoder = CLIPTextModel.from_pretrained('rynmurdock/Sea_Claws', subfolder='text_encoder',
|
| 101 |
+
device_map='cpu').to(dtype)
|
| 102 |
|
| 103 |
adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM")
|
| 104 |
pipe = AnimateDiffPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", motion_adapter=adapter, image_encoder=image_encoder, torch_dtype=dtype, unet=unet, text_encoder=text_encoder)
|
|
|
|
| 107 |
pipe.set_adapters(["lcm-lora"], [.9])
|
| 108 |
pipe.fuse_lora()
|
| 109 |
|
| 110 |
+
|
| 111 |
#pipe = AnimateDiffPipeline.from_pretrained('emilianJR/epiCRealism', torch_dtype=dtype, image_encoder=image_encoder)
|
| 112 |
#pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
|
| 113 |
#repo = "ByteDance/AnimateDiff-Lightning"
|
|
|
|
| 123 |
pipe.to(device=DEVICE)
|
| 124 |
#pipe.unet = torch.compile(pipe.unet)
|
| 125 |
#pipe.vae = torch.compile(pipe.vae)
|
| 126 |
+
# TODO cannot compile on Spaces or we time out; don't run leave_imb stuff either
|
|
|
|
| 127 |
#im_embs = torch.zeros(1, 1, 1, 1280, device=DEVICE, dtype=dtype)
|
| 128 |
#output = pipe(prompt='a person', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[im_embs], num_inference_steps=STEPS)
|
| 129 |
#leave_im_emb, _ = pipe.encode_image(
|
|
|
|
| 132 |
#assert len(output.frames[0]) == 16
|
| 133 |
#leave_im_emb.detach().to('cpu')
|
| 134 |
|
| 135 |
+
@spaces.GPU(duration=10)
|
| 136 |
def generate_gpu(in_im_embs):
|
| 137 |
print('start gen')
|
| 138 |
in_im_embs = in_im_embs.to('cuda').unsqueeze(0).unsqueeze(0)
|
| 139 |
#im_embs = torch.cat((torch.zeros(1, 1280, device=DEVICE, dtype=dtype), in_im_embs), 0)
|
| 140 |
|
| 141 |
+
output = pipe(prompt='', guidance_scale=0, added_cond_kwargs={}, ip_adapter_image_embeds=[in_im_embs], num_inference_steps=STEPS)
|
| 142 |
print('image is made')
|
| 143 |
im_emb, _ = pipe.encode_image(
|
| 144 |
output.frames[0][len(output.frames[0])//2], 'cuda', 1, output_hidden_state
|
|
|
|
| 169 |
|
| 170 |
#######################
|
| 171 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
def get_user_emb(embs, ys):
|
| 173 |
# handle case where every instance of calibration videos is 'Neither' or 'Like' or 'Dislike'
|
| 174 |
if len(list(set(ys))) <= 1:
|
|
|
|
| 247 |
for uid in user_id_list:
|
| 248 |
rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is not None for i in prevs_df.iterrows()]]
|
| 249 |
not_rated_rows = prevs_df[[i[1]['user:rating'].get(uid, None) is None for i in prevs_df.iterrows()]]
|
| 250 |
+
|
| 251 |
+
# we need to intersect not_rated_rows from this user's embed > 7. Just add a new column on which user_id spawned the
|
| 252 |
+
# media.
|
| 253 |
+
|
| 254 |
+
from_user = prevs_df[[i[1]['from_user_id'] == uid for i in prevs_df.iterrows()]]
|
| 255 |
+
if len(from_user) >= 10:
|
| 256 |
+
oldest = from_user.iloc[-1]['paths']
|
| 257 |
+
print(f'User has {len(from_user)} rows. Popping oldest: {oldest}')
|
| 258 |
+
prevs_df = prevs_df[prevs_df['paths'] != oldest]
|
| 259 |
+
|
| 260 |
+
if len(rated_rows) < 4:
|
| 261 |
print(f'latest user {uid} has < 4 rows') # or > 7 unrated rows')
|
| 262 |
continue
|
| 263 |
|
|
|
|
| 272 |
tmp_df['paths'] = [img]
|
| 273 |
tmp_df['embeddings'] = [embs]
|
| 274 |
tmp_df['user:rating'] = [{' ': ' '}]
|
| 275 |
+
tmp_df['from_user_id'] = [uid]
|
| 276 |
prevs_df = pd.concat((prevs_df, tmp_df))
|
| 277 |
# we can free up storage by deleting the image
|
| 278 |
if len(prevs_df) > 50:
|
|
|
|
| 358 |
choice = 0
|
| 359 |
|
| 360 |
row_mask = [p.split('/')[-1] in img for p in prevs_df['paths'].to_list()]
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
if len(prevs_df.loc[row_mask, 'user:rating']) > 0:
|
| 364 |
prevs_df.loc[row_mask, 'user:rating'][0][user_id] = choice
|
| 365 |
prevs_df.loc[row_mask, 'latest_user_to_rate'] = [user_id]
|
| 366 |
img, calibrate_prompts = next_image(calibrate_prompts, user_id)
|
|
|
|
| 426 |
''', elem_id="description")
|
| 427 |
user_id = gr.State()
|
| 428 |
print('USER_ID: ',user_id)
|
| 429 |
+
# calibration videos -- this is a misnomer now :D
|
| 430 |
calibrate_prompts = gr.State([
|
| 431 |
'./first.mp4',
|
| 432 |
'./second.mp4',
|
|
|
|
| 445 |
interactive=False,
|
| 446 |
height=512,
|
| 447 |
width=512,
|
| 448 |
+
#include_audio=False,
|
| 449 |
elem_id="video_output"
|
| 450 |
)
|
| 451 |
img.play(l, js='''document.querySelector('[data-testid="Lightning-player"]').loop = true''')
|
|
|
|
| 487 |
log.setLevel(logging.ERROR)
|
| 488 |
|
| 489 |
scheduler = BackgroundScheduler()
|
| 490 |
+
scheduler.add_job(func=background_next_image, trigger="interval", seconds=.1)
|
| 491 |
scheduler.start()
|
| 492 |
|
| 493 |
def encode_space(x):
|
| 494 |
im_emb, _ = pipe.encode_image(
|
| 495 |
+
image, DEVICE, 1, output_hidden_state
|
| 496 |
)
|
| 497 |
return im_emb.detach().to('cpu').to(torch.float32)
|
| 498 |
|