|
\ |
|
import os, json, numpy as np, pandas as pd |
|
import gradio as gr |
|
import faiss |
|
import re |
|
from sentence_transformers import SentenceTransformer |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
from logic.cleaning import clean_dataframe |
|
from logic.search import SloganSearcher |
|
|
|
|
|
ASSETS_DIR = "assets" |
|
DATA_PATH = "data/slogan.csv" |
|
PROMPT_PATH = "data/prompt.txt" |
|
|
|
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" |
|
NORMALIZE = True |
|
|
|
GEN_MODEL = "google/flan-t5-base" |
|
NUM_GEN_CANDIDATES = 12 |
|
MAX_NEW_TOKENS = 18 |
|
TEMPERATURE = 0.7 |
|
TOP_P = 0.9 |
|
REPETITION_PENALTY = 1.15 |
|
|
|
|
|
RELEVANCE_WEIGHT = 0.7 |
|
NOVELTY_WEIGHT = 0.3 |
|
DUPLICATE_MAX_SIM = 0.92 |
|
NOVELTY_SIM_THRESHOLD = 0.80 |
|
|
|
META_PATH = os.path.join(ASSETS_DIR, "meta.json") |
|
PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet") |
|
INDEX_PATH = os.path.join(ASSETS_DIR, "faiss.index") |
|
EMB_PATH = os.path.join(ASSETS_DIR, "embeddings.npy") |
|
|
|
def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True) |
|
|
|
|
|
def _build_assets(): |
|
if not os.path.exists(DATA_PATH): |
|
raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').") |
|
os.makedirs(ASSETS_DIR, exist_ok=True) |
|
|
|
_log(f"Loading dataset: {DATA_PATH}") |
|
df = pd.read_csv(DATA_PATH) |
|
|
|
_log(f"Rows before cleaning: {len(df)}") |
|
df = clean_dataframe(df) |
|
_log(f"Rows after cleaning: {len(df)}") |
|
|
|
if "description" in df.columns and df["description"].notna().any(): |
|
texts = df["description"].fillna(df["tagline"]).astype(str).tolist() |
|
text_col, fallback_col = "description", "tagline" |
|
else: |
|
texts = df["tagline"].astype(str).tolist() |
|
text_col, fallback_col = "tagline", "tagline" |
|
|
|
_log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) β¦") |
|
encoder = SentenceTransformer(MODEL_NAME) |
|
emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE) |
|
|
|
dim = emb.shape[1] |
|
index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim) |
|
index.add(emb) |
|
|
|
_log("Persisting assets β¦") |
|
df.to_parquet(PARQUET_PATH, index=False) |
|
faiss.write_index(index, INDEX_PATH) |
|
np.save(EMB_PATH, emb) |
|
|
|
meta = { |
|
"model_name": MODEL_NAME, |
|
"dim": int(dim), |
|
"normalized": NORMALIZE, |
|
"metric": "ip" if NORMALIZE else "l2", |
|
"row_count": int(len(df)), |
|
"text_col": text_col, |
|
"fallback_col": fallback_col, |
|
} |
|
with open(META_PATH, "w") as f: |
|
json.dump(meta, f, indent=2) |
|
_log("Assets built successfully.") |
|
|
|
def _ensure_assets(): |
|
need = False |
|
for p in (META_PATH, PARQUET_PATH, INDEX_PATH): |
|
if not os.path.exists(p): |
|
_log(f"Missing asset: {p}") |
|
need = True |
|
if need: |
|
_log("Building assets from scratch β¦") |
|
_build_assets() |
|
return |
|
try: |
|
pd.read_parquet(PARQUET_PATH) |
|
except Exception as e: |
|
_log(f"Parquet read failed ({e}); rebuilding assets.") |
|
_build_assets() |
|
|
|
|
|
_ensure_assets() |
|
|
|
|
|
searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False) |
|
meta = json.load(open(META_PATH)) |
|
_encoder = SentenceTransformer(meta["model_name"]) |
|
|
|
|
|
_gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL) |
|
_gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL) |
|
|
|
|
|
_BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"} |
|
_PUNCT_CHARS = ":;ββ-,.!?ββ\"'`" |
|
_PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]") |
|
|
|
_MIN_WORDS, _MAX_WORDS = 2, 8 |
|
|
|
def _load_prompt(): |
|
if os.path.exists(PROMPT_PATH): |
|
with open(PROMPT_PATH, "r", encoding="utf-8") as f: |
|
return f.read() |
|
return ( |
|
"You are a professional slogan writer.\n" |
|
"Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n" |
|
"Do not copy examples.\n" |
|
"Description:\n{description}\nSlogan:" |
|
) |
|
|
|
def _render_prompt(description: str, retrieved=None) -> str: |
|
tmpl = _load_prompt() |
|
if "{description}" in tmpl: |
|
prompt = tmpl.replace("{description}", description) |
|
else: |
|
prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:" |
|
if retrieved: |
|
prompt += "\n\nDo NOT copy these existing slogans:\n" |
|
for s in retrieved[:3]: |
|
prompt += f"- {s}\n" |
|
return prompt |
|
|
|
def _title_case(s: str) -> str: |
|
small = {"and","or","for","of","the","to","in","on","with","a","an"} |
|
words = [w for w in s.split() if w] |
|
out = [] |
|
for i,w in enumerate(words): |
|
lw = w.lower() |
|
if i>0 and lw in small: out.append(lw) |
|
else: out.append(lw.capitalize()) |
|
return " ".join(out) |
|
|
|
def _strip_punct(s: str) -> str: |
|
return _PUNCT_RE.sub("", s) |
|
|
|
def _strict_ok(s: str) -> bool: |
|
if not s: return False |
|
wc = len(s.split()) |
|
if wc < _MIN_WORDS or wc > _MAX_WORDS: return False |
|
lo = s.lower() |
|
if any(term in lo for term in _BANNED_TERMS): return False |
|
if lo in {"the","a","an"}: return False |
|
return True |
|
|
|
def _postprocess_strict(texts): |
|
cleaned, seen = [], set() |
|
for t in texts: |
|
s = t.replace("Slogan:", "").strip().strip('"').strip("'") |
|
s = " ".join(s.split()) |
|
s = _strip_punct(s) |
|
s = _title_case(s) |
|
if _strict_ok(s): |
|
k = s.lower() |
|
if k not in seen: |
|
seen.add(k); cleaned.append(s) |
|
return cleaned |
|
|
|
def _postprocess_relaxed(texts): |
|
|
|
cleaned, seen = [], set() |
|
for t in texts: |
|
s = t.strip().strip('"').strip("'") |
|
s = _strip_punct(s) |
|
s = " ".join(s.split()) |
|
wc = len(s.split()) |
|
if _MIN_WORDS <= wc <= _MAX_WORDS: |
|
s = _title_case(s) |
|
k = s.lower() |
|
if k not in seen: |
|
seen.add(k); cleaned.append(s) |
|
return cleaned |
|
|
|
def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES): |
|
prompt = _render_prompt(description, retrieved_texts) |
|
|
|
|
|
bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids |
|
|
|
inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True) |
|
outputs = _gen_model.generate( |
|
**inputs, |
|
do_sample=True, |
|
temperature=TEMPERATURE, |
|
top_p=TOP_P, |
|
num_return_sequences=n, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
no_repeat_ngram_size=3, |
|
repetition_penalty=REPETITION_PENALTY, |
|
bad_words_ids=bad_ids if bad_ids else None, |
|
eos_token_id=_gen_tokenizer.eos_token_id, |
|
) |
|
texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
cands = _postprocess_strict(texts) |
|
if not cands: |
|
cands = _postprocess_relaxed(texts) |
|
return cands |
|
|
|
def _pick_best(candidates, retrieved_texts, description): |
|
"""Weighted relevance to description minus duplication vs retrieved.""" |
|
if not candidates: |
|
return None |
|
c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True) |
|
d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0] |
|
rel = c_emb @ d_emb |
|
|
|
if retrieved_texts: |
|
R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True) |
|
dup = np.max(R @ c_emb.T, axis=0) |
|
else: |
|
dup = np.zeros(len(candidates), dtype=np.float32) |
|
|
|
|
|
mask = dup < DUPLICATE_MAX_SIM |
|
if mask.any(): |
|
scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask] |
|
best_idx = np.argmax(scores) |
|
return [c for i, c in enumerate(candidates) if mask[i]][best_idx] |
|
|
|
|
|
scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup |
|
order = np.argsort(-scores) |
|
for i in order: |
|
if dup[i] < NOVELTY_SIM_THRESHOLD: |
|
return candidates[i] |
|
return candidates[order[0]] |
|
|
|
|
|
def run_pipeline(user_description: str): |
|
if not user_description or not user_description.strip(): |
|
return "Please enter a description." |
|
retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10) |
|
retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else [] |
|
gens = _generate_candidates(user_description, retrieved_texts, NUM_GEN_CANDIDATES) |
|
chosen = _pick_best(gens, retrieved_texts, user_description) or (gens[0] if gens else "β") |
|
lines = [] |
|
lines.append("### π Top 3 similar slogans") |
|
if retrieved_texts: |
|
for i, s in enumerate(retrieved_texts, 1): |
|
lines.append(f"{i}. {s}") |
|
else: |
|
lines.append("No similar slogans found.") |
|
lines.append("\n### β¨ AI-generated suggestion") |
|
lines.append(chosen) |
|
return "\n".join(lines) |
|
|
|
|
|
with gr.Blocks(title="Slogan Finder") as demo: |
|
gr.Markdown("# π Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.") |
|
query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...") |
|
btn = gr.Button("Get slogans", variant="primary") |
|
out = gr.Markdown() |
|
btn.click(run_pipeline, inputs=[query], outputs=out) |
|
|
|
demo.queue(max_size=64).launch() |
|
|
|
|