File size: 6,352 Bytes
2a7d044
3a947b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a7d044
3a947b7
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import streamlit as st
import io
from PIL import Image
import soundfile as sf
import librosa
import numpy as np
import torch # Importa torch
import sys
sys.setrecursionlimit(2000) # Aumentiamo il limite di ricorsione

# --- Configurazione del Dispositivo ---
# Questo rileva automaticamente se MPS (GPU Apple Silicon) è disponibile
# Per ora, useremo la CPU come fallback se MPS è problematico per Stable Audio
device = "mps" if torch.backends.mps.is_available() else "cpu"
# ******************** MODIFICA QUI: Forza device = "cpu" ********************
# Per superare i problemi di Stable Audio su MPS con float16/float32
# FORZA LA CPU PER TUTTI I MODELLI, per semplicità.
# Se la caption genera velocemente, potremmo tornare indietro e mettere il modello vit_gpt2 su MPS
device = "cpu"
# **************************************************************************
st.write(f"Utilizzo del dispositivo: {device}")


# --- 1. Caricamento dei Modelli AI (spostati qui, fuori dalle funzioni Streamlit) ---
@st.cache_resource
def load_models():
    # Caricamento del modello per la captioning (ViT-GPT2)
    from transformers import AutoFeatureExtractor, AutoTokenizer, AutoModelForVision2Seq
    st.write("Caricamento del modello ViT-GPT2 per la captioning dell'immagine...")

    vit_gpt2_feature_extractor = AutoFeatureExtractor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    vit_gpt2_tokenizer = AutoTokenizer.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    
    # Questo modello andrà sulla CPU
    vit_gpt2_model = AutoModelForVision2Seq.from_pretrained("nlpconnect/vit-gpt2-image-captioning").to(device)

    st.write("Modello ViT-GPT2 caricato.")

    # Caricamento del modello Text-to-Audio (Stable Audio Open - 1.0)
    from diffusers import DiffusionPipeline
    st.write("Caricamento del modello Stable Audio Open - 1.0 per la generazione del soundscape...")
    # ******************** MODIFICA QUI ********************
    # Assicurati che non ci sia torch_dtype=torch.float16 e che vada sulla CPU
    stable_audio_pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-audio-open-1.0", force_download=True).to(device) 
    # ******************************************************
    st.write("Modello Stable Audio Open 1.0 caricato.")

    return vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_pipeline

# Carica i modelli all'avvio dell'app
vit_gpt2_feature_extractor, vit_gpt2_model, vit_gpt2_tokenizer, stable_audio_pipeline = load_models()


# --- 2. Funzioni della Pipeline ---
def generate_image_caption(image_pil):
    pixel_values = vit_gpt2_feature_extractor(images=image_pil.convert("RGB"), return_tensors="pt").pixel_values
    pixel_values = pixel_values.to(device) # Sposta input su CPU
    
    # Token di inizio per GPT-2, assicurandosi che sia su CPU
    # Ottieni il decoder_start_token_id dal modello o dal tokenizer
    if hasattr(vit_gpt2_model.config, "decoder_start_token_id"):
        decoder_start_token_id = vit_gpt2_model.config.decoder_start_token_id
    else:
        if vit_gpt2_tokenizer.pad_token_id is not None:
            decoder_start_token_id = vit_gpt2_tokenizer.pad_token_id
        else:
            decoder_start_token_id = 50256 # Default GPT-2 EOS token

    # Crea un input_ids iniziale con il decoder_start_token_id e spostalo su CPU
    input_ids = torch.ones((1, 1), device=device, dtype=torch.long) * decoder_start_token_id


    output_ids = vit_gpt2_model.generate(
        pixel_values=pixel_values,
        input_ids=input_ids,
        max_length=50,
        do_sample=True,
        top_k=50,
        temperature=0.7,
        no_repeat_ngram_size=2,
        early_stopping=True
    )
    caption = vit_gpt2_tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return caption


def generate_soundscape_from_caption(caption: str, duration_seconds: int = 10):
    st.write(f"Generazione soundscape per: '{caption}' (durata: {duration_seconds}s)")
    with st.spinner("Generazione audio in corso..."):
        try:
            # Assicurati che il modello sia già su CPU dal caricamento
            audio_output = stable_audio_pipeline(
                prompt=caption,
                audio_end_in_s=duration_seconds 
            ).audios

            audio_data = audio_output[0].cpu().numpy() 
            sample_rate = stable_audio_pipeline.sample_rate

            audio_data = audio_data.astype(np.float32)
            audio_data = librosa.util.normalize(audio_data)

            buffer = io.BytesIO()
            sf.write(buffer, audio_data, sample_rate, format='WAV')
            buffer.seek(0)
            return buffer.getvalue(), sample_rate

        except Exception as e:
            st.error(f"Errore durante la generazione dell'audio: {e}")
            return None, None


# --- 3. Interfaccia Streamlit ---
st.title("Generatore di Paesaggi Sonori da Immagini")
st.write("Carica un'immagine e otterrai una descrizione testuale e un paesaggio sonoro generato!")

uploaded_file = st.file_uploader("Scegli un'immagine...", type=["jpg", "jpeg", "png"])

if uploaded_file is not None:
    input_image = Image.open(uploaded_file)
    st.image(input_image, caption='Immagine Caricata.', use_column_width=True)
    st.write("")

    audio_duration = st.slider("Durata audio (secondi):", 5, 30, 10, key="audio_duration_slider")


    if st.button("Genera Paesaggio Sonoro"):
        st.subheader("Processo in Corso...")

        # PASSO 1: Genera la caption
        st.write("Generazione della caption...")
        caption = generate_image_caption(input_image)
        st.write(f"Caption generata: **{caption}**")

        # PASSO 2: Genera il soundscape
        st.write("Generazione del paesaggio sonoro...")
        audio_data_bytes, sample_rate = generate_soundscape_from_caption(caption, duration_seconds=audio_duration)

        if audio_data_bytes is not None:
            st.subheader("Paesaggio Sonoro Generato")
            st.audio(audio_data_bytes, format='audio/wav', sample_rate=sample_rate)

            st.download_button(
                label="Scarica Audio WAV",
                data=audio_data_bytes,
                file_name="paesaggio_sonoro_generato.wav",
                mime="audio/wav"
            )
        else:
            st.error("La generazione del paesaggio sonoro è fallita.")