napstablook911 commited on
Commit
c5c0d60
·
verified ·
1 Parent(s): 3490480

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +148 -106
src/streamlit_app.py CHANGED
@@ -1,119 +1,161 @@
1
  import streamlit as st
 
 
 
 
2
  import torch
3
- import torchaudio
4
- from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
5
- from einops import rearrange
6
- from stable_audio_tools import get_pretrained_model
7
- from stable_audio_tools.inference.generation import generate_diffusion_cond
8
- import io # Per salvare l'audio in memoria per Streamlit
9
-
10
- st.set_page_config(layout="wide")
11
-
12
- st.title("Image Captioning and Soundscape Generation")
13
-
14
- # Funzione per caricare i modelli e metterli in cache
15
- @st.cache_resource
16
- def load_models():
17
- # Imposta il dispositivo su "cpu" come da requisiti per lo Space
18
- device = "cpu"
19
- st.write(f"Utilizzo del dispositivo: {device}")
20
-
21
- # Caricamento del modello ViT-GPT2 per la captioning dell'immagine
22
- st.write("Caricamento del modello ViT-GPT2 per la captioning dell'immagine...")
23
  try:
24
- vit_gpt2_feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning", cache_dir="/app/hf_cache")
25
- vit_gpt2_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning", cache_dir="/app/hf_cache")
26
- vit_gpt2_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning", cache_dir="/app/hf_cache").to(device)
27
- st.write("Modello ViT-GPT2 caricato.")
 
 
 
28
  except Exception as e:
29
- st.error(f"Errore durante il caricamento del modello ViT-GPT2: {e}")
30
- st.stop() # Ferma l'app se il modello essenziale non carica
31
-
32
- # Caricamento del modello Stable Audio Open Small per la generazione del soundscape
33
- st.write("Caricamento del modello Stable Audio Open Small per la generazione del soundscape...")
 
 
 
 
34
  try:
35
- # Carica il modello Stable Audio usando stable_audio_tools
36
- stable_audio_model, stable_audio_config = get_pretrained_model("stabilityai/stable-audio-open-small", cache_dir="/app/hf_cache")
37
- stable_audio_model = stable_audio_model.to(device)
38
- st.write("Modello Stable Audio Open Small caricato.")
39
- return vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_model, stable_audio_config
40
  except Exception as e:
41
- st.error(f"Errore durante il caricamento del modello Stable Audio Open Small: {e}")
42
- st.stop() # Ferma l'app se il modello essenziale non carica
43
-
44
-
45
- # Carica i modelli all'avvio dell'app
46
- vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_model, stable_audio_config = load_models()
47
-
48
- # Funzione per generare la caption dell'immagine
49
- def generate_caption(image_pil):
50
- pixel_values = vit_gpt2_feature_extractor(images=image_pil, return_tensors="pt").pixel_values
51
- output_ids = vit_gpt2_model.generate(pixel_values, max_new_tokens=16)
52
- caption = vit_gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
53
- return caption
54
-
55
- # Funzione per generare il soundscape
56
- def generate_soundscape(prompt_text):
57
- sample_size = stable_audio_config["sample_size"]
58
- sample_rate = stable_audio_config["sample_rate"]
59
 
60
- # Assicurati che il modello sia sulla CPU per la generazione
61
- device = "cpu"
 
 
 
 
 
 
62
 
63
- conditioning = [{
64
- "prompt": prompt_text,
65
- }]
66
-
67
- # Genera audio
68
- with st.spinner("Generazione audio in corso... (potrebbe richiedere un po' di tempo)"):
69
- output = generate_diffusion_cond(
70
- stable_audio_model,
71
- conditioning=conditioning,
72
- sample_size=sample_size,
73
- device=device,
74
- steps=100, # Numero di step di diffusione (puoi renderlo configurabile)
75
- cfg_scale=7, # Scala di classifer-free guidance
76
- sigma_min=0.03,
77
- sigma_max=500,
78
- sampler_type="dpmpp-3m-sde" # Tipo di sampler
79
- )
80
-
81
- # Riorganizza il batch audio in una singola sequenza
82
- output = rearrange(output, "b d n -> d (b n)")
83
 
84
- # Peak normalize, clip, converti in int16, e prepara per la riproduzione
85
- output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
 
 
86
 
87
- # Salva l'audio in un buffer di memoria per Streamlit
88
- buffer = io.BytesIO()
89
- torchaudio.save(buffer, output, sample_rate, format="wav")
90
- return buffer.getvalue(), sample_rate
 
91
 
92
- # Streamlit UI
93
- uploaded_file = st.file_uploader("Carica un'immagine per la captioning:", type=["png", "jpg", "jpeg"])
94
 
95
- caption = ""
96
  if uploaded_file is not None:
97
- from PIL import Image
98
- image = Image.open(uploaded_file).convert("RGB")
99
- st.image(image, caption="Immagine caricata.", use_column_width=True)
100
-
101
- with st.spinner("Generazione della caption..."):
102
- caption = generate_caption(image)
103
- st.success(f"Caption generata: **{caption}**")
104
-
105
- # Campo di input per il prompt del soundscape
106
- st.header("Generazione Soundscape")
107
- soundscape_prompt_input = st.text_input(
108
- "Inserisci un prompt per il soundscape (es. 'A gentle rain with thunder and distant birds'):",
109
- value=caption if caption else "A natural outdoor soundscape" # Pre-popola con la caption se disponibile
110
- )
111
-
112
- if st.button("Genera Soundscape Audio"):
113
- if soundscape_prompt_input:
114
- audio_bytes, sr = generate_soundscape(soundscape_prompt_input)
115
- st.audio(audio_bytes, format='audio/wav', sample_rate=sr)
116
- else:
117
- st.warning("Per favore, inserisci un prompt per generare il soundscape.")
118
-
119
- st.info("Nota: La generazione del soundscape può richiedere un po' di tempo a seconda della complessità del prompt e delle risorse disponibili.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from PIL import Image
3
+ import io
4
+ import soundfile as sf
5
+ import numpy as np
6
  import torch
7
+ from transformers import pipeline
8
+ from diffusers import StableAudioPipeline
9
+
10
+ # --- Configuration ---
11
+ # Determine the optimal device for model inference
12
+ # Prioritize CUDA (NVIDIA GPUs), then MPS (Apple Silicon), fallback to CPU
13
+ DEVICE = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
14
+
15
+ # Use float16 for reduced memory and faster inference on compatible hardware (GPU/MPS)
16
+ # Fallback to float32 for CPU for better stability
17
+ TORCH_DTYPE = torch.float16 if DEVICE in ["cuda", "mps"] else torch.float32
18
+
19
+ # --- Cached Model Loading Functions ---
20
+ @st.cache_resource(show_spinner="Loading Image Captioning Model (BLIP)...")
21
+ def load_blip_model():
22
+ """
23
+ Loads the BLIP image captioning model using Hugging Face transformers pipeline.
24
+ The model is cached to prevent reloading on every Streamlit rerun.
25
+ """
 
26
  try:
27
+ captioner = pipeline(
28
+ "image-to-text",
29
+ model="Salesforce/blip-image-captioning-base",
30
+ torch_dtype=TORCH_DTYPE,
31
+ device=DEVICE
32
+ )
33
+ return captioner
34
  except Exception as e:
35
+ st.error(f"Failed to load BLIP model: {e}")
36
+ return None
37
+
38
+ @st.cache_resource(show_spinner="Loading Audio Generation Model (Stable Audio Open 1.0)...")
39
+ def load_stable_audio_model():
40
+ """
41
+ Loads the Stable Audio Open 1.0 pipeline using Hugging Face diffusers.
42
+ The pipeline is cached to prevent reloading on every Streamlit rerun.
43
+ """
44
  try:
45
+ audio_pipeline = StableAudioPipeline.from_pretrained(
46
+ "stabilityai/stable-audio-open-1.0",
47
+ torch_dtype=TORCH_DTYPE
48
+ ).to(DEVICE)
49
+ return audio_pipeline
50
  except Exception as e:
51
+ st.error(f"Failed to load Stable Audio model: {e}")
52
+ return None
53
+
54
+ # --- Audio Conversion Utility ---
55
+ def convert_numpy_to_wav_bytes(audio_array: np.ndarray, sample_rate: int) -> bytes:
56
+ """
57
+ Converts a NumPy audio array to an in-memory WAV byte stream.
58
+ This avoids writing temporary files to disk, which is efficient and
59
+ suitable for ephemeral environments like Hugging Face Spaces.
60
+ """
61
+ byte_io = io.BytesIO()
 
 
 
 
 
 
 
62
 
63
+ # Stable Audio Open's diffusers output is (channels, frames).
64
+ # soundfile typically expects (frames, channels) for stereo.
65
+ # Transpose if it's a 2D array (stereo) to match soundfile's expectation.
66
+ if audio_array.ndim == 2 and audio_array.shape == 2: # Check if stereo (2 channels)
67
+ audio_array = audio_array.T # Transpose to (frames, channels) [1]
68
+
69
+ # Write the NumPy array to the in-memory BytesIO object as a WAV file [1, 2]
70
+ sf.write(byte_io, audio_array, sample_rate, format='WAV', subtype='FLOAT')
71
 
72
+ # IMPORTANT: Reset the stream position to the beginning before reading [3]
73
+ byte_io.seek(0)
74
+ return byte_io.read()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ # --- Streamlit App Layout ---
77
+ st.set_page_config(layout="centered", page_title="Image-to-Soundscape Generator")
78
+ st.title("🏞️ Image-to-Soundscape Generator 🎶")
79
+ st.markdown("Upload a landscape image, and let AI transform it into a unique soundscape!")
80
 
81
+ # Initialize session state for persistence across reruns [4]
82
+ if "audio_bytes" not in st.session_state:
83
+ st.session_state.audio_bytes = None
84
+ if "image_uploaded" not in st.session_state:
85
+ st.session_state.image_uploaded = False
86
 
87
+ # --- UI Components ---
88
+ uploaded_file = st.file_uploader("Choose a landscape image...", type=["jpg", "jpeg", "png"]) # [5]
89
 
 
90
  if uploaded_file is not None:
91
+ st.session_state.image_uploaded = True
92
+ image = Image.open(uploaded_file).convert("RGB") # Ensure image is in RGB format
93
+ st.image(image, caption="Uploaded Image", use_column_width=True) # [6]
94
+
95
+ # Button to trigger the generation pipeline
96
+ if st.button("Generate Soundscape"):
97
+ st.session_state.audio_bytes = None # Clear previous audio
98
+
99
+ with st.spinner("Generating soundscape... This may take a moment."): # [4]
100
+ try:
101
+ # 1. Load BLIP model and generate caption (hidden from user)
102
+ captioner = load_blip_model()
103
+ if captioner is None:
104
+ st.error("Image captioning model could not be loaded. Please try again.")
105
+ st.session_state.image_uploaded = False # Reset to allow re-upload
106
+ st.stop()
107
+
108
+ # Generate caption
109
+ # The BLIP pipeline expects a PIL Image object directly
110
+ caption_results = captioner(image)
111
+ # Extract the generated text from the pipeline's output [7]
112
+ generated_caption = caption_results['generated_text']
113
+
114
+ # Optional: Enhance prompt for soundscape generation
115
+ # This helps guide the audio model towards environmental sounds
116
+ soundscape_prompt = f"A soundscape of {generated_caption}"
117
+
118
+ # 2. Load Stable Audio model and generate audio
119
+ audio_pipeline = load_stable_audio_model()
120
+ if audio_pipeline is None:
121
+ st.error("Audio generation model could not be loaded. Please try again.")
122
+ st.session_state.image_uploaded = False # Reset to allow re-upload
123
+ st.stop()
124
+
125
+ # Generate audio with optimized parameters for speed [8, 9]
126
+ # num_inference_steps: Lower for faster generation, higher for better quality
127
+ # audio_end_in_s: Shorter audio for faster generation
128
+ # negative_prompt: Helps improve perceived quality [9]
129
+ audio_output = audio_pipeline(
130
+ prompt=soundscape_prompt,
131
+ num_inference_steps=50, # Tuned for faster generation [9]
132
+ audio_end_in_s=10.0, # 10 seconds audio length [8]
133
+ negative_prompt="low quality, average quality, distorted" # [9]
134
+ )
135
+
136
+ # Extract the NumPy array and sample rate [10]
137
+ audio_numpy_array = audio_output.audios
138
+ sample_rate = audio_pipeline.config.sampling_rate
139
+
140
+ # 3. Convert NumPy array to WAV bytes and store in session state
141
+ st.session_state.audio_bytes = convert_numpy_to_wav_bytes(audio_numpy_array, sample_rate)
142
+
143
+ st.success("Soundscape generated successfully!")
144
+
145
+ except Exception as e:
146
+ st.error(f"An error occurred during generation: {e}") # [11]
147
+ st.session_state.audio_bytes = None # Clear any partial audio
148
+ st.session_state.image_uploaded = False # Reset to allow re-upload
149
+ st.exception(e) # Display full traceback for debugging [11]
150
+
151
+ # Display generated soundscape if available in session state
152
+ if st.session_state.audio_bytes:
153
+ st.subheader("Generated Soundscape:")
154
+ st.audio(st.session_state.audio_bytes, format='audio/wav') # [6, 12]
155
+ st.markdown("You can download the audio using the controls above.")
156
+
157
+ # Reset button for new image upload
158
+ if st.session_state.image_uploaded and st.button("Upload New Image"):
159
+ st.session_state.audio_bytes = None
160
+ st.session_state.image_uploaded = False
161
+ st.rerun() # Rerun the app to clear the file uploader