Spaces:
Runtime error
Runtime error
Commit
Β·
793fc05
1
Parent(s):
30a6bb8
update
Browse files- app.py +2 -2
- diffrhythm/infer/infer.py +5 -1
- diffrhythm/infer/infer_utils.py +1 -2
- requirements.txt +3 -2
app.py
CHANGED
@@ -51,7 +51,7 @@ def infer_music(lrc, ref_audio_path, steps, file_type, max_frames=2048):
|
|
51 |
start_time=start_time,
|
52 |
file_type=file_type
|
53 |
)
|
54 |
-
|
55 |
gc.collect()
|
56 |
print(">4")
|
57 |
|
@@ -207,7 +207,7 @@ with gr.Blocks(css=css) as demo:
|
|
207 |
interactive=True,
|
208 |
elem_id="step_slider"
|
209 |
)
|
210 |
-
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="
|
211 |
|
212 |
|
213 |
|
|
|
51 |
start_time=start_time,
|
52 |
file_type=file_type
|
53 |
)
|
54 |
+
devicetorch.empty_cache(torch)
|
55 |
gc.collect()
|
56 |
print(">4")
|
57 |
|
|
|
207 |
interactive=True,
|
208 |
elem_id="step_slider"
|
209 |
)
|
210 |
+
file_type = gr.Dropdown(["wav", "mp3", "ogg"], label="Output Format", value="mp3")
|
211 |
|
212 |
|
213 |
|
diffrhythm/infer/infer.py
CHANGED
@@ -134,7 +134,11 @@ if __name__ == "__main__":
|
|
134 |
parser.add_argument('--output-dir', type=str, default="example/output")
|
135 |
args = parser.parse_args()
|
136 |
|
137 |
-
device =
|
|
|
|
|
|
|
|
|
138 |
|
139 |
audio_length = args.audio_length
|
140 |
if audio_length == 95:
|
|
|
134 |
parser.add_argument('--output-dir', type=str, default="example/output")
|
135 |
args = parser.parse_args()
|
136 |
|
137 |
+
device = "cpu"
|
138 |
+
if torch.cuda.is_available():
|
139 |
+
device = "cuda"
|
140 |
+
elif torch.mps.is_available():
|
141 |
+
device = "mps"
|
142 |
|
143 |
audio_length = args.audio_length
|
144 |
if audio_length == 95:
|
diffrhythm/infer/infer_utils.py
CHANGED
@@ -169,8 +169,7 @@ def get_lrc_token(text, tokenizer, device):
|
|
169 |
return lrc_emb, normalized_start_time
|
170 |
|
171 |
def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
172 |
-
|
173 |
-
model = model.half()
|
174 |
|
175 |
ckpt_type = ckpt_path.split(".")[-1]
|
176 |
if ckpt_type == "safetensors":
|
|
|
169 |
return lrc_emb, normalized_start_time
|
170 |
|
171 |
def load_checkpoint(model, ckpt_path, device, use_ema=True):
|
172 |
+
model = model.half()
|
|
|
173 |
|
174 |
ckpt_type = ckpt_path.split(".")[-1]
|
175 |
if ckpt_type == "safetensors":
|
requirements.txt
CHANGED
@@ -12,7 +12,8 @@ pandas==2.2.3
|
|
12 |
pylance==0.23.2
|
13 |
ema-pytorch==0.7.7
|
14 |
prefigure==0.0.10
|
15 |
-
bitsandbytes==0.45.3
|
|
|
16 |
muq==0.1.0
|
17 |
mutagen==1.47.0
|
18 |
pyopenjtalk==0.4.0
|
@@ -21,7 +22,7 @@ jieba==0.42.1
|
|
21 |
cn2an==0.5.23
|
22 |
pypinyin==0.53.0
|
23 |
#onnxruntime==1.20.1
|
24 |
-
onnxruntime-gpu
|
25 |
Unidecode==1.3.8
|
26 |
phonemizer==3.3.0
|
27 |
LangSegment==0.3.5
|
|
|
12 |
pylance==0.23.2
|
13 |
ema-pytorch==0.7.7
|
14 |
prefigure==0.0.10
|
15 |
+
#bitsandbytes==0.45.3
|
16 |
+
bitsandbytes
|
17 |
muq==0.1.0
|
18 |
mutagen==1.47.0
|
19 |
pyopenjtalk==0.4.0
|
|
|
22 |
cn2an==0.5.23
|
23 |
pypinyin==0.53.0
|
24 |
#onnxruntime==1.20.1
|
25 |
+
#onnxruntime-gpu
|
26 |
Unidecode==1.3.8
|
27 |
phonemizer==3.3.0
|
28 |
LangSegment==0.3.5
|