sloganAI / app.py
asaf1602's picture
Upload folder using huggingface_hub
06a5663 verified
raw
history blame
10 kB
\
import os, json, numpy as np, pandas as pd
import gradio as gr
import faiss
import re
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from logic.cleaning import clean_dataframe
from logic.search import SloganSearcher
# -------------------- Config --------------------
ASSETS_DIR = "assets"
DATA_PATH = "data/slogan.csv"
PROMPT_PATH = "data/prompt.txt"
MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
NORMALIZE = True
GEN_MODEL = "google/flan-t5-base"
NUM_GEN_CANDIDATES = 12
MAX_NEW_TOKENS = 18
TEMPERATURE = 0.7
TOP_P = 0.9
REPETITION_PENALTY = 1.15
# choose the most relevant yet non-duplicate candidate
RELEVANCE_WEIGHT = 0.7
NOVELTY_WEIGHT = 0.3
DUPLICATE_MAX_SIM = 0.92
NOVELTY_SIM_THRESHOLD = 0.80 # keep some distance from retrieved
META_PATH = os.path.join(ASSETS_DIR, "meta.json")
PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet")
INDEX_PATH = os.path.join(ASSETS_DIR, "faiss.index")
EMB_PATH = os.path.join(ASSETS_DIR, "embeddings.npy")
def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True)
# -------------------- Asset build --------------------
def _build_assets():
if not os.path.exists(DATA_PATH):
raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').")
os.makedirs(ASSETS_DIR, exist_ok=True)
_log(f"Loading dataset: {DATA_PATH}")
df = pd.read_csv(DATA_PATH)
_log(f"Rows before cleaning: {len(df)}")
df = clean_dataframe(df)
_log(f"Rows after cleaning: {len(df)}")
if "description" in df.columns and df["description"].notna().any():
texts = df["description"].fillna(df["tagline"]).astype(str).tolist()
text_col, fallback_col = "description", "tagline"
else:
texts = df["tagline"].astype(str).tolist()
text_col, fallback_col = "tagline", "tagline"
_log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) …")
encoder = SentenceTransformer(MODEL_NAME)
emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE)
dim = emb.shape[1]
index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim)
index.add(emb)
_log("Persisting assets …")
df.to_parquet(PARQUET_PATH, index=False)
faiss.write_index(index, INDEX_PATH)
np.save(EMB_PATH, emb)
meta = {
"model_name": MODEL_NAME,
"dim": int(dim),
"normalized": NORMALIZE,
"metric": "ip" if NORMALIZE else "l2",
"row_count": int(len(df)),
"text_col": text_col,
"fallback_col": fallback_col,
}
with open(META_PATH, "w") as f:
json.dump(meta, f, indent=2)
_log("Assets built successfully.")
def _ensure_assets():
need = False
for p in (META_PATH, PARQUET_PATH, INDEX_PATH):
if not os.path.exists(p):
_log(f"Missing asset: {p}")
need = True
if need:
_log("Building assets from scratch …")
_build_assets()
return
try:
pd.read_parquet(PARQUET_PATH)
except Exception as e:
_log(f"Parquet read failed ({e}); rebuilding assets.")
_build_assets()
# Build before UI
_ensure_assets()
# -------------------- Retrieval --------------------
searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False)
meta = json.load(open(META_PATH))
_encoder = SentenceTransformer(meta["model_name"])
# -------------------- Generator --------------------
_gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
_gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL)
# keep this list small so we don't nuke relevant outputs
_BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"}
_PUNCT_CHARS = ":;—–-,.!?β€œβ€\"'`"
_PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]")
_MIN_WORDS, _MAX_WORDS = 2, 8
def _load_prompt():
if os.path.exists(PROMPT_PATH):
with open(PROMPT_PATH, "r", encoding="utf-8") as f:
return f.read()
return (
"You are a professional slogan writer.\n"
"Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n"
"Do not copy examples.\n"
"Description:\n{description}\nSlogan:"
)
def _render_prompt(description: str, retrieved=None) -> str:
tmpl = _load_prompt()
if "{description}" in tmpl:
prompt = tmpl.replace("{description}", description)
else:
prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:"
if retrieved:
prompt += "\n\nDo NOT copy these existing slogans:\n"
for s in retrieved[:3]:
prompt += f"- {s}\n"
return prompt
def _title_case(s: str) -> str:
small = {"and","or","for","of","the","to","in","on","with","a","an"}
words = [w for w in s.split() if w]
out = []
for i,w in enumerate(words):
lw = w.lower()
if i>0 and lw in small: out.append(lw)
else: out.append(lw.capitalize())
return " ".join(out)
def _strip_punct(s: str) -> str:
return _PUNCT_RE.sub("", s)
def _strict_ok(s: str) -> bool:
if not s: return False
wc = len(s.split())
if wc < _MIN_WORDS or wc > _MAX_WORDS: return False
lo = s.lower()
if any(term in lo for term in _BANNED_TERMS): return False
if lo in {"the","a","an"}: return False
return True
def _postprocess_strict(texts):
cleaned, seen = [], set()
for t in texts:
s = t.replace("Slogan:", "").strip().strip('"').strip("'")
s = " ".join(s.split())
s = _strip_punct(s) # remove punctuation instead of rejecting
s = _title_case(s)
if _strict_ok(s):
k = s.lower()
if k not in seen:
seen.add(k); cleaned.append(s)
return cleaned
def _postprocess_relaxed(texts):
# fallback if strict returns nothing: keep 2–8 words, strip punctuation, Title Case
cleaned, seen = [], set()
for t in texts:
s = t.strip().strip('"').strip("'")
s = _strip_punct(s)
s = " ".join(s.split())
wc = len(s.split())
if _MIN_WORDS <= wc <= _MAX_WORDS:
s = _title_case(s)
k = s.lower()
if k not in seen:
seen.add(k); cleaned.append(s)
return cleaned
def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES):
prompt = _render_prompt(description, retrieved_texts)
# only block very generic junk at decode time
bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids
inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
outputs = _gen_model.generate(
**inputs,
do_sample=True,
temperature=TEMPERATURE,
top_p=TOP_P,
num_return_sequences=n,
max_new_tokens=MAX_NEW_TOKENS,
no_repeat_ngram_size=3,
repetition_penalty=REPETITION_PENALTY,
bad_words_ids=bad_ids if bad_ids else None,
eos_token_id=_gen_tokenizer.eos_token_id,
)
texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
cands = _postprocess_strict(texts)
if not cands:
cands = _postprocess_relaxed(texts) # <- graceful fallback
return cands
def _pick_best(candidates, retrieved_texts, description):
"""Weighted relevance to description minus duplication vs retrieved."""
if not candidates:
return None
c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0]
rel = c_emb @ d_emb # cosine sim to description
if retrieved_texts:
R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True)
dup = np.max(R @ c_emb.T, axis=0) # max sim to any retrieved
else:
dup = np.zeros(len(candidates), dtype=np.float32)
# penalize near-duplicates outright
mask = dup < DUPLICATE_MAX_SIM
if mask.any():
scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask]
best_idx = np.argmax(scores)
return [c for i, c in enumerate(candidates) if mask[i]][best_idx]
# else: pick most relevant that still clears a basic novelty bar, else top score
scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup
order = np.argsort(-scores)
for i in order:
if dup[i] < NOVELTY_SIM_THRESHOLD:
return candidates[i]
return candidates[order[0]]
# -------------------- Inference pipeline --------------------
def run_pipeline(user_description: str):
if not user_description or not user_description.strip():
return "Please enter a description."
retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10)
retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else []
gens = _generate_candidates(user_description, retrieved_texts, NUM_GEN_CANDIDATES)
chosen = _pick_best(gens, retrieved_texts, user_description) or (gens[0] if gens else "β€”")
lines = []
lines.append("### πŸ”Ž Top 3 similar slogans")
if retrieved_texts:
for i, s in enumerate(retrieved_texts, 1):
lines.append(f"{i}. {s}")
else:
lines.append("No similar slogans found.")
lines.append("\n### ✨ AI-generated suggestion")
lines.append(chosen)
return "\n".join(lines)
# -------------------- UI --------------------
with gr.Blocks(title="Slogan Finder") as demo:
gr.Markdown("# πŸ”Ž Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.")
query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...")
btn = gr.Button("Get slogans", variant="primary")
out = gr.Markdown()
btn.click(run_pipeline, inputs=[query], outputs=out)
demo.queue(max_size=64).launch()