yair319732 commited on
Commit
c9eabf9
·
verified ·
1 Parent(s): 7b4250d

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. app.py +21 -14
  2. data/prompt.txt +6 -0
  3. logic/cleaning.py +2 -2
app.py CHANGED
@@ -11,6 +11,7 @@ from logic.search import SloganSearcher
11
 
12
  ASSETS_DIR = "assets"
13
  DATA_PATH = "data/slogan.csv"
 
14
 
15
  MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
16
  NORMALIZE = True
@@ -20,7 +21,7 @@ NUM_GEN_CANDIDATES = 6
20
  MAX_NEW_TOKENS = 24
21
  TEMPERATURE = 0.9
22
  TOP_P = 0.95
23
- NOVELTY_SIM_THRESHOLD = 0.80
24
 
25
  META_PATH = os.path.join(ASSETS_DIR, "meta.json")
26
  PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet")
@@ -100,16 +101,21 @@ _encoder = SentenceTransformer(meta["model_name"])
100
  _gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
101
  _gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
102
 
103
- def _prompt_for(description: str) -> str:
104
- return (
105
- "You are a professional slogan writer. "
106
- "Write ONE original, catchy startup slogan under 8 words, Title Case, no punctuation. "
107
- "Do not copy examples. Description:\\n"
108
- f"{description}\\nSlogan:"
109
- )
 
 
 
 
 
110
 
111
  def _generate_candidates(description: str, n: int = NUM_GEN_CANDIDATES):
112
- prompt = _prompt_for(description)
113
  inputs = _gen_tokenizer([prompt]*n, return_tensors="pt", padding=True, truncation=True)
114
  outputs = _gen_model.generate(
115
  **inputs,
@@ -121,7 +127,7 @@ def _generate_candidates(description: str, n: int = NUM_GEN_CANDIDATES):
121
  eos_token_id=_gen_tokenizer.eos_token_id,
122
  )
123
  texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
124
- return [t.replace("Slogan:", "").strip().strip('"') for t in texts if t.strip()]
125
 
126
  def _pick_most_novel(candidates, retrieved_texts):
127
  if not candidates:
@@ -136,7 +142,8 @@ def _pick_most_novel(candidates, retrieved_texts):
136
  sims = np.dot(R, c_emb[0]) # cosine
137
  max_sim = float(np.max(sims))
138
  novelty = 1.0 - max_sim
139
- if (max_sim < {0.80} and novelty > best_novelty) or best is None and novelty > best_novelty:
 
140
  best, best_novelty = c, novelty
141
  return best
142
 
@@ -154,12 +161,12 @@ def run_pipeline(user_description: str):
154
  lines.append(f"{i}. {s}")
155
  else:
156
  lines.append("_No similar slogans found._")
157
- lines.append("\\n### ✨ AI-generated suggestion")
158
  lines.append(generated)
159
- return "\\n".join(lines)
160
 
161
  with gr.Blocks(title="Slogan Finder") as demo:
162
- gr.Markdown("# 🔎 Slogan Finder\\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.")
163
  query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...")
164
  btn = gr.Button("Get slogans", variant="primary")
165
  out = gr.Markdown()
 
11
 
12
  ASSETS_DIR = "assets"
13
  DATA_PATH = "data/slogan.csv"
14
+ PROMPT_PATH= "data/prompt.txt"
15
 
16
  MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
17
  NORMALIZE = True
 
21
  MAX_NEW_TOKENS = 24
22
  TEMPERATURE = 0.9
23
  TOP_P = 0.95
24
+ NOVELTY_SIM_THRESHOLD = 0.80 # <-- fixed: use a float, not a set
25
 
26
  META_PATH = os.path.join(ASSETS_DIR, "meta.json")
27
  PARQUET_PATH = os.path.join(ASSETS_DIR, "slogans_clean.parquet")
 
101
  _gen_tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
102
  _gen_model = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
103
 
104
+ # ----- Load your prompt from file -----
105
+ def _load_prompt():
106
+ if os.path.exists(PROMPT_PATH):
107
+ with open(PROMPT_PATH, "r", encoding="utf-8") as f:
108
+ return f.read()
109
+ # Fallback (shouldn't happen since we write it in Colab)
110
+ return "Write a short startup slogan for:\n{description}\nSlogan:"
111
+
112
+ def _render_prompt(description: str) -> str:
113
+ tmpl = _load_prompt()
114
+ # Support {description} placeholder; leave other braces untouched
115
+ return tmpl.replace("{description}", description)
116
 
117
  def _generate_candidates(description: str, n: int = NUM_GEN_CANDIDATES):
118
+ prompt = _render_prompt(description)
119
  inputs = _gen_tokenizer([prompt]*n, return_tensors="pt", padding=True, truncation=True)
120
  outputs = _gen_model.generate(
121
  **inputs,
 
127
  eos_token_id=_gen_tokenizer.eos_token_id,
128
  )
129
  texts = _gen_tokenizer.batch_decode(outputs, skip_special_tokens=True)
130
+ return [t.strip().strip('"') for t in texts if t.strip()]
131
 
132
  def _pick_most_novel(candidates, retrieved_texts):
133
  if not candidates:
 
142
  sims = np.dot(R, c_emb[0]) # cosine
143
  max_sim = float(np.max(sims))
144
  novelty = 1.0 - max_sim
145
+ # FIXED: compare to float threshold
146
+ if ((max_sim < NOVELTY_SIM_THRESHOLD) and (novelty > best_novelty)) or (best is None and novelty > best_novelty):
147
  best, best_novelty = c, novelty
148
  return best
149
 
 
161
  lines.append(f"{i}. {s}")
162
  else:
163
  lines.append("_No similar slogans found._")
164
+ lines.append("\n### ✨ AI-generated suggestion")
165
  lines.append(generated)
166
+ return "\n".join(lines)
167
 
168
  with gr.Blocks(title="Slogan Finder") as demo:
169
+ gr.Markdown("# 🔎 Slogan Finder\nDescribe your product/company; get 3 similar slogans + 1 AI-generated suggestion.")
170
  query = gr.Textbox(label="Describe your product/company", placeholder="AI-powered patient financial navigation platform...")
171
  btn = gr.Button("Get slogans", variant="primary")
172
  out = gr.Markdown()
data/prompt.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ You are a professional slogan writer.
2
+ Write ONE original startup slogan under 8 words, Title Case, no punctuation.
3
+ Do not copy examples.
4
+ Description:
5
+ {description}
6
+ Slogan:
logic/cleaning.py CHANGED
@@ -60,8 +60,8 @@ def _ascii_only(s: str) -> bool:
60
 
61
  def _dupe_key(s: str) -> str:
62
  s = s.lower()
63
- s = PUNCT_RE.sub(" ", s)
64
- s = WS_RE.sub(" ", s).strip()
65
  return s
66
 
67
  def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame:
 
60
 
61
  def _dupe_key(s: str) -> str:
62
  s = s.lower()
63
+ s = re.sub(r"[^\\w\\s]+", " ", s)
64
+ s = re.sub(r"\\s+", " ", s).strip()
65
  return s
66
 
67
  def clean_dataframe(df: pd.DataFrame) -> pd.DataFrame: