File size: 10,030 Bytes
06a5663
 
c17e99d
06a5663
 
 
73f5c98
b8397a5
06a5663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
da49eac
c17e99d
06a5663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c17e99d
06a5663
 
 
 
 
 
 
 
c17e99d
06a5663
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c17e99d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
\
import os, json, numpy as np, pandas as pd
import gradio as gr
import faiss
import re
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

from logic.cleaning import clean_dataframe
from logic.search import SloganSearcher

# -------------------- Config --------------------
ASSETS_DIR   = "assets"
DATA_PATH    = "data/slogan.csv"
PROMPT_PATH  = "data/prompt.txt"

MODEL_NAME   = "sentence-transformers/all-MiniLM-L6-v2"
NORMALIZE    = True

GEN_MODEL    = "google/flan-t5-base"
NUM_GEN_CANDIDATES = 12
MAX_NEW_TOKENS     = 18
TEMPERATURE        = 0.7
TOP_P              = 0.9
REPETITION_PENALTY = 1.15

# choose the most relevant yet non-duplicate candidate
RELEVANCE_WEIGHT   = 0.7
NOVELTY_WEIGHT     = 0.3
DUPLICATE_MAX_SIM  = 0.92
NOVELTY_SIM_THRESHOLD = 0.80  # keep some distance from retrieved

META_PATH    = os.path.join(ASSETS_DIR, "meta.json")
PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet")
INDEX_PATH   = os.path.join(ASSETS_DIR, "faiss.index")
EMB_PATH     = os.path.join(ASSETS_DIR, "embeddings.npy")

def _log(m): print(f"[SLOGAN-SPACE] {m}", flush=True)

# -------------------- Asset build --------------------
def _build_assets():
    if not os.path.exists(DATA_PATH):
        raise FileNotFoundError(f"Dataset not found at {DATA_PATH} (CSV with columns: 'tagline', 'description').")
    os.makedirs(ASSETS_DIR, exist_ok=True)

    _log(f"Loading dataset: {DATA_PATH}")
    df = pd.read_csv(DATA_PATH)

    _log(f"Rows before cleaning: {len(df)}")
    df = clean_dataframe(df)
    _log(f"Rows after cleaning: {len(df)}")

    if "description" in df.columns and df["description"].notna().any():
        texts = df["description"].fillna(df["tagline"]).astype(str).tolist()
        text_col, fallback_col = "description", "tagline"
    else:
        texts = df["tagline"].astype(str).tolist()
        text_col, fallback_col = "tagline", "tagline"

    _log(f"Encoding with {MODEL_NAME} (normalize={NORMALIZE}) …")
    encoder = SentenceTransformer(MODEL_NAME)
    emb = encoder.encode(texts, batch_size=64, convert_to_numpy=True, normalize_embeddings=NORMALIZE)

    dim = emb.shape[1]
    index = faiss.IndexFlatIP(dim) if NORMALIZE else faiss.IndexFlatL2(dim)
    index.add(emb)

    _log("Persisting assets …")
    df.to_parquet(PARQUET_PATH, index=False)
    faiss.write_index(index, INDEX_PATH)
    np.save(EMB_PATH, emb)

    meta = {
        "model_name": MODEL_NAME,
        "dim": int(dim),
        "normalized": NORMALIZE,
        "metric": "ip" if NORMALIZE else "l2",
        "row_count": int(len(df)),
        "text_col": text_col,
        "fallback_col": fallback_col,
    }
    with open(META_PATH, "w") as f:
        json.dump(meta, f, indent=2)
    _log("Assets built successfully.")

def _ensure_assets():
    need = False
    for p in (META_PATH, PARQUET_PATH, INDEX_PATH):
        if not os.path.exists(p):
            _log(f"Missing asset: {p}")
            need = True
    if need:
        _log("Building assets from scratch …")
        _build_assets()
        return
    try:
        pd.read_parquet(PARQUET_PATH)
    except Exception as e:
        _log(f"Parquet read failed ({e}); rebuilding assets.")
        _build_assets()

# Build before UI
_ensure_assets()

# -------------------- Retrieval --------------------
searcher = SloganSearcher(assets_dir=ASSETS_DIR, use_rerank=False)
meta     = json.load(open(META_PATH))
_encoder = SentenceTransformer(meta["model_name"])

# -------------------- Generator --------------------
_gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL)
_gen_model     = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL)

# keep this list small so we don't nuke relevant outputs
_BANNED_TERMS = {"portal", "e-commerce", "ecommerce", "shopping", "shop"}
_PUNCT_CHARS = ":;—–-,.!?β€œβ€\"'`"
_PUNCT_RE = re.compile(f"[{re.escape(_PUNCT_CHARS)}]")

_MIN_WORDS, _MAX_WORDS = 2, 8

def _load_prompt():
    if os.path.exists(PROMPT_PATH):
        with open(PROMPT_PATH, "r", encoding="utf-8") as f:
            return f.read()
    return (
        "You are a professional slogan writer.\n"
        "Write ONE original startup slogan under 8 words, Title Case, no punctuation.\n"
        "Do not copy examples.\n"
        "Description:\n{description}\nSlogan:"
    )

def _render_prompt(description: str, retrieved=None) -> str:
    tmpl = _load_prompt()
    if "{description}" in tmpl:
        prompt = tmpl.replace("{description}", description)
    else:
        prompt = f"{tmpl}\n\nDescription:\n{description}\nSlogan:"
    if retrieved:
        prompt += "\n\nDo NOT copy these existing slogans:\n"
        for s in retrieved[:3]:
            prompt += f"- {s}\n"
    return prompt

def _title_case(s: str) -> str:
    small = {"and","or","for","of","the","to","in","on","with","a","an"}
    words = [w for w in s.split() if w]
    out = []
    for i,w in enumerate(words):
        lw = w.lower()
        if i>0 and lw in small: out.append(lw)
        else: out.append(lw.capitalize())
    return " ".join(out)

def _strip_punct(s: str) -> str:
    return _PUNCT_RE.sub("", s)

def _strict_ok(s: str) -> bool:
    if not s: return False
    wc = len(s.split())
    if wc < _MIN_WORDS or wc > _MAX_WORDS: return False
    lo = s.lower()
    if any(term in lo for term in _BANNED_TERMS): return False
    if lo in {"the","a","an"}: return False
    return True

def _postprocess_strict(texts):
    cleaned, seen = [], set()
    for t in texts:
        s = t.replace("Slogan:", "").strip().strip('"').strip("'")
        s = " ".join(s.split())
        s = _strip_punct(s)          # remove punctuation instead of rejecting
        s = _title_case(s)
        if _strict_ok(s):
            k = s.lower()
            if k not in seen:
                seen.add(k); cleaned.append(s)
    return cleaned

def _postprocess_relaxed(texts):
    # fallback if strict returns nothing: keep 2–8 words, strip punctuation, Title Case
    cleaned, seen = [], set()
    for t in texts:
        s = t.strip().strip('"').strip("'")
        s = _strip_punct(s)
        s = " ".join(s.split())
        wc = len(s.split())
        if _MIN_WORDS <= wc <= _MAX_WORDS:
            s = _title_case(s)
            k = s.lower()
            if k not in seen:
                seen.add(k); cleaned.append(s)
    return cleaned

def _generate_candidates(description: str, retrieved_texts, n: int = NUM_GEN_CANDIDATES):
    prompt = _render_prompt(description, retrieved_texts)

    # only block very generic junk at decode time
    bad_ids = _gen_tokenizer(list(_BANNED_TERMS), add_special_tokens=False).input_ids

    inputs = _gen_tokenizer([prompt], return_tensors="pt", padding=True, truncation=True)
    outputs = _gen_model.generate(
        **inputs,
        do_sample=True,
        temperature=TEMPERATURE,
        top_p=TOP_P,
        num_return_sequences=n,
        max_new_tokens=MAX_NEW_TOKENS,
        no_repeat_ngram_size=3,
        repetition_penalty=REPETITION_PENALTY,
        bad_words_ids=bad_ids if bad_ids else None,
        eos_token_id=_gen_tokenizer.eos_token_id,
    )
    texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)

    cands = _postprocess_strict(texts)
    if not cands:
        cands = _postprocess_relaxed(texts)  # <- graceful fallback
    return cands

def _pick_best(candidates, retrieved_texts, description):
    """Weighted relevance to description minus duplication vs retrieved."""
    if not candidates:
        return None
    c_emb = _encoder.encode(candidates, convert_to_numpy=True, normalize_embeddings=True)
    d_emb = _encoder.encode([description], convert_to_numpy=True, normalize_embeddings=True)[0]
    rel = c_emb @ d_emb  # cosine sim to description

    if retrieved_texts:
        R = _encoder.encode(retrieved_texts, convert_to_numpy=True, normalize_embeddings=True)
        dup = np.max(R @ c_emb.T, axis=0)  # max sim to any retrieved
    else:
        dup = np.zeros(len(candidates), dtype=np.float32)

    # penalize near-duplicates outright
    mask = dup < DUPLICATE_MAX_SIM
    if mask.any():
        scores = RELEVANCE_WEIGHT * rel[mask] - NOVELTY_WEIGHT * dup[mask]
        best_idx = np.argmax(scores)
        return [c for i, c in enumerate(candidates) if mask[i]][best_idx]

    # else: pick most relevant that still clears a basic novelty bar, else top score
    scores = RELEVANCE_WEIGHT * rel - NOVELTY_WEIGHT * dup
    order = np.argsort(-scores)
    for i in order:
        if dup[i] < NOVELTY_SIM_THRESHOLD:
            return candidates[i]
    return candidates[order[0]]

# -------------------- Inference pipeline --------------------
def run_pipeline(user_description: str):
    if not user_description or not user_description.strip():
        return "Please enter a description."
    retrieved_df = searcher.search(user_description, top_k=3, rerank_top_n=10)
    retrieved_texts = retrieved_df["display"].tolist() if not retrieved_df.empty else []
    gens = _generate_candidates(user_description, retrieved_texts, NUM_GEN_CANDIDATES)
    chosen = _pick_best(gens, retrieved_texts, user_description) or (gens[0] if gens else "β€”")
    lines = []
    lines.append("### πŸ”Ž Top 3 similar slogans")
    if retrieved_texts:
        for i, s in enumerate(retrieved_texts, 1):
            lines.append(f"{i}. {s}")
    else:
        lines.append("No similar slogans found.")
    lines.append("\n### ✨ AI-generated suggestion")
    lines.append(chosen)
    return "\n".join(lines)

# -------------------- UI --------------------
with gr.Blocks(title="Slogan Finder") as demo:
    gr.Markdown("# πŸ”Ž Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.")
    query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...")
    btn = gr.Button("Get slogans", variant="primary")
    out = gr.Markdown()
    btn.click(run_pipeline, inputs=[query], outputs=out)

demo.queue(max_size=64).launch()