Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files
app.py
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
import os, json, numpy as np, pandas as pd
|
| 3 |
import gradio as gr
|
| 4 |
import faiss
|
| 5 |
-
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 8 |
|
|
@@ -111,19 +111,17 @@ _encoder = SentenceTransformer(meta["model_name"])
|
|
| 111 |
_gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
|
| 112 |
_gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL)
|
| 113 |
|
| 114 |
-
#
|
| 115 |
-
_BANNED_TERMS = {
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
_PUNCT = {":",";","—","–","-",".",",","!","?","“","”","\"","'"}
|
| 120 |
_MIN_WORDS, _MAX_WORDS = 2, 8
|
| 121 |
|
| 122 |
def _load_prompt():
|
| 123 |
if os.path.exists(PROMPT_PATH):
|
| 124 |
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
|
| 125 |
return f.read()
|
| 126 |
-
# fallback if no prompt file shipped
|
| 127 |
return (
|
| 128 |
"You are a professional slogan writer.\n"
|
| 129 |
"Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n"
|
|
@@ -137,7 +135,6 @@ def _render_prompt(description: str, retrieved=None) -> str:
|
|
| 137 |
prompt = tmpl.replace("{description}", description)
|
| 138 |
else:
|
| 139 |
prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:"
|
| 140 |
-
# Optionally add negatives (don’t copy these)
|
| 141 |
if retrieved:
|
| 142 |
prompt += "\n\nDo NOT copy these existing slogans:\n"
|
| 143 |
for s in retrieved[:3]:
|
|
@@ -154,10 +151,11 @@ def _title_case(s: str) -> str:
|
|
| 154 |
else: out.append(lw.capitalize())
|
| 155 |
return " ".join(out)
|
| 156 |
|
| 157 |
-
def
|
|
|
|
|
|
|
|
|
|
| 158 |
if not s: return False
|
| 159 |
-
s = s.strip()
|
| 160 |
-
if any(p in s for p in _PUNCT): return False
|
| 161 |
wc = len(s.split())
|
| 162 |
if wc < _MIN_WORDS or wc > _MAX_WORDS: return False
|
| 163 |
lo = s.lower()
|
|
@@ -165,22 +163,40 @@ def _looks_ok(s: str) -> bool:
|
|
| 165 |
if lo in {"the","a","an"}: return False
|
| 166 |
return True
|
| 167 |
|
| 168 |
-
def
|
| 169 |
cleaned, seen = [], set()
|
| 170 |
-
for
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
|
|
|
| 176 |
if k not in seen:
|
| 177 |
-
seen.add(k); cleaned.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 178 |
return cleaned
|
| 179 |
|
| 180 |
def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES):
|
| 181 |
prompt = _render_prompt(description, retrieved_texts)
|
| 182 |
-
|
|
|
|
| 183 |
bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids
|
|
|
|
| 184 |
inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
|
| 185 |
outputs = _gen_model.generate(
|
| 186 |
**inputs,
|
|
@@ -195,15 +211,19 @@ def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CAN
|
|
| 195 |
eos_token_id=_gen_tokenizer.eos_token_id,
|
| 196 |
)
|
| 197 |
texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 198 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
|
| 200 |
def _pick_best(candidates, retrieved_texts, description):
|
| 201 |
-
"""Weighted relevance to
|
| 202 |
if not candidates:
|
| 203 |
return None
|
| 204 |
c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
|
| 205 |
d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0]
|
| 206 |
-
rel = c_emb @ d_emb # cosine
|
| 207 |
|
| 208 |
if retrieved_texts:
|
| 209 |
R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True)
|
|
@@ -217,7 +237,8 @@ def _pick_best(candidates, retrieved_texts, description):
|
|
| 217 |
scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask]
|
| 218 |
best_idx = np.argmax(scores)
|
| 219 |
return [c for i, c in enumerate(candidates) if mask[i]][best_idx]
|
| 220 |
-
|
|
|
|
| 221 |
scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup
|
| 222 |
order = np.argsort(-scores)
|
| 223 |
for i in order:
|
|
|
|
| 2 |
import os, json, numpy as np, pandas as pd
|
| 3 |
import gradio as gr
|
| 4 |
import faiss
|
| 5 |
+
import re
|
| 6 |
from sentence_transformers import SentenceTransformer
|
| 7 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
| 8 |
|
|
|
|
| 111 |
_gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
|
| 112 |
_gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL)
|
| 113 |
|
| 114 |
+
# keep this list small so we don't nuke relevant outputs
|
| 115 |
+
_BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"}
|
| 116 |
+
_PUNCT_CHARS = ":;—–-,.!?“”\"'`"
|
| 117 |
+
_PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]")
|
| 118 |
+
|
|
|
|
| 119 |
_MIN_WORDS, _MAX_WORDS = 2, 8
|
| 120 |
|
| 121 |
def _load_prompt():
|
| 122 |
if os.path.exists(PROMPT_PATH):
|
| 123 |
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
|
| 124 |
return f.read()
|
|
|
|
| 125 |
return (
|
| 126 |
"You are a professional slogan writer.\n"
|
| 127 |
"Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n"
|
|
|
|
| 135 |
prompt = tmpl.replace("{description}", description)
|
| 136 |
else:
|
| 137 |
prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:"
|
|
|
|
| 138 |
if retrieved:
|
| 139 |
prompt += "\n\nDo NOT copy these existing slogans:\n"
|
| 140 |
for s in retrieved[:3]:
|
|
|
|
| 151 |
else: out.append(lw.capitalize())
|
| 152 |
return " ".join(out)
|
| 153 |
|
| 154 |
+
def _strip_punct(s: str) -> str:
|
| 155 |
+
return _PUNCT_RE.sub("", s)
|
| 156 |
+
|
| 157 |
+
def _strict_ok(s: str) -> bool:
|
| 158 |
if not s: return False
|
|
|
|
|
|
|
| 159 |
wc = len(s.split())
|
| 160 |
if wc < _MIN_WORDS or wc > _MAX_WORDS: return False
|
| 161 |
lo = s.lower()
|
|
|
|
| 163 |
if lo in {"the","a","an"}: return False
|
| 164 |
return True
|
| 165 |
|
| 166 |
+
def _postprocess_strict(texts):
|
| 167 |
cleaned, seen = [], set()
|
| 168 |
+
for t in texts:
|
| 169 |
+
s = t.replace("Slogan:", "").strip().strip('"').strip("'")
|
| 170 |
+
s = " ".join(s.split())
|
| 171 |
+
s = _strip_punct(s) # remove punctuation instead of rejecting
|
| 172 |
+
s = _title_case(s)
|
| 173 |
+
if _strict_ok(s):
|
| 174 |
+
k = s.lower()
|
| 175 |
if k not in seen:
|
| 176 |
+
seen.add(k); cleaned.append(s)
|
| 177 |
+
return cleaned
|
| 178 |
+
|
| 179 |
+
def _postprocess_relaxed(texts):
|
| 180 |
+
# fallback if strict returns nothing: keep 2–8 words, strip punctuation, Title Case
|
| 181 |
+
cleaned, seen = [], set()
|
| 182 |
+
for t in texts:
|
| 183 |
+
s = t.strip().strip('"').strip("'")
|
| 184 |
+
s = _strip_punct(s)
|
| 185 |
+
s = " ".join(s.split())
|
| 186 |
+
wc = len(s.split())
|
| 187 |
+
if _MIN_WORDS <= wc <= _MAX_WORDS:
|
| 188 |
+
s = _title_case(s)
|
| 189 |
+
k = s.lower()
|
| 190 |
+
if k not in seen:
|
| 191 |
+
seen.add(k); cleaned.append(s)
|
| 192 |
return cleaned
|
| 193 |
|
| 194 |
def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES):
|
| 195 |
prompt = _render_prompt(description, retrieved_texts)
|
| 196 |
+
|
| 197 |
+
# only block very generic junk at decode time
|
| 198 |
bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids
|
| 199 |
+
|
| 200 |
inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
|
| 201 |
outputs = _gen_model.generate(
|
| 202 |
**inputs,
|
|
|
|
| 211 |
eos_token_id=_gen_tokenizer.eos_token_id,
|
| 212 |
)
|
| 213 |
texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
| 214 |
+
|
| 215 |
+
cands = _postprocess_strict(texts)
|
| 216 |
+
if not cands:
|
| 217 |
+
cands = _postprocess_relaxed(texts) # <- graceful fallback
|
| 218 |
+
return cands
|
| 219 |
|
| 220 |
def _pick_best(candidates, retrieved_texts, description):
|
| 221 |
+
"""Weighted relevance to description minus duplication vs retrieved."""
|
| 222 |
if not candidates:
|
| 223 |
return None
|
| 224 |
c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
|
| 225 |
d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0]
|
| 226 |
+
rel = c_emb @ d_emb # cosine sim to description
|
| 227 |
|
| 228 |
if retrieved_texts:
|
| 229 |
R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True)
|
|
|
|
| 237 |
scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask]
|
| 238 |
best_idx = np.argmax(scores)
|
| 239 |
return [c for i, c in enumerate(candidates) if mask[i]][best_idx]
|
| 240 |
+
|
| 241 |
+
# else: pick most relevant that still clears a basic novelty bar, else top score
|
| 242 |
scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup
|
| 243 |
order = np.argsort(-scores)
|
| 244 |
for i in order:
|