mrfakename commited on
Commit
51e0928
·
verified ·
1 Parent(s): 9abfb86

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -35
app.py CHANGED
@@ -1,37 +1,129 @@
1
- import torch, spaces
2
- import gradio as gr
3
- from diffusers import FluxPipeline
4
-
5
- MODELS = {
6
- # 'FLUX.1 [dev]': 'black-forest-labs/FLUX.1-dev',
7
- 'FLUX.1 [schnell]': 'black-forest-labs/FLUX.1-schnell',
8
- 'OpenFLUX.1': 'ostris/OpenFLUX.1',
9
- }
10
- MODEL_CACHE = {}
11
- for id, model in MODELS.items():
12
- print(f"Loading model {model}...")
13
- MODEL_CACHE[id] = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
14
- MODEL_CACHE[id].enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
15
- print(f"Loaded model {model}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  @spaces.GPU
18
- def generate(text):
19
- prompt = "A cat holding a sign that says hello world"
20
- image = MODEL_CACHE['OpenFLUX.1'](
21
- prompt,
22
- height=1024,
23
- width=1024,
24
- guidance_scale=3.5,
25
- num_inference_steps=50,
26
- max_sequence_length=512,
27
- generator=torch.Generator("cpu").manual_seed(0)
28
- ).images[0]
29
- return image
30
- # image.save("flux-dev.png")
31
-
32
- with gr.Blocks() as demo:
33
- prompt = gr.Textbox(label="Prompt")
34
- btn = gr.Button("Generate", variant="primary")
35
- out = gr.Image(label="Generated image", interactive=False)
36
- btn.click(generate,inputs=prompt,outputs=out)
37
- demo.queue().launch()
 
1
+ import torch as T
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchaudio
5
+ from utils import load_ckpt, print_colored
6
+ from tokenizer import make_tokenizer
7
+ from model import get_hertz_dev_config
8
+ import matplotlib.pyplot as plt
9
+
10
+ device = 'cuda' if T.cuda.is_available() else 'cpu'
11
+ T.cuda.set_device(0)
12
+ print_colored(f"Using device: {device}", "grey")
13
+
14
+ audio_tokenizer = make_tokenizer(device)
15
+
16
+ TWO_SPEAKER = False
17
+
18
+ model_config = get_hertz_dev_config(is_split=TWO_SPEAKER)
19
+
20
+ generator = model_config()
21
+ generator = generator.eval().to(T.bfloat16).to(device)
22
+
23
+
24
+
25
+ ##############
26
+ # Load audio
27
+
28
+ def load_and_preprocess_audio(audio_path):
29
+ gr.Info("Loading and preprocessing audio...")
30
+ # Load audio file
31
+ audio_tensor, sr = torchaudio.load(audio_path)
32
+ gr.Info(f"Loaded audio shape: {audio_tensor.shape}")
33
+
34
+ if TWO_SPEAKER:
35
+ if audio_tensor.shape[0] == 1:
36
+ gr.Info("Converting mono to stereo...")
37
+ audio_tensor = audio_tensor.repeat(2, 1)
38
+ gr.Info(f"Stereo audio shape: {audio_tensor.shape}")
39
+ else:
40
+ if audio_tensor.shape[0] == 2:
41
+ gr.Info("Converting stereo to mono...")
42
+ audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
43
+ gr.Info(f"Mono audio shape: {audio_tensor.shape}")
44
+
45
+ # Resample to 16kHz if needed
46
+ if sr != 16000:
47
+ gr.Info(f"Resampling from {sr}Hz to 16000Hz...")
48
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
49
+ audio_tensor = resampler(audio_tensor)
50
+
51
+ # Clip to 5 minutes if needed
52
+ max_samples = 16000 * 60 * 5
53
+ if audio_tensor.shape[1] > max_samples:
54
+ # gr.Info("Clipping audio to 5 minutes...")
55
+ raise gr.Erorr("Maximum prompt is 5 minutes")
56
+ # audio_tensor = audio_tensor[:, :max_samples]
57
+
58
+ duration_seconds = audio_tensor.shape[1] / sample_rate
59
+
60
+ gr.Info("Audio preprocessing complete!")
61
+ return audio_tensor.unsqueeze(0), duration_seconds
62
+
63
+ ##############
64
+ # Return audio to gradio
65
+
66
+ def display_audio(audio_tensor):
67
+ audio_tensor = audio_tensor.cpu().squeeze()
68
+ if audio_tensor.ndim == 1:
69
+ audio_tensor = audio_tensor.unsqueeze(0)
70
+ audio_tensor = audio_tensor.float()
71
+
72
+ # Make a waveform plot
73
+ # plt.figure(figsize=(4, 1))
74
+ # plt.plot(audio_tensor.numpy()[0], linewidth=0.5)
75
+ # plt.axis('off')
76
+ # plt.show()
77
+
78
+ # Make an audio player
79
+ return (16000, audio_tensor.numpy())
80
+
81
+ def get_completion(encoded_prompt_audio, prompt_len):
82
+ prompt_len_seconds = prompt_len / 8
83
+ gr.Info(f"Prompt length: {prompt_len_seconds:.2f}s")
84
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
85
+ completed_audio_batch = generator.completion(
86
+ encoded_prompt_audio,
87
+ temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))
88
+ use_cache=True)
89
+
90
+ completed_audio = completed_audio_batch
91
+ print_colored(f"Decoding completion...", "blue")
92
+ if TWO_SPEAKER:
93
+ decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())
94
+ decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())
95
+ decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)
96
+ else:
97
+ decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())
98
+ gr.Info(f"Decoded completion shape: {decoded_completion.shape}")
99
+
100
+ gr.Info("Preparing audio for playback...")
101
+
102
+ audio_tensor = decoded_completion.cpu().squeeze()
103
+ if audio_tensor.ndim == 1:
104
+ audio_tensor = audio_tensor.unsqueeze(0)
105
+ audio_tensor = audio_tensor.float()
106
+
107
+ if audio_tensor.abs().max() > 1:
108
+ audio_tensor = audio_tensor / audio_tensor.abs().max()
109
+
110
+ return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]
111
 
112
  @spaces.GPU
113
+ def run(audio_path):
114
+ prompt_audio, prompt_len_seconds = load_and_preprocess_audio(audio_path)
115
+ prompt_len = prompt_len_seconds * 8
116
+ gr.Info("Encoding prompt...")
117
+ with T.autocast(device_type='cuda', dtype=T.bfloat16):
118
+ if TWO_SPEAKER:
119
+ encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))
120
+ encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))
121
+ encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)
122
+ else:
123
+ encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
124
+ gr.Info(f"Encoded prompt shape: {encoded_prompt_audio.shape}")
125
+ gr.Info("Prompt encoded successfully!")
126
+ # num_completions = 10
127
+ completion = get_completion(encoded_prompt_audio, prompt_len)
128
+ return display_audio(completion)
129
+