resolverkatla commited on
Commit
8450003
Β·
verified Β·
1 Parent(s): b89ac72

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -31
app.py CHANGED
@@ -1,19 +1,17 @@
1
- import gradio as gr
2
  import re
3
- from datasets import load_dataset
 
4
  from sklearn.pipeline import make_pipeline
5
  from sklearn.feature_extraction.text import TfidfVectorizer
6
  from sklearn.linear_model import LogisticRegression
7
- from sklearn.model_selection import train_test_split
8
  from sklearn.metrics import accuracy_score
9
  from collections import Counter
10
 
11
- # 1. Load dataset
12
- dataset = load_dataset("sms_spam", split="train")
13
- texts = dataset["sms"]
14
- labels = [1 if label == "spam" else 0 for label in dataset["label"]]
15
-
16
- print("Label distribution:", Counter(labels)) # Debug check
17
 
18
  # 2. Clean text
19
  def clean_text(text):
@@ -21,30 +19,26 @@ def clean_text(text):
21
  text = re.sub(r"\W+", " ", text)
22
  return text.strip()
23
 
24
- texts_cleaned = [clean_text(t) for t in texts]
25
 
26
- # 3. Train/test split with stratification
27
  X_train, X_test, y_train, y_test = train_test_split(
28
- texts_cleaned, labels, test_size=0.2, random_state=42, stratify=labels
29
  )
30
 
31
- print("Train labels:", Counter(y_train)) # Debug check
32
- print("Test labels:", Counter(y_test)) # Debug check
33
-
34
- # 4. Build model pipeline
35
  model = make_pipeline(
36
- TfidfVectorizer(ngram_range=(1, 2), stop_words="english", max_df=0.9),
37
  LogisticRegression(max_iter=1000, class_weight="balanced")
38
  )
39
 
40
- # 5. Train model
41
  model.fit(X_train, y_train)
42
 
43
- # 6. Evaluate
44
- y_pred = model.predict(X_test)
45
- print("Validation Accuracy:", accuracy_score(y_test, y_pred))
46
 
47
- # 7. Predict function
48
  def predict_spam(message):
49
  cleaned = clean_text(message)
50
  pred = model.predict([cleaned])[0]
@@ -52,14 +46,11 @@ def predict_spam(message):
52
  label = "🚫 Spam" if pred == 1 else "πŸ“© Not Spam (Ham)"
53
  return f"{label} (Confidence: {prob:.2%})"
54
 
55
- # 8. Gradio app
56
- iface = gr.Interface(
57
  fn=predict_spam,
58
- inputs=gr.Textbox(lines=4, label="Enter your SMS message"),
59
  outputs=gr.Text(label="Prediction"),
60
- title="πŸ“¬ SMS Spam Detector (Improved)",
61
- description="Detect spam in SMS messages using Logistic Regression with TF-IDF bi-grams. Trained on a balanced dataset from Hugging Face."
62
- )
63
-
64
- if __name__ == "__main__":
65
- iface.launch(share=False)
 
1
+ import pandas as pd
2
  import re
3
+ import gradio as gr
4
+ from sklearn.model_selection import train_test_split
5
  from sklearn.pipeline import make_pipeline
6
  from sklearn.feature_extraction.text import TfidfVectorizer
7
  from sklearn.linear_model import LogisticRegression
 
8
  from sklearn.metrics import accuracy_score
9
  from collections import Counter
10
 
11
+ # 1. Load and clean data
12
+ df = pd.read_csv("spam.csv", encoding="latin1")[["v1", "v2"]]
13
+ df.columns = ["label", "text"]
14
+ df["label"] = df["label"].map({"ham": 0, "spam": 1})
 
 
15
 
16
  # 2. Clean text
17
  def clean_text(text):
 
19
  text = re.sub(r"\W+", " ", text)
20
  return text.strip()
21
 
22
+ df["text"] = df["text"].apply(clean_text)
23
 
24
+ # 3. Split data
25
  X_train, X_test, y_train, y_test = train_test_split(
26
+ df["text"], df["label"], test_size=0.2, stratify=df["label"], random_state=42
27
  )
28
 
29
+ # 4. Build and train model
 
 
 
30
  model = make_pipeline(
31
+ TfidfVectorizer(ngram_range=(1, 2), stop_words="english"),
32
  LogisticRegression(max_iter=1000, class_weight="balanced")
33
  )
34
 
 
35
  model.fit(X_train, y_train)
36
 
37
+ # 5. Evaluate
38
+ accuracy = accuracy_score(y_test, model.predict(X_test))
39
+ print(f"Validation Accuracy: {accuracy:.2%}")
40
 
41
+ # 6. Gradio prediction function
42
  def predict_spam(message):
43
  cleaned = clean_text(message)
44
  pred = model.predict([cleaned])[0]
 
46
  label = "🚫 Spam" if pred == 1 else "πŸ“© Not Spam (Ham)"
47
  return f"{label} (Confidence: {prob:.2%})"
48
 
49
+ # 7. Gradio UI
50
+ gr.Interface(
51
  fn=predict_spam,
52
+ inputs=gr.Textbox(lines=4, label="Enter SMS Message"),
53
  outputs=gr.Text(label="Prediction"),
54
+ title="SMS Spam Detector",
55
+ description=f"Detects spam in SMS messages. Trained on uploaded CSV (Accuracy: {accuracy:.2%})."
56
+ ).launch()