Spaces:
Build error
Build error
Haoxin Chen
commited on
Commit
·
62b6d65
1
Parent(s):
3ab49e3
update change_lora func
Browse files- app.py +1 -1
- lvdm/models/modules/lora.py +77 -0
- videocrafter_test.py +5 -4
app.py
CHANGED
|
@@ -12,7 +12,7 @@ def videocrafter_demo(result_dir='./tmp/'):
|
|
| 12 |
with gr.Tab(label="VideoCrafter"):
|
| 13 |
input_text = gr.Text()
|
| 14 |
model_choices=['origin','vangogh','frozen','yourname', 'coco']
|
| 15 |
-
trigger_word_list=['','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style']
|
| 16 |
|
| 17 |
with gr.Row():
|
| 18 |
model_index = gr.Dropdown(label='Models', elem_id=f"model", choices=model_choices, value=model_choices[0], type="index",interactive=True)
|
|
|
|
| 12 |
with gr.Tab(label="VideoCrafter"):
|
| 13 |
input_text = gr.Text()
|
| 14 |
model_choices=['origin','vangogh','frozen','yourname', 'coco']
|
| 15 |
+
trigger_word_list=[' ','Loving Vincent style', 'frozenmovie style', 'MakotoShinkaiYourName style', 'coco style']
|
| 16 |
|
| 17 |
with gr.Row():
|
| 18 |
model_index = gr.Dropdown(label='Models', elem_id=f"model", choices=model_choices, value=model_choices[0], type="index",interactive=True)
|
lvdm/models/modules/lora.py
CHANGED
|
@@ -680,6 +680,83 @@ def change_lora(model, inject_lora=False, lora_scale=1.0, lora_path='', last_tim
|
|
| 680 |
net_load_lora(model, lora_path, alpha=lora_scale)
|
| 681 |
|
| 682 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
def load_safeloras(path, device="cpu"):
|
| 685 |
safeloras = safe_open(path, framework="pt", device=device)
|
|
|
|
| 680 |
net_load_lora(model, lora_path, alpha=lora_scale)
|
| 681 |
|
| 682 |
|
| 683 |
+
def net_load_lora_v2(net, checkpoint_path, alpha=1.0, remove=False, origin_weight=None):
|
| 684 |
+
visited=[]
|
| 685 |
+
state_dict = torch.load(checkpoint_path)
|
| 686 |
+
for k, v in state_dict.items():
|
| 687 |
+
state_dict[k] = v.to(net.device)
|
| 688 |
+
# import pdb;pdb.set_trace()
|
| 689 |
+
for key in state_dict:
|
| 690 |
+
if ".alpha" in key or key in visited:
|
| 691 |
+
continue
|
| 692 |
+
layer_infos = key.split(".")[:-2] # remove lora_up and down weight
|
| 693 |
+
curr_layer = net
|
| 694 |
+
# find the target layer
|
| 695 |
+
temp_name = layer_infos.pop(0)
|
| 696 |
+
while len(layer_infos) > -1:
|
| 697 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
| 698 |
+
if len(layer_infos) > 0:
|
| 699 |
+
temp_name = layer_infos.pop(0)
|
| 700 |
+
elif len(layer_infos) == 0:
|
| 701 |
+
break
|
| 702 |
+
if curr_layer.__class__ not in [nn.Linear, nn.Conv2d]:
|
| 703 |
+
print('missing param at:', key)
|
| 704 |
+
continue
|
| 705 |
+
pair_keys = []
|
| 706 |
+
if "lora_down" in key:
|
| 707 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
| 708 |
+
pair_keys.append(key)
|
| 709 |
+
else:
|
| 710 |
+
pair_keys.append(key)
|
| 711 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
| 712 |
+
|
| 713 |
+
# storage weight
|
| 714 |
+
if origin_weight is None:
|
| 715 |
+
origin_weight = dict()
|
| 716 |
+
storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
|
| 717 |
+
origin_weight[storage_key] = curr_layer.weight.data.clone()
|
| 718 |
+
else:
|
| 719 |
+
storage_key = key.replace("lora_down", "lora").replace("lora_up", "lora")
|
| 720 |
+
if storage_key not in origin_weight.keys():
|
| 721 |
+
origin_weight[storage_key] = curr_layer.weight.data.clone()
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
# update
|
| 725 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
| 726 |
+
# for conv
|
| 727 |
+
if remove:
|
| 728 |
+
curr_layer.weight.data = origin_weight[storage_key].clone()
|
| 729 |
+
else:
|
| 730 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
| 731 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
| 732 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
|
| 733 |
+
else:
|
| 734 |
+
# for linear
|
| 735 |
+
if remove:
|
| 736 |
+
curr_layer.weight.data = origin_weight[storage_key].clone()
|
| 737 |
+
else:
|
| 738 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
| 739 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
| 740 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
|
| 741 |
+
|
| 742 |
+
# update visited list
|
| 743 |
+
for item in pair_keys:
|
| 744 |
+
visited.append(item)
|
| 745 |
+
print('load_weight_num:',len(visited))
|
| 746 |
+
return origin_weight
|
| 747 |
+
|
| 748 |
+
def change_lora_v2(model, inject_lora=False, lora_scale=1.0, lora_path='', last_time_lora='', last_time_lora_scale=1.0, origin_weight=None):
|
| 749 |
+
# remove lora
|
| 750 |
+
if last_time_lora != '':
|
| 751 |
+
origin_weight = net_load_lora_v2(model, last_time_lora, alpha=last_time_lora_scale, remove=True, origin_weight=origin_weight)
|
| 752 |
+
# add new lora
|
| 753 |
+
if inject_lora:
|
| 754 |
+
origin_weight = net_load_lora_v2(model, lora_path, alpha=lora_scale, origin_weight=origin_weight)
|
| 755 |
+
return origin_weight
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
|
| 760 |
|
| 761 |
def load_safeloras(path, device="cpu"):
|
| 762 |
safeloras = safe_open(path, framework="pt", device=device)
|
videocrafter_test.py
CHANGED
|
@@ -13,7 +13,7 @@ from lvdm.utils.common_utils import str2bool
|
|
| 13 |
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
|
| 14 |
from scripts.sample_text2video import sample_text2video
|
| 15 |
from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np
|
| 16 |
-
from lvdm.models.modules.lora import change_lora
|
| 17 |
|
| 18 |
from huggingface_hub import hf_hub_download
|
| 19 |
|
|
@@ -50,13 +50,14 @@ class Text2Video():
|
|
| 50 |
self.result_dir = result_dir
|
| 51 |
self.save_fps = 8
|
| 52 |
self.ddim_sampler = DDIMSampler(model)
|
|
|
|
| 53 |
|
| 54 |
def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0, trigger_word=''):
|
| 55 |
-
if trigger_word !='':
|
| 56 |
input_text = input_text + ', ' + trigger_word
|
| 57 |
inject_lora = model_index > 0
|
| 58 |
-
|
| 59 |
-
last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale)
|
| 60 |
|
| 61 |
all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1,
|
| 62 |
sample_type='ddim', sampler=self.ddim_sampler,
|
|
|
|
| 13 |
from lvdm.utils.saving_utils import npz_to_video_grid, npz_to_imgsheet_5d
|
| 14 |
from scripts.sample_text2video import sample_text2video
|
| 15 |
from scripts.sample_utils import load_model, get_conditions, make_model_input_shape, torch_to_np
|
| 16 |
+
from lvdm.models.modules.lora import change_lora, change_lora_v2
|
| 17 |
|
| 18 |
from huggingface_hub import hf_hub_download
|
| 19 |
|
|
|
|
| 50 |
self.result_dir = result_dir
|
| 51 |
self.save_fps = 8
|
| 52 |
self.ddim_sampler = DDIMSampler(model)
|
| 53 |
+
self.origin_weight = None
|
| 54 |
|
| 55 |
def get_prompt(self, input_text, steps=50, model_index=0, eta=1.0, cfg_scale=15.0, lora_scale=1.0, trigger_word=''):
|
| 56 |
+
if trigger_word !=' ':
|
| 57 |
input_text = input_text + ', ' + trigger_word
|
| 58 |
inject_lora = model_index > 0
|
| 59 |
+
self.origin_weight = change_lora_v2(self.model, inject_lora=inject_lora, lora_scale=lora_scale, lora_path=self.lora_path_list[model_index],
|
| 60 |
+
last_time_lora=self.last_time_lora, last_time_lora_scale=self.last_time_lora_scale, origin_weight=self.origin_weight)
|
| 61 |
|
| 62 |
all_videos = sample_text2video(self.model, input_text, n_samples=1, batch_size=1,
|
| 63 |
sample_type='ddim', sampler=self.ddim_sampler,
|