|
import streamlit as st |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
import torch |
|
|
|
|
|
st.title("🤖 Marketing Text Generator") |
|
st.markdown("*Ein KI-Tool für kreative Marketing-Texte mit verschiedenen Sprachmodellen*") |
|
|
|
|
|
MODELS = { |
|
"GPT-2 (schnell & ressourcensparend)": "gpt2", |
|
"Mistral-7B (ausgewogen)": "mistralai/Mistral-7B-v0.1", |
|
"LLAMA-2 (leistungsstark)": "meta-llama/Llama-2-7b-hf", |
|
"Falcon (kreativ)": "tiiuae/falcon-7b" |
|
} |
|
|
|
@st.cache_resource |
|
def load_model(model_name): |
|
""" |
|
Lädt das ausgewählte Modell und den zugehörigen Tokenizer. |
|
Verwendet Caching für bessere Performance. |
|
""" |
|
try: |
|
if model_name == "gpt2": |
|
|
|
return pipeline('text-generation', model=model_name, device=-1) |
|
else: |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
load_in_8bit=True |
|
) |
|
return (model, tokenizer) |
|
except Exception as e: |
|
st.error(f"Fehler beim Laden des Modells: {str(e)}") |
|
return None |
|
|
|
def generate_text(model_name, prompt, max_length=200): |
|
""" |
|
Generiert Text basierend auf dem ausgewählten Modell und Prompt. |
|
Behandelt verschiedene Modelltypen unterschiedlich. |
|
""" |
|
try: |
|
if model_name == "gpt2": |
|
generator = load_model(model_name) |
|
response = generator(prompt, max_length=max_length, num_return_sequences=1) |
|
return response[0]['generated_text'] |
|
else: |
|
model, tokenizer = load_model(model_name) |
|
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
num_return_sequences=1, |
|
temperature=0.7, |
|
top_p=0.9 |
|
) |
|
|
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
except Exception as e: |
|
st.error(f"Fehler bei der Textgenerierung: {str(e)}") |
|
return None |
|
|
|
def main(): |
|
|
|
with st.sidebar: |
|
st.header("Modell-Einstellungen") |
|
selected_model = st.selectbox( |
|
"Wählen Sie ein Sprachmodell:", |
|
list(MODELS.keys()), |
|
help="Verschiedene Modelle haben unterschiedliche Stärken und Performance-Eigenschaften" |
|
) |
|
|
|
st.markdown("---") |
|
st.markdown(""" |
|
**Modell-Informationen:** |
|
- GPT-2: Schnell, aber basic |
|
- Mistral: Guter Allrounder |
|
- LLAMA-2: Sehr leistungsfähig |
|
- Falcon: Besonders kreativ |
|
""") |
|
|
|
|
|
with st.form("text_generation_form"): |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
product_name = st.text_input( |
|
"Produktname", |
|
help="Name des Produkts, für das Text generiert werden soll" |
|
) |
|
with col2: |
|
target_audience = st.text_input( |
|
"Zielgruppe", |
|
help="Beschreiben Sie Ihre Zielgruppe" |
|
) |
|
|
|
key_features = st.text_area( |
|
"Hauptmerkmale", |
|
help="Listen Sie die wichtigsten Eigenschaften des Produkts auf (durch Kommas getrennt)" |
|
) |
|
|
|
tone_options = ["Professionell", "Casual", "Luxuriös", "Jugendlich", "Technisch"] |
|
tone = st.select_slider( |
|
"Tonalität", |
|
options=tone_options, |
|
value="Professionell" |
|
) |
|
|
|
submit_button = st.form_submit_button("Text generieren") |
|
|
|
if submit_button: |
|
if not product_name or not key_features: |
|
st.warning("Bitte füllen Sie mindestens Produktname und Hauptmerkmale aus.") |
|
return |
|
|
|
|
|
with st.spinner(f'Generiere Text mit {selected_model.split(" ")[0]}...'): |
|
|
|
prompt = f""" |
|
Erstelle einen überzeugenden Marketing-Text mit folgendem Kontext: |
|
Produkt: {product_name} |
|
Zielgruppe: {target_audience} |
|
Hauptmerkmale: {key_features} |
|
Tonalität: {tone} |
|
|
|
Der Text sollte die USPs hervorheben und die Zielgruppe direkt ansprechen. |
|
""" |
|
|
|
|
|
model_name = MODELS[selected_model] |
|
response = generate_text(model_name, prompt) |
|
|
|
if response: |
|
st.success("Text wurde generiert!") |
|
st.markdown("### Generierter Marketing-Text:") |
|
st.markdown(response) |
|
|
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
if st.button("Text kopieren"): |
|
st.text_area("Kopieren Sie den Text:", value=response) |
|
with col2: |
|
st.download_button( |
|
"Als TXT herunterladen", |
|
response, |
|
file_name="marketing_text.txt" |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |