Muyumba commited on
Commit
ec96645
·
verified ·
1 Parent(s): 8e068a3

src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +72 -18
src/streamlit_app.py CHANGED
@@ -1,26 +1,80 @@
1
  import streamlit as st
2
- from huggingface_hub import hf_hub_download
3
  import torch
 
 
 
4
 
5
- # Télécharger le modèle
6
- model_path = hf_hub_download(
7
- repo_id="tencent/SongGeneration",
8
- filename="ckpt/songgeneration_base_zh/model.pt"
9
- )
10
 
11
- # Charger le modèle
12
- model = torch.load(model_path)
13
- model.eval()
 
 
 
 
14
 
15
- # Interface Streamlit
16
- st.title("🎵 Générateur de Chansons")
 
17
 
18
- description = st.text_area("Décrivez l'ambiance de la chanson")
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- if st.button("Générer la chanson"):
21
- if description:
22
- # Génération de la chanson (à adapter selon le modèle)
23
- chanson = model.generate(description)
24
- st.audio(chanson)
25
- else:
 
 
 
 
 
 
26
  st.warning("Veuillez fournir une description.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from huggingface_hub import hf_hub_download, set_access_token
3
  import torch
4
+ import os
5
+ import tempfile
6
+ import soundfile as sf
7
 
8
+ st.set_page_config(page_title="🎵 Générateur de Chansons Local", layout="centered")
9
+ st.title("🎵 Générateur de Chansons (Local CPU, Hugging Face)")
 
 
 
10
 
11
+ # -----------------------------
12
+ # Configuration Hugging Face Hub
13
+ # -----------------------------
14
+ # Token Hugging Face (optionnel si le repo est public)
15
+ HF_TOKEN = st.secrets.get("HF_TOKEN", None)
16
+ if HF_TOKEN:
17
+ set_access_token(HF_TOKEN)
18
 
19
+ # Forcer le cache local dans un dossier où on a les droits
20
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
21
+ os.makedirs(os.environ["HF_HOME"], exist_ok=True)
22
 
23
+ # -----------------------------
24
+ # Télécharger le modèle SongGeneration
25
+ # -----------------------------
26
+ @st.cache_resource
27
+ def load_song_model():
28
+ model_file = hf_hub_download(
29
+ repo_id="tencent/SongGeneration",
30
+ filename="ckpt/songgeneration_base_zh/model.pt"
31
+ )
32
+ model = torch.load(model_file, map_location="cpu")
33
+ model.eval()
34
+ return model
35
 
36
+ song_model = load_song_model()
37
+
38
+ # -----------------------------
39
+ # Interface utilisateur
40
+ # -----------------------------
41
+ description = st.text_area(
42
+ "Décrivez l'ambiance ou le thème de la chanson",
43
+ value="Une chanson nostalgique sur l’amour perdu, style pop moderne."
44
+ )
45
+
46
+ if st.button("🎛️ Générer la chanson"):
47
+ if not description.strip():
48
  st.warning("Veuillez fournir une description.")
49
+ else:
50
+ st.info("Génération en cours… (CPU, cela peut prendre du temps)")
51
+ try:
52
+ # -----------------------------
53
+ # Génération de la chanson (exemple simplifié)
54
+ # -----------------------------
55
+ with torch.no_grad():
56
+ # ⚠️ Adapter selon l'API exacte du modèle SongGeneration
57
+ # Ici on suppose qu'il y a une méthode .generate(text) qui renvoie un array audio
58
+ audio = song_model.generate(description) # numpy array ou torch tensor
59
+
60
+ # -----------------------------
61
+ # Sauvegarder l'audio temporaire pour Streamlit
62
+ # -----------------------------
63
+ tmp_wav = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
64
+ if isinstance(audio, torch.Tensor):
65
+ audio = audio.cpu().numpy()
66
+ sf.write(tmp_wav.name, audio, 44100)
67
+ tmp_wav.close()
68
+
69
+ # -----------------------------
70
+ # Affichage et téléchargement
71
+ # -----------------------------
72
+ st.audio(tmp_wav.name)
73
+ with open(tmp_wav.name, "rb") as f:
74
+ st.download_button("⬇️ Télécharger la chanson", f, "generated_song.wav")
75
+
76
+ st.success("✅ Chanson générée avec succès !")
77
+
78
+ except Exception as e:
79
+ st.error("❌ Erreur lors de la génération de la chanson :")
80
+ st.exception(e)