resolverkatla commited on
Commit
9b36b48
·
1 Parent(s): bdc150f
Files changed (1) hide show
  1. app.py +12 -10
app.py CHANGED
@@ -2,29 +2,31 @@ import gradio as gr
2
  from transformers import pipeline
3
  import pandas as pd
4
 
5
- # Load dataset
6
- DATASET_PATH = "spam.csv"
7
- df = pd.read_csv(DATASET_PATH, encoding="latin1")
8
 
9
  # Load a spam classification model
10
  classifier = pipeline("text-classification", model="mrm8488/bert-tiny-finetuned-sms-spam-detection")
11
 
12
  def spam_detector(text):
 
13
  result = classifier(text)
14
- return "Spam" if result[0]['label'].lower() == "Spam" else "Not Spam"
 
15
 
16
- # Create Gradio UI
17
  app = gr.Interface(
18
  fn=spam_detector,
19
- inputs=gr.Textbox(label="Enter a message"),
20
  outputs=gr.Textbox(label="Prediction"),
21
- title="Spam Detector",
22
- description="Enter a message to check if it's spam or not."
 
23
  )
24
 
25
  # Run the app
26
  if __name__ == "__main__":
27
  print("Loaded dataset preview:")
28
  print(df.head())
29
- app.launch()
30
-
 
2
  from transformers import pipeline
3
  import pandas as pd
4
 
5
+ # Load dataset from Hugging Face Hub
6
+ dataset_path = "hf://datasets/ucirvine/sms_spam/plain_text/train-00000-of-00001.parquet"
7
+ df = pd.read_parquet(dataset_path)
8
 
9
  # Load a spam classification model
10
  classifier = pipeline("text-classification", model="mrm8488/bert-tiny-finetuned-sms-spam-detection")
11
 
12
  def spam_detector(text):
13
+ """Detect if a message is spam or not."""
14
  result = classifier(text)
15
+ label = result[0]['label'].lower()
16
+ return "Spam" if label == "spam" else "Not Spam"
17
 
18
+ # Create Gradio UI with enhanced styling
19
  app = gr.Interface(
20
  fn=spam_detector,
21
+ inputs=gr.Textbox(label="Enter a message", placeholder="Type your message here..."),
22
  outputs=gr.Textbox(label="Prediction"),
23
+ title="AI-Powered Spam Detector",
24
+ description="Enter a message to check if it's spam or not, using a fine-tuned BERT model.",
25
+ theme="huggingface"
26
  )
27
 
28
  # Run the app
29
  if __name__ == "__main__":
30
  print("Loaded dataset preview:")
31
  print(df.head())
32
+ app.launch(server_name="0.0.0.0", server_port=7860, share=True)