Spaces:
Sleeping
Sleeping
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
import gradio as gr | |
# Charger le modèle et le tokenizer | |
model_name = "Moustapha91/bart_large_poetique-v02" | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
# Fonction de génération | |
def generate_poem(prompt, temperature, top_k, top_p, num_beams, repetition_penalty, max_length): | |
""" | |
Génère un texte poétique à partir d'un prompt et de paramètres personnalisés. | |
""" | |
# Tokenisation du prompt | |
inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=512, | |
return_tensors="pt", | |
truncation=True, | |
) | |
input_ids = inputs.input_ids.to(model.device) | |
attention_mask = inputs.attention_mask.to(model.device) | |
# Génération avec les paramètres spécifiés par l'utilisateur | |
outputs = model.generate( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
max_length=max_length, | |
temperature=temperature, | |
top_k=top_k, | |
top_p=top_p, | |
num_beams=num_beams, | |
repetition_penalty=repetition_penalty, | |
early_stopping=True, | |
no_repeat_ngram_size=3, # Empêche les répétitions courtes | |
) | |
# Décodage du texte généré | |
output_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return output_text | |
# Interface Gradio | |
def gradio_interface(): | |
# Définir les composants Gradio | |
title = "✨ Générateur de Texte Poétique" | |
description = ( | |
"Cette interface interactive vous permet de générer des textes poétiques " | |
"en entrant un 'prompt' et en ajustant les paramètres de génération." | |
) | |
examples = [ | |
# Chaque exemple correspond à [prompt, temperature, top_k, top_p, num_beams, repetition_penalty, max_length] | |
["Dans la clarté d'une nuit étoilée, un fleuve murmure des secrets oubliés.", 1.0, 50, 0.9, 4, 1.5, 150], | |
["L'amour est un feu doux, éclairant les âmes perdues dans l'ombre.", 1.2, 40, 0.85, 5, 2.0, 100], | |
["Un enfant rêve d'unifier un continent par des chants de liberté.", 0.8, 30, 0.95, 6, 1.8, 200], | |
] | |
interface = gr.Interface( | |
fn=generate_poem, | |
inputs=[ | |
gr.Textbox( | |
label="Prompt", | |
placeholder="Entrez votre prompt ici...", | |
lines=2, | |
), | |
gr.Slider( | |
label="Température (Créativité)", | |
minimum=0.5, | |
maximum=1.5, | |
value=1.0, | |
step=0.1, | |
), | |
gr.Slider( | |
label="Top-k (Mots possibles)", | |
minimum=0, | |
maximum=100, | |
value=50, | |
step=1, | |
), | |
gr.Slider( | |
label="Top-p (Nucleus Sampling)", | |
minimum=0.5, | |
maximum=1.0, | |
value=0.9, | |
step=0.1, | |
), | |
gr.Slider( | |
label="Nombre de faisceaux (Beam Search)", | |
minimum=1, | |
maximum=10, | |
value=4, | |
step=1, | |
), | |
gr.Slider( | |
label="Pénalité de répétition", | |
minimum=1.0, | |
maximum=5.0, | |
value=1.5, | |
step=0.1, | |
), | |
gr.Slider( | |
label="Longueur maximale", | |
minimum=50, | |
maximum=512, | |
value=150, | |
step=10, | |
), | |
], | |
outputs=gr.Textbox(label="Texte Poétique Généré", lines=10), | |
title=title, | |
description=description, | |
examples=examples, | |
) | |
return interface | |
# Lancer l'application Gradio | |
if __name__ == "__main__": | |
interface = gradio_interface() | |
interface.launch() | |