Update app.py
Browse files
app.py
CHANGED
@@ -127,6 +127,14 @@ def generate_music(genre_txt, lyrics_txt, max_new_tokens=5, run_n_segments=2, us
|
|
127 |
raw_output = None
|
128 |
stage1_output_set = []
|
129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
130 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
131 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
132 |
guidance_scale = 1.5 if i <= 1 else 1.2
|
|
|
127 |
raw_output = None
|
128 |
stage1_output_set = []
|
129 |
|
130 |
+
class BlockTokenRangeProcessor(LogitsProcessor):
|
131 |
+
def __init__(self, start_id, end_id):
|
132 |
+
self.blocked_token_ids = list(range(start_id, end_id))
|
133 |
+
|
134 |
+
def __call__(self, input_ids, scores):
|
135 |
+
scores[:, self.blocked_token_ids] = -float("inf")
|
136 |
+
return scores
|
137 |
+
|
138 |
for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
139 |
section_text = p.replace('[start_of_segment]', '').replace('[end_of_segment]', '')
|
140 |
guidance_scale = 1.5 if i <= 1 else 1.2
|