yair319732 commited on
Commit
bcaa618
·
verified ·
1 Parent(s): 369b8da

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +46 -25
app.py CHANGED
@@ -2,7 +2,7 @@
2
  import os, json, numpy as np, pandas as pd
3
  import gradio as gr
4
  import faiss
5
-
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
@@ -111,19 +111,17 @@ _encoder = SentenceTransformer(meta["model_name"])
111
  _gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
112
  _gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL)
113
 
114
- # banned junk and formatting constraints
115
- _BANNED_TERMS = {
116
- "portal","platform","service","solution","assistant","product","company",
117
- "business","website","app","shopping","shop","e-commerce","ecommerce"
118
- }
119
- _PUNCT = {":",";","—","–","-",".",",","!","?","“","”","\"","'"}
120
  _MIN_WORDS, _MAX_WORDS = 2, 8
121
 
122
  def _load_prompt():
123
  if os.path.exists(PROMPT_PATH):
124
  with open(PROMPT_PATH, "r", encoding="utf-8") as f:
125
  return f.read()
126
- # fallback if no prompt file shipped
127
  return (
128
  "You are a professional slogan writer.\n"
129
  "Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n"
@@ -137,7 +135,6 @@ def _render_prompt(description: str, retrieved=None) -> str:
137
  prompt = tmpl.replace("{description}", description)
138
  else:
139
  prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:"
140
- # Optionally add negatives (don’t copy these)
141
  if retrieved:
142
  prompt += "\n\nDo NOT copy these existing slogans:\n"
143
  for s in retrieved[:3]:
@@ -154,10 +151,11 @@ def _title_case(s: str) -> str:
154
  else: out.append(lw.capitalize())
155
  return " ".join(out)
156
 
157
- def _looks_ok(s: str) -> bool:
 
 
 
158
  if not s: return False
159
- s = s.strip()
160
- if any(p in s for p in _PUNCT): return False
161
  wc = len(s.split())
162
  if wc < _MIN_WORDS or wc > _MAX_WORDS: return False
163
  lo = s.lower()
@@ -165,22 +163,40 @@ def _looks_ok(s: str) -> bool:
165
  if lo in {"the","a","an"}: return False
166
  return True
167
 
168
- def _postprocess(cands):
169
  cleaned, seen = [], set()
170
- for c in cands:
171
- c = c.replace("Slogan:", "").strip().strip('"').strip("'")
172
- c = " ".join(c.split())
173
- c = _title_case(c)
174
- if _looks_ok(c):
175
- k = c.lower()
 
176
  if k not in seen:
177
- seen.add(k); cleaned.append(c)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  return cleaned
179
 
180
  def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES):
181
  prompt = _render_prompt(description, retrieved_texts)
182
- # ban generic words at decode time
 
183
  bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids
 
184
  inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
185
  outputs = _gen_model.generate(
186
  **inputs,
@@ -195,15 +211,19 @@ def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CAN
195
  eos_token_id=_gen_tokenizer.eos_token_id,
196
  )
197
  texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
198
- return _postprocess(texts)
 
 
 
 
199
 
200
  def _pick_best(candidates, retrieved_texts, description):
201
- """Weighted relevance to the description minus duplication vs retrieved."""
202
  if not candidates:
203
  return None
204
  c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
205
  d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0]
206
- rel = c_emb @ d_emb # cosine similarity to description
207
 
208
  if retrieved_texts:
209
  R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True)
@@ -217,7 +237,8 @@ def _pick_best(candidates, retrieved_texts, description):
217
  scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask]
218
  best_idx = np.argmax(scores)
219
  return [c for i, c in enumerate(candidates) if mask[i]][best_idx]
220
- # if all are too close, pick the most relevant that still passes basic novelty threshold
 
221
  scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup
222
  order = np.argsort(-scores)
223
  for i in order:
 
2
  import os, json, numpy as np, pandas as pd
3
  import gradio as gr
4
  import faiss
5
+ import re
6
  from sentence_transformers import SentenceTransformer
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
8
 
 
111
  _gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
112
  _gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL)
113
 
114
+ # keep this list small so we don't nuke relevant outputs
115
+ _BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"}
116
+ _PUNCT_CHARS = ":;—–-,.!?“”\"'`"
117
+ _PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]")
118
+
 
119
  _MIN_WORDS, _MAX_WORDS = 2, 8
120
 
121
  def _load_prompt():
122
  if os.path.exists(PROMPT_PATH):
123
  with open(PROMPT_PATH, "r", encoding="utf-8") as f:
124
  return f.read()
 
125
  return (
126
  "You are a professional slogan writer.\n"
127
  "Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n"
 
135
  prompt = tmpl.replace("{description}", description)
136
  else:
137
  prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:"
 
138
  if retrieved:
139
  prompt += "\n\nDo NOT copy these existing slogans:\n"
140
  for s in retrieved[:3]:
 
151
  else: out.append(lw.capitalize())
152
  return " ".join(out)
153
 
154
+ def _strip_punct(s: str) -> str:
155
+ return _PUNCT_RE.sub("", s)
156
+
157
+ def _strict_ok(s: str) -> bool:
158
  if not s: return False
 
 
159
  wc = len(s.split())
160
  if wc < _MIN_WORDS or wc > _MAX_WORDS: return False
161
  lo = s.lower()
 
163
  if lo in {"the","a","an"}: return False
164
  return True
165
 
166
+ def _postprocess_strict(texts):
167
  cleaned, seen = [], set()
168
+ for t in texts:
169
+ s = t.replace("Slogan:", "").strip().strip('"').strip("'")
170
+ s = " ".join(s.split())
171
+ s = _strip_punct(s) # remove punctuation instead of rejecting
172
+ s = _title_case(s)
173
+ if _strict_ok(s):
174
+ k = s.lower()
175
  if k not in seen:
176
+ seen.add(k); cleaned.append(s)
177
+ return cleaned
178
+
179
+ def _postprocess_relaxed(texts):
180
+ # fallback if strict returns nothing: keep 2–8 words, strip punctuation, Title Case
181
+ cleaned, seen = [], set()
182
+ for t in texts:
183
+ s = t.strip().strip('"').strip("'")
184
+ s = _strip_punct(s)
185
+ s = " ".join(s.split())
186
+ wc = len(s.split())
187
+ if _MIN_WORDS <= wc <= _MAX_WORDS:
188
+ s = _title_case(s)
189
+ k = s.lower()
190
+ if k not in seen:
191
+ seen.add(k); cleaned.append(s)
192
  return cleaned
193
 
194
  def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES):
195
  prompt = _render_prompt(description, retrieved_texts)
196
+
197
+ # only block very generic junk at decode time
198
  bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids
199
+
200
  inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
201
  outputs = _gen_model.generate(
202
  **inputs,
 
211
  eos_token_id=_gen_tokenizer.eos_token_id,
212
  )
213
  texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
214
+
215
+ cands = _postprocess_strict(texts)
216
+ if not cands:
217
+ cands = _postprocess_relaxed(texts) # <- graceful fallback
218
+ return cands
219
 
220
  def _pick_best(candidates, retrieved_texts, description):
221
+ """Weighted relevance to description minus duplication vs retrieved."""
222
  if not candidates:
223
  return None
224
  c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
225
  d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0]
226
+ rel = c_emb @ d_emb # cosine sim to description
227
 
228
  if retrieved_texts:
229
  R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True)
 
237
  scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask]
238
  best_idx = np.argmax(scores)
239
  return [c for i, c in enumerate(candidates) if mask[i]][best_idx]
240
+
241
+ # else: pick most relevant that still clears a basic novelty bar, else top score
242
  scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup
243
  order = np.argsort(-scores)
244
  for i in order: