Jan Kirenz commited on
Commit
4235103
·
1 Parent(s): b0ce7e8

added gpt2 fallback

Browse files
Files changed (3) hide show
  1. app.py +59 -51
  2. app_model_picker.py +0 -157
  3. requirements.txt +2 -1
app.py CHANGED
@@ -1,36 +1,57 @@
1
  import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- st.title("🚀 Marketing Text Generator mit Falcon")
6
- st.markdown("*Generiere kreative Marketing-Texte mit dem Falcon-7B Modell*")
7
 
8
  @st.cache_resource
9
- def load_falcon_model():
10
  """
11
- Lädt das Falcon Modell in einer CPU-freundlichen Konfiguration.
12
- Diese Version verwendet weniger Speicher und läuft auf allen Systemen.
13
  """
14
  try:
15
- model_name = "tiiuae/falcon-7b"
16
-
17
- # Tokenizer bleibt unverändert
18
- tokenizer = AutoTokenizer.from_pretrained(model_name)
19
-
20
- # Angepasste Modellkonfiguration für CPU-Systeme
21
  model = AutoModelForCausalLM.from_pretrained(
22
- model_name,
23
  trust_remote_code=True,
24
  device_map="auto",
25
- torch_dtype=torch.float32, # Standard-Datentyp statt 8-bit
26
- low_cpu_mem_usage=True # Speicheroptimierung für CPU
 
27
  )
28
-
29
- return model, tokenizer
30
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  except Exception as e:
32
- st.error(f"Fehler beim Laden des Modells: {str(e)}")
33
- return None, None
34
 
35
  # Hauptbereich für die Eingabe
36
  with st.form("marketing_form"):
@@ -41,48 +62,34 @@ with st.form("marketing_form"):
41
 
42
  key_features = st.text_area(
43
  "Produktmerkmale",
44
- help="Beschreiben Sie die wichtigsten Eigenschaften (durch Kommas getrennt)"
45
  )
46
 
47
- # Zusätzliche Kontrolle über die Textlänge
48
  max_length = st.slider(
49
  "Maximale Textlänge",
50
  min_value=50,
51
  max_value=200,
52
- value=100,
53
- help="Längere Texte benötigen mehr Verarbeitungszeit"
54
  )
55
 
56
  submit = st.form_submit_button("Marketing-Text generieren")
57
 
58
  if submit and product_name and key_features:
59
- with st.spinner("Lade Falcon Modell... (Dies kann einige Minuten dauern)"):
60
- model, tokenizer = load_falcon_model()
61
 
62
- if model and tokenizer:
63
- prompt = f"""
64
- Erstelle einen kurzen, überzeugenden Marketing-Text für folgendes Produkt:
65
- Produkt: {product_name}
66
- Merkmale: {key_features}
67
-
68
- Der Text sollte professionell und ansprechend sein.
69
- """
 
 
70
 
71
- with st.spinner("Generiere Marketing-Text..."):
72
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
73
-
74
- with torch.no_grad():
75
- outputs = model.generate(
76
- **inputs,
77
- max_length=max_length,
78
- temperature=0.7,
79
- top_p=0.9,
80
- num_return_sequences=1,
81
- pad_token_id=tokenizer.eos_token_id # Verbesserte Token-Handhabung
82
- )
83
-
84
- generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
-
86
  st.success("Text wurde erfolgreich generiert!")
87
  st.markdown("### Ihr Marketing-Text:")
88
  st.markdown(generated_text)
@@ -92,13 +99,14 @@ if submit and product_name and key_features:
92
  generated_text,
93
  file_name="marketing_text.txt"
94
  )
 
95
  elif submit:
96
  st.warning("Bitte füllen Sie alle Felder aus.")
97
 
98
  st.markdown("---")
99
  st.markdown("""
100
  **Wichtige Hinweise:**
101
- - Die erste Generierung dauert länger, da das Modell geladen werden muss
102
- - Auf Systemen ohne GPU kann die Verarbeitung mehrere Minuten in Anspruch nehmen
103
  - Kürzere Texte werden schneller generiert
104
  """)
 
1
  import streamlit as st
2
+ from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ st.title("🚀 Marketing Text Generator mit KI")
6
+ st.markdown("*Generiere kreative Marketing-Texte für deine Produkte*")
7
 
8
  @st.cache_resource
9
+ def load_model():
10
  """
11
+ Versucht zuerst das Falcon-Modell zu laden. Falls dies nicht möglich ist,
12
+ wird automatisch auf das kleinere GPT-2 Modell zurückgegriffen.
13
  """
14
  try:
15
+ # Erster Versuch: Optimiertes Falcon
16
+ tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-7b")
 
 
 
 
17
  model = AutoModelForCausalLM.from_pretrained(
18
+ "tiiuae/falcon-7b",
19
  trust_remote_code=True,
20
  device_map="auto",
21
+ torch_dtype=torch.float32,
22
+ low_cpu_mem_usage=True,
23
+ max_memory={0: "4GB"} # Begrenzt den Speicherverbrauch
24
  )
25
+ return ("falcon", model, tokenizer)
26
+ except Exception as e:
27
+ st.warning("Falcon konnte nicht geladen werden. Verwende GPT-2 als Alternative.")
28
+ # Fallback: GPT-2
29
+ generator = pipeline('text-generation', model='gpt2', device=-1)
30
+ return ("gpt2", generator, None)
31
+
32
+ def generate_text(model_type, model, tokenizer, prompt, max_length):
33
+ """
34
+ Generiert Text abhängig vom geladenen Modelltyp.
35
+ """
36
+ try:
37
+ if model_type == "gpt2":
38
+ response = model(prompt, max_length=max_length, num_return_sequences=1)
39
+ return response[0]['generated_text']
40
+ else:
41
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
42
+ with torch.no_grad():
43
+ outputs = model.generate(
44
+ **inputs,
45
+ max_length=max_length,
46
+ temperature=0.7,
47
+ top_p=0.9,
48
+ num_return_sequences=1,
49
+ pad_token_id=tokenizer.eos_token_id
50
+ )
51
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
52
  except Exception as e:
53
+ st.error(f"Fehler bei der Textgenerierung: {str(e)}")
54
+ return None
55
 
56
  # Hauptbereich für die Eingabe
57
  with st.form("marketing_form"):
 
62
 
63
  key_features = st.text_area(
64
  "Produktmerkmale",
65
+ help="Beschreiben Sie die wichtigsten Eigenschaften"
66
  )
67
 
 
68
  max_length = st.slider(
69
  "Maximale Textlänge",
70
  min_value=50,
71
  max_value=200,
72
+ value=100
 
73
  )
74
 
75
  submit = st.form_submit_button("Marketing-Text generieren")
76
 
77
  if submit and product_name and key_features:
78
+ with st.spinner("Lade KI-Modell..."):
79
+ model_type, model, tokenizer = load_model()
80
 
81
+ prompt = f"""
82
+ Erstelle einen überzeugenden Marketing-Text auf Deutsch für folgendes Produkt:
83
+ Produkt: {product_name}
84
+ Merkmale: {key_features}
85
+
86
+ Der Text sollte professionell, kreativ und verkaufsfördernd sein.
87
+ """
88
+
89
+ with st.spinner("Generiere Marketing-Text..."):
90
+ generated_text = generate_text(model_type, model, tokenizer, prompt, max_length)
91
 
92
+ if generated_text:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  st.success("Text wurde erfolgreich generiert!")
94
  st.markdown("### Ihr Marketing-Text:")
95
  st.markdown(generated_text)
 
99
  generated_text,
100
  file_name="marketing_text.txt"
101
  )
102
+
103
  elif submit:
104
  st.warning("Bitte füllen Sie alle Felder aus.")
105
 
106
  st.markdown("---")
107
  st.markdown("""
108
  **Wichtige Hinweise:**
109
+ - Die erste Generierung dauert länger (Modell-Ladezeit)
110
+ - Bei Speicherproblemen wird automatisch ein kleineres Modell verwendet
111
  - Kürzere Texte werden schneller generiert
112
  """)
app_model_picker.py DELETED
@@ -1,157 +0,0 @@
1
- import streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
- import torch
4
-
5
- # Titel der Anwendung mit erklärender Unterzeile
6
- st.title("🤖 Marketing Text Generator")
7
- st.markdown("*Ein KI-Tool für kreative Marketing-Texte mit verschiedenen Sprachmodellen*")
8
-
9
- # Konfiguration der verfügbaren Modelle
10
- MODELS = {
11
- "GPT-2 (schnell & ressourcensparend)": "gpt2",
12
- "Mistral-7B (ausgewogen)": "mistralai/Mistral-7B-v0.1",
13
- "LLAMA-2 (leistungsstark)": "meta-llama/Llama-2-7b-hf",
14
- "Falcon (kreativ)": "tiiuae/falcon-7b"
15
- }
16
-
17
- @st.cache_resource
18
- def load_model(model_name):
19
- """
20
- Lädt das ausgewählte Modell und den zugehörigen Tokenizer.
21
- Verwendet Caching für bessere Performance.
22
- """
23
- try:
24
- if model_name == "gpt2":
25
- # GPT-2 ist einfacher zu laden und benötigt weniger Ressourcen
26
- return pipeline('text-generation', model=model_name, device=-1)
27
- else:
28
- # Fortgeschrittene Modelle benötigen spezielle Konfiguration
29
- tokenizer = AutoTokenizer.from_pretrained(model_name)
30
- model = AutoModelForCausalLM.from_pretrained(
31
- model_name,
32
- device_map="auto",
33
- trust_remote_code=True,
34
- load_in_8bit=True # Speicheroptimierung
35
- )
36
- return (model, tokenizer)
37
- except Exception as e:
38
- st.error(f"Fehler beim Laden des Modells: {str(e)}")
39
- return None
40
-
41
- def generate_text(model_name, prompt, max_length=200):
42
- """
43
- Generiert Text basierend auf dem ausgewählten Modell und Prompt.
44
- Behandelt verschiedene Modelltypen unterschiedlich.
45
- """
46
- try:
47
- if model_name == "gpt2":
48
- generator = load_model(model_name)
49
- response = generator(prompt, max_length=max_length, num_return_sequences=1)
50
- return response[0]['generated_text']
51
- else:
52
- model, tokenizer = load_model(model_name)
53
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
54
-
55
- with torch.no_grad():
56
- outputs = model.generate(
57
- **inputs,
58
- max_length=max_length,
59
- num_return_sequences=1,
60
- temperature=0.7, # Kreativität kontrollieren
61
- top_p=0.9 # Vielfalt der Ausgabe steuern
62
- )
63
-
64
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
65
- except Exception as e:
66
- st.error(f"Fehler bei der Textgenerierung: {str(e)}")
67
- return None
68
-
69
- def main():
70
- # Seitenleiste für Modellauswahl und Erklärungen
71
- with st.sidebar:
72
- st.header("Modell-Einstellungen")
73
- selected_model = st.selectbox(
74
- "Wählen Sie ein Sprachmodell:",
75
- list(MODELS.keys()),
76
- help="Verschiedene Modelle haben unterschiedliche Stärken und Performance-Eigenschaften"
77
- )
78
-
79
- st.markdown("---")
80
- st.markdown("""
81
- **Modell-Informationen:**
82
- - GPT-2: Schnell, aber basic
83
- - Mistral: Guter Allrounder
84
- - LLAMA-2: Sehr leistungsfähig
85
- - Falcon: Besonders kreativ
86
- """)
87
-
88
- # Hauptbereich für Eingabe und Generierung
89
- with st.form("text_generation_form"):
90
- # Strukturierte Eingabefelder
91
- col1, col2 = st.columns(2)
92
- with col1:
93
- product_name = st.text_input(
94
- "Produktname",
95
- help="Name des Produkts, für das Text generiert werden soll"
96
- )
97
- with col2:
98
- target_audience = st.text_input(
99
- "Zielgruppe",
100
- help="Beschreiben Sie Ihre Zielgruppe"
101
- )
102
-
103
- key_features = st.text_area(
104
- "Hauptmerkmale",
105
- help="Listen Sie die wichtigsten Eigenschaften des Produkts auf (durch Kommas getrennt)"
106
- )
107
-
108
- tone_options = ["Professionell", "Casual", "Luxuriös", "Jugendlich", "Technisch"]
109
- tone = st.select_slider(
110
- "Tonalität",
111
- options=tone_options,
112
- value="Professionell"
113
- )
114
-
115
- submit_button = st.form_submit_button("Text generieren")
116
-
117
- if submit_button:
118
- if not product_name or not key_features:
119
- st.warning("Bitte füllen Sie mindestens Produktname und Hauptmerkmale aus.")
120
- return
121
-
122
- # Fortschrittsanzeige
123
- with st.spinner(f'Generiere Text mit {selected_model.split(" ")[0]}...'):
124
- # Marketing-spezifischer Prompt
125
- prompt = f"""
126
- Erstelle einen überzeugenden Marketing-Text mit folgendem Kontext:
127
- Produkt: {product_name}
128
- Zielgruppe: {target_audience}
129
- Hauptmerkmale: {key_features}
130
- Tonalität: {tone}
131
-
132
- Der Text sollte die USPs hervorheben und die Zielgruppe direkt ansprechen.
133
- """
134
-
135
- # Modellname aus dem Dictionary abrufen
136
- model_name = MODELS[selected_model]
137
- response = generate_text(model_name, prompt)
138
-
139
- if response:
140
- st.success("Text wurde generiert!")
141
- st.markdown("### Generierter Marketing-Text:")
142
- st.markdown(response)
143
-
144
- # Zusätzliche Aktionen anbieten
145
- col1, col2 = st.columns(2)
146
- with col1:
147
- if st.button("Text kopieren"):
148
- st.text_area("Kopieren Sie den Text:", value=response)
149
- with col2:
150
- st.download_button(
151
- "Als TXT herunterladen",
152
- response,
153
- file_name="marketing_text.txt"
154
- )
155
-
156
- if __name__ == "__main__":
157
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  streamlit
2
  transformers
3
  torch
4
- accelerate
 
 
1
  streamlit
2
  transformers
3
  torch
4
+ accelerate
5
+ protobuf