resolverkatla commited on
Commit
50944f0
Β·
verified Β·
1 Parent(s): 45c0da1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -13
app.py CHANGED
@@ -1,38 +1,54 @@
1
  import gradio as gr
 
2
  from datasets import load_dataset
3
- from sklearn.feature_extraction.text import TfidfVectorizer
4
- from sklearn.naive_bayes import MultinomialNB
5
  from sklearn.pipeline import make_pipeline
 
 
6
  from sklearn.model_selection import train_test_split
7
  from sklearn.metrics import accuracy_score
8
 
9
  # 1. Load dataset
10
  dataset = load_dataset("ucirvine/sms_spam", split="train")
11
  texts = dataset["sms"]
12
- labels = [1 if label == "spam" else 0 for label in dataset["label"]] # spam=1, ham=0
 
 
 
 
 
 
13
 
14
- # 2. Train/test split
15
- X_train, X_test, y_train, y_test = train_test_split(texts, labels, test_size=0.2, random_state=42)
16
 
17
- # 3. Create model pipeline (TF-IDF + Naive Bayes)
18
- model = make_pipeline(TfidfVectorizer(), MultinomialNB())
 
 
 
 
 
 
19
  model.fit(X_train, y_train)
20
 
21
- # 4. Accuracy for reference
22
  y_pred = model.predict(X_test)
23
  print("Validation Accuracy:", accuracy_score(y_test, y_pred))
24
 
25
- # 5. Gradio interface
26
  def predict_spam(message):
27
- pred = model.predict([message])[0]
28
- return "πŸ“© Not Spam (Ham)" if pred == 0 else "🚫 Spam"
 
 
 
29
 
 
30
  iface = gr.Interface(
31
  fn=predict_spam,
32
  inputs=gr.Textbox(lines=4, label="Enter your SMS message"),
33
  outputs=gr.Text(label="Prediction"),
34
- title="πŸ“¬ SMS Spam Detector",
35
- description="Classifies whether an SMS message is spam or not using a Naive Bayes model."
36
  )
37
 
38
  if __name__ == "__main__":
 
1
  import gradio as gr
2
+ import re
3
  from datasets import load_dataset
 
 
4
  from sklearn.pipeline import make_pipeline
5
+ from sklearn.feature_extraction.text import TfidfVectorizer
6
+ from sklearn.linear_model import LogisticRegression
7
  from sklearn.model_selection import train_test_split
8
  from sklearn.metrics import accuracy_score
9
 
10
  # 1. Load dataset
11
  dataset = load_dataset("ucirvine/sms_spam", split="train")
12
  texts = dataset["sms"]
13
+ labels = [1 if label == "spam" else 0 for label in dataset["label"]]
14
+
15
+ # 2. Clean text
16
+ def clean_text(text):
17
+ text = text.lower()
18
+ text = re.sub(r"\W+", " ", text)
19
+ return text.strip()
20
 
21
+ texts_cleaned = [clean_text(t) for t in texts]
 
22
 
23
+ # 3. Train/test split
24
+ X_train, X_test, y_train, y_test = train_test_split(texts_cleaned, labels, test_size=0.2, random_state=42)
25
+
26
+ # 4. Build model: TF-IDF + Logistic Regression
27
+ model = make_pipeline(
28
+ TfidfVectorizer(ngram_range=(1, 2), stop_words="english", max_df=0.9),
29
+ LogisticRegression(max_iter=1000, class_weight="balanced")
30
+ )
31
  model.fit(X_train, y_train)
32
 
33
+ # 5. Show validation accuracy
34
  y_pred = model.predict(X_test)
35
  print("Validation Accuracy:", accuracy_score(y_test, y_pred))
36
 
37
+ # 6. Prediction function
38
  def predict_spam(message):
39
+ cleaned = clean_text(message)
40
+ pred = model.predict([cleaned])[0]
41
+ prob = model.predict_proba([cleaned])[0][pred]
42
+ label = "🚫 Spam" if pred == 1 else "πŸ“© Not Spam (Ham)"
43
+ return f"{label} (Confidence: {prob:.2%})"
44
 
45
+ # 7. Gradio UI
46
  iface = gr.Interface(
47
  fn=predict_spam,
48
  inputs=gr.Textbox(lines=4, label="Enter your SMS message"),
49
  outputs=gr.Text(label="Prediction"),
50
+ title="πŸ“¬ Improved SMS Spam Detector",
51
+ description="Detects spam in SMS messages using Logistic Regression with TF-IDF bi-grams. Now with higher accuracy!"
52
  )
53
 
54
  if __name__ == "__main__":