\ import os, json, numpy as np, pandas as pd import gradio as gr import faiss import re from sentence_transformers import SentenceTransformer from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from logic.cleaning import clean_dataframe from logic.search import SloganSearcher # -------------------- Config -------------------- ASSETS_DIR = "assets" DATA_PATH = "data/slogan.csv" PROMPT_PATH = "data/prompt.txt" MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" NORMALIZE = True GEN_MODEL = "google/flan-t5-base" NUM_GEN_CANDIDATES = 12 MAX_NEW_TOKENS = 18 TEMPERATURE = 0.7 TOP_P = 0.9 REPETITION_PENALTY = 1.15 # choose the most relevant yet non-duplicate candidate RELEVANCE_WEIGHT = 0.7 NOVELTY_WEIGHT = 0.3 DUPLICATE_MAX_SIM = 0.92 NOVELTY_SIM_THRESHOLD = 0.80 # keep some distance from retrieved META_PATH = os.path.join(ASSETS_DIR, "meta.json") PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet") INDEX_PATH = os.path.join(ASSETS_DIR, "faiss.index") EMB_PATH = os.path.join(ASSETS_DIR, "embeddings.npy") def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True) # -------------------- Asset build -------------------- def _build_assets(): if not os.path.exists(DATA_PATH): raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').") os.makedirs(ASSETS_DIR, exist_ok=True) _log(f"Loading dataset: {DATA_PATH}") df = pd.read_csv(DATA_PATH) _log(f"Rows before cleaning: {len(df)}") df = clean_dataframe(df) _log(f"Rows after cleaning: {len(df)}") if "description" in df.columns and df["description"].notna().any(): texts = df["description"].fillna(df["tagline"]).astype(str).tolist() text_col, fallback_col = "description", "tagline" else: texts = df["tagline"].astype(str).tolist() text_col, fallback_col = "tagline", "tagline" _log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) …") encoder = SentenceTransformer(MODEL_NAME) emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE) dim = emb.shape[1] index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim) index.add(emb) _log("Persisting assets …") df.to_parquet(PARQUET_PATH, index=False) faiss.write_index(index, INDEX_PATH) np.save(EMB_PATH, emb) meta = { "model_name": MODEL_NAME, "dim": int(dim), "normalized": NORMALIZE, "metric": "ip" if NORMALIZE else "l2", "row_count": int(len(df)), "text_col": text_col, "fallback_col": fallback_col, } with open(META_PATH, "w") as f: json.dump(meta, f, indent=2) _log("Assets built successfully.") def _ensure_assets(): need = False for p in (META_PATH, PARQUET_PATH, INDEX_PATH): if not os.path.exists(p): _log(f"Missing asset: {p}") need = True if need: _log("Building assets from scratch …") _build_assets() return try: pd.read_parquet(PARQUET_PATH) except Exception as e: _log(f"Parquet read failed ({e}); rebuilding assets.") _build_assets() # Build before UI _ensure_assets() # -------------------- Retrieval -------------------- searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False) meta = json.load(open(META_PATH)) _encoder = SentenceTransformer(meta["model_name"]) # -------------------- Generator -------------------- _gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL) _gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL) # keep this list small so we don't nuke relevant outputs _BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"} _PUNCT_CHARS = ":;—–-,.!?“”\"'`" _PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]") _MIN_WORDS, _MAX_WORDS = 2, 8 def _load_prompt(): if os.path.exists(PROMPT_PATH): with open(PROMPT_PATH, "r", encoding="utf-8") as f: return f.read() return ( "You are a professional slogan writer.\n" "Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n" "Do not copy examples.\n" "Description:\n{description}\nSlogan:" ) def _render_prompt(description: str, retrieved=None) -> str: tmpl = _load_prompt() if "{description}" in tmpl: prompt = tmpl.replace("{description}", description) else: prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:" if retrieved: prompt += "\n\nDo NOT copy these existing slogans:\n" for s in retrieved[:3]: prompt += f"- {s}\n" return prompt def _title_case(s: str) -> str: small = {"and","or","for","of","the","to","in","on","with","a","an"} words = [w for w in s.split() if w] out = [] for i,w in enumerate(words): lw = w.lower() if i>0 and lw in small: out.append(lw) else: out.append(lw.capitalize()) return " ".join(out) def _strip_punct(s: str) -> str: return _PUNCT_RE.sub("", s) def _strict_ok(s: str) -> bool: if not s: return False wc = len(s.split()) if wc < _MIN_WORDS or wc > _MAX_WORDS: return False lo = s.lower() if any(term in lo for term in _BANNED_TERMS): return False if lo in {"the","a","an"}: return False return True def _postprocess_strict(texts): cleaned, seen = [], set() for t in texts: s = t.replace("Slogan:", "").strip().strip('"').strip("'") s = " ".join(s.split()) s = _strip_punct(s) # remove punctuation instead of rejecting s = _title_case(s) if _strict_ok(s): k = s.lower() if k not in seen: seen.add(k); cleaned.append(s) return cleaned def _postprocess_relaxed(texts): # fallback if strict returns nothing: keep 2–8 words, strip punctuation, Title Case cleaned, seen = [], set() for t in texts: s = t.strip().strip('"').strip("'") s = _strip_punct(s) s = " ".join(s.split()) wc = len(s.split()) if _MIN_WORDS <= wc <= _MAX_WORDS: s = _title_case(s) k = s.lower() if k not in seen: seen.add(k); cleaned.append(s) return cleaned def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES): prompt = _render_prompt(description, retrieved_texts) # only block very generic junk at decode time bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True) outputs = _gen_model.generate( **inputs, do_sample=True, temperature=TEMPERATURE, top_p=TOP_P, num_return_sequences=n, max_new_tokens=MAX_NEW_TOKENS, no_repeat_ngram_size=3, repetition_penalty=REPETITION_PENALTY, bad_words_ids=bad_ids if bad_ids else None, eos_token_id=_gen_tokenizer.eos_token_id, ) texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True) cands = _postprocess_strict(texts) if not cands: cands = _postprocess_relaxed(texts) # <- graceful fallback return cands def _pick_best(candidates, retrieved_texts, description): """Weighted relevance to description minus duplication vs retrieved.""" if not candidates: return None c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True) d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0] rel = c_emb @ d_emb # cosine sim to description if retrieved_texts: R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True) dup = np.max(R @ c_emb.T, axis=0) # max sim to any retrieved else: dup = np.zeros(len(candidates), dtype=np.float32) # penalize near-duplicates outright mask = dup < DUPLICATE_MAX_SIM if mask.any(): scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask] best_idx = np.argmax(scores) return [c for i, c in enumerate(candidates) if mask[i]][best_idx] # else: pick most relevant that still clears a basic novelty bar, else top score scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup order = np.argsort(-scores) for i in order: if dup[i] < NOVELTY_SIM_THRESHOLD: return candidates[i] return candidates[order[0]] # -------------------- Inference pipeline -------------------- def run_pipeline(user_description: str): if not user_description or not user_description.strip(): return "Please enter a description." retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10) retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else [] gens = _generate_candidates(user_description, retrieved_texts, NUM_GEN_CANDIDATES) chosen = _pick_best(gens, retrieved_texts, user_description) or (gens[0] if gens else "—") lines = [] lines.append("### 🔎 Top 3 similar slogans") if retrieved_texts: for i, s in enumerate(retrieved_texts, 1): lines.append(f"{i}. {s}") else: lines.append("No similar slogans found.") lines.append("\n### ✨ AI-generated suggestion") lines.append(chosen) return "\n".join(lines) # -------------------- UI -------------------- with gr.Blocks(title="Slogan Finder") as demo: gr.Markdown("# 🔎 Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.") query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...") btn = gr.Button("Get slogans", variant="primary") out = gr.Markdown() btn.click(run_pipeline, inputs=[query], outputs=out) demo.queue(max_size=64).launch()