Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +106 -139
src/streamlit_app.py
CHANGED
@@ -1,152 +1,119 @@
|
|
1 |
import streamlit as st
|
2 |
-
import
|
3 |
-
|
4 |
-
import
|
5 |
-
import
|
6 |
-
|
7 |
-
|
8 |
-
import
|
9 |
-
sys.setrecursionlimit(2000) # Aumentiamo il limite di ricorsione
|
10 |
-
|
11 |
-
# --- Configurazione del Dispositivo ---
|
12 |
-
# Questo rileva automaticamente se MPS (GPU Apple Silicon) è disponibile
|
13 |
-
# Per ora, useremo la CPU come fallback se MPS è problematico per Stable Audio
|
14 |
-
device = "mps" if torch.backends.mps.is_available() else "cpu"
|
15 |
-
# ******************** MODIFICA QUI: Forza device = "cpu" ********************
|
16 |
-
# Per superare i problemi di Stable Audio su MPS con float16/float32
|
17 |
-
# FORZA LA CPU PER TUTTI I MODELLI, per semplicità.
|
18 |
-
# Se la caption genera velocemente, potremmo tornare indietro e mettere il modello vit_gpt2 su MPS
|
19 |
-
device = "cpu"
|
20 |
-
# **************************************************************************
|
21 |
-
st.write(f"Utilizzo del dispositivo: {device}")
|
22 |
-
|
23 |
-
|
24 |
-
# --- 1. Caricamento dei Modelli AI (spostati qui, fuori dalle funzioni Streamlit) ---
|
25 |
-
@st.cache_resource
|
26 |
-
def load_models():
|
27 |
-
# Caricamento del modello per la captioning (ViT-GPT2)
|
28 |
-
from transformers import AutoFeatureExtractor, AutoTokenizer, AutoModelForVision2Seq
|
29 |
-
st.write("Caricamento del modello ViT-GPT2 per la captioning dell'immagine...")
|
30 |
|
31 |
-
|
32 |
-
vit_gpt2_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
|
33 |
-
|
34 |
-
# Questo modello andrà sulla CPU
|
35 |
-
vit_gpt2_model = AutoModelForVision2Seq.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)
|
36 |
|
37 |
-
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
st.write("Caricamento del modello Stable Audio Open Small per la generazione del soundscape...")
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
|
52 |
-
# Carica i modelli all'avvio dell'app
|
53 |
-
vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_pipeline = load_models()
|
54 |
|
|
|
|
|
55 |
|
56 |
-
#
|
57 |
-
def
|
58 |
-
pixel_values = vit_gpt2_feature_extractor(images=image_pil
|
59 |
-
|
60 |
-
|
61 |
-
# Token di inizio per GPT-2, assicurandosi che sia su CPU
|
62 |
-
# Ottieni il decoder_start_token_id dal modello o dal tokenizer
|
63 |
-
if hasattr(vit_gpt2_model.config, "decoder_start_token_id"):
|
64 |
-
decoder_start_token_id = vit_gpt2_model.config.decoder_start_token_id
|
65 |
-
else:
|
66 |
-
if vit_gpt2_tokenizer.pad_token_id is not None:
|
67 |
-
decoder_start_token_id = vit_gpt2_tokenizer.pad_token_id
|
68 |
-
else:
|
69 |
-
decoder_start_token_id = 50256 # Default GPT-2 EOS token
|
70 |
-
|
71 |
-
# Crea un input_ids iniziale con il decoder_start_token_id e spostalo su CPU
|
72 |
-
input_ids = torch.ones((1, 1), device=device, dtype=torch.long) * decoder_start_token_id
|
73 |
-
|
74 |
-
|
75 |
-
output_ids = vit_gpt2_model.generate(
|
76 |
-
pixel_values=pixel_values,
|
77 |
-
input_ids=input_ids,
|
78 |
-
max_length=50,
|
79 |
-
do_sample=True,
|
80 |
-
top_k=50,
|
81 |
-
temperature=0.7,
|
82 |
-
no_repeat_ngram_size=2,
|
83 |
-
early_stopping=True
|
84 |
-
)
|
85 |
caption = vit_gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
86 |
return caption
|
87 |
|
88 |
-
|
89 |
-
def
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
if uploaded_file is not None:
|
122 |
-
|
123 |
-
|
124 |
-
st.
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
st.audio(audio_data_bytes, format='audio/wav', sample_rate=sample_rate)
|
144 |
|
145 |
-
|
146 |
-
label="Scarica Audio WAV",
|
147 |
-
data=audio_data_bytes,
|
148 |
-
file_name="paesaggio_sonoro_generato.wav",
|
149 |
-
mime="audio/wav"
|
150 |
-
)
|
151 |
-
else:
|
152 |
-
st.error("La generazione del paesaggio sonoro è fallita.")
|
|
|
1 |
import streamlit as st
|
2 |
+
import torch
|
3 |
+
import torchaudio
|
4 |
+
from transformers import ViTImageProcessor, AutoTokenizer, VisionEncoderDecoderModel
|
5 |
+
from einops import rearrange
|
6 |
+
from stable_audio_tools import get_pretrained_model
|
7 |
+
from stable_audio_tools.inference.generation import generate_diffusion_cond
|
8 |
+
import io # Per salvare l'audio in memoria per Streamlit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
+
st.set_page_config(layout="wide")
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
st.title("Image Captioning and Soundscape Generation")
|
13 |
|
14 |
+
# Funzione per caricare i modelli e metterli in cache
|
15 |
+
@st.cache_resource
|
16 |
+
def load_models():
|
17 |
+
# Imposta il dispositivo su "cpu" come da requisiti per lo Space
|
18 |
+
device = "cpu"
|
19 |
+
st.write(f"Utilizzo del dispositivo: {device}")
|
20 |
+
|
21 |
+
# Caricamento del modello ViT-GPT2 per la captioning dell'immagine
|
22 |
+
st.write("Caricamento del modello ViT-GPT2 per la captioning dell'immagine...")
|
23 |
+
try:
|
24 |
+
vit_gpt2_feature_extractor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning", cache_dir="/app/hf_cache")
|
25 |
+
vit_gpt2_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning", cache_dir="/app/hf_cache")
|
26 |
+
vit_gpt2_model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning", cache_dir="/app/hf_cache").to(device)
|
27 |
+
st.write("Modello ViT-GPT2 caricato.")
|
28 |
+
except Exception as e:
|
29 |
+
st.error(f"Errore durante il caricamento del modello ViT-GPT2: {e}")
|
30 |
+
st.stop() # Ferma l'app se il modello essenziale non carica
|
31 |
+
|
32 |
+
# Caricamento del modello Stable Audio Open Small per la generazione del soundscape
|
33 |
st.write("Caricamento del modello Stable Audio Open Small per la generazione del soundscape...")
|
34 |
+
try:
|
35 |
+
# Carica il modello Stable Audio usando stable_audio_tools
|
36 |
+
stable_audio_model, stable_audio_config = get_pretrained_model("stabilityai/stable-audio-open-small", cache_dir="/app/hf_cache")
|
37 |
+
stable_audio_model = stable_audio_model.to(device)
|
38 |
+
st.write("Modello Stable Audio Open Small caricato.")
|
39 |
+
return vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_model, stable_audio_config
|
40 |
+
except Exception as e:
|
41 |
+
st.error(f"Errore durante il caricamento del modello Stable Audio Open Small: {e}")
|
42 |
+
st.stop() # Ferma l'app se il modello essenziale non carica
|
43 |
|
|
|
|
|
44 |
|
45 |
+
# Carica i modelli all'avvio dell'app
|
46 |
+
vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_model, stable_audio_config = load_models()
|
47 |
|
48 |
+
# Funzione per generare la caption dell'immagine
|
49 |
+
def generate_caption(image_pil):
|
50 |
+
pixel_values = vit_gpt2_feature_extractor(images=image_pil, return_tensors="pt").pixel_values
|
51 |
+
output_ids = vit_gpt2_model.generate(pixel_values, max_new_tokens=16)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
caption = vit_gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
|
53 |
return caption
|
54 |
|
55 |
+
# Funzione per generare il soundscape
|
56 |
+
def generate_soundscape(prompt_text):
|
57 |
+
sample_size = stable_audio_config["sample_size"]
|
58 |
+
sample_rate = stable_audio_config["sample_rate"]
|
59 |
+
|
60 |
+
# Assicurati che il modello sia sulla CPU per la generazione
|
61 |
+
device = "cpu"
|
62 |
+
|
63 |
+
conditioning = [{
|
64 |
+
"prompt": prompt_text,
|
65 |
+
}]
|
66 |
+
|
67 |
+
# Genera audio
|
68 |
+
with st.spinner("Generazione audio in corso... (potrebbe richiedere un po' di tempo)"):
|
69 |
+
output = generate_diffusion_cond(
|
70 |
+
stable_audio_model,
|
71 |
+
conditioning=conditioning,
|
72 |
+
sample_size=sample_size,
|
73 |
+
device=device,
|
74 |
+
steps=100, # Numero di step di diffusione (puoi renderlo configurabile)
|
75 |
+
cfg_scale=7, # Scala di classifer-free guidance
|
76 |
+
sigma_min=0.03,
|
77 |
+
sigma_max=500,
|
78 |
+
sampler_type="dpmpp-3m-sde" # Tipo di sampler
|
79 |
+
)
|
80 |
+
|
81 |
+
# Riorganizza il batch audio in una singola sequenza
|
82 |
+
output = rearrange(output, "b d n -> d (b n)")
|
83 |
+
|
84 |
+
# Peak normalize, clip, converti in int16, e prepara per la riproduzione
|
85 |
+
output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
|
86 |
+
|
87 |
+
# Salva l'audio in un buffer di memoria per Streamlit
|
88 |
+
buffer = io.BytesIO()
|
89 |
+
torchaudio.save(buffer, output, sample_rate, format="wav")
|
90 |
+
return buffer.getvalue(), sample_rate
|
91 |
+
|
92 |
+
# Streamlit UI
|
93 |
+
uploaded_file = st.file_uploader("Carica un'immagine per la captioning:", type=["png", "jpg", "jpeg"])
|
94 |
+
|
95 |
+
caption = ""
|
96 |
if uploaded_file is not None:
|
97 |
+
from PIL import Image
|
98 |
+
image = Image.open(uploaded_file).convert("RGB")
|
99 |
+
st.image(image, caption="Immagine caricata.", use_column_width=True)
|
100 |
+
|
101 |
+
with st.spinner("Generazione della caption..."):
|
102 |
+
caption = generate_caption(image)
|
103 |
+
st.success(f"Caption generata: **{caption}**")
|
104 |
+
|
105 |
+
# Campo di input per il prompt del soundscape
|
106 |
+
st.header("Generazione Soundscape")
|
107 |
+
soundscape_prompt_input = st.text_input(
|
108 |
+
"Inserisci un prompt per il soundscape (es. 'A gentle rain with thunder and distant birds'):",
|
109 |
+
value=caption if caption else "A natural outdoor soundscape" # Pre-popola con la caption se disponibile
|
110 |
+
)
|
111 |
+
|
112 |
+
if st.button("Genera Soundscape Audio"):
|
113 |
+
if soundscape_prompt_input:
|
114 |
+
audio_bytes, sr = generate_soundscape(soundscape_prompt_input)
|
115 |
+
st.audio(audio_bytes, format='audio/wav', sample_rate=sr)
|
116 |
+
else:
|
117 |
+
st.warning("Per favore, inserisci un prompt per generare il soundscape.")
|
|
|
118 |
|
119 |
+
st.info("Nota: La generazione del soundscape può richiedere un po' di tempo a seconda della complessità del prompt e delle risorse disponibili.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|