JenetGhumman commited on
Commit
f6107f3
·
verified ·
1 Parent(s): 5d2f9b2

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +23 -48
tasks/text.py CHANGED
@@ -1,17 +1,16 @@
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
 
 
4
  from sklearn.metrics import accuracy_score
5
- from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
6
- from transformers import Trainer, TrainingArguments
7
- import torch
8
 
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
  router = APIRouter()
13
 
14
- DESCRIPTION = "DistilBERT Baseline"
15
  ROUTE = "/text"
16
 
17
  @router.post(ROUTE, tags=["Text Task"],
@@ -20,8 +19,9 @@ async def evaluate_text(request: TextEvaluationRequest):
20
  """
21
  Evaluate text classification for climate disinformation detection.
22
 
23
- Current Model: DistilBERT
24
- - Fine-tunes and evaluates a DistilBERT model on the given dataset
 
25
  """
26
  # Get space info
27
  username, space_url = get_space_info()
@@ -49,58 +49,33 @@ async def evaluate_text(request: TextEvaluationRequest):
49
  train_dataset = train_test["train"]
50
  test_dataset = train_test["test"]
51
 
52
- # Tokenizer and model
53
- tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
54
- model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8)
55
-
56
- # Tokenize datasets
57
- def preprocess(examples):
58
- return tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
59
-
60
- train_dataset = train_dataset.map(preprocess, batched=True)
61
- test_dataset = test_dataset.map(preprocess, batched=True)
62
-
63
- train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
64
- test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
65
-
66
- # Training arguments
67
- training_args = TrainingArguments(
68
- output_dir="./results",
69
- evaluation_strategy="epoch",
70
- learning_rate=5e-5,
71
- per_device_train_batch_size=16,
72
- per_device_eval_batch_size=16,
73
- num_train_epochs=3,
74
- weight_decay=0.01,
75
- logging_dir="./logs",
76
- logging_steps=10,
77
- )
78
-
79
- trainer = Trainer(
80
- model=model,
81
- args=training_args,
82
- train_dataset=train_dataset,
83
- eval_dataset=test_dataset,
84
- tokenizer=tokenizer,
85
- )
86
 
87
  # Start tracking emissions
88
  tracker.start()
89
  tracker.start_task("inference")
90
 
91
- # Train and evaluate the model
92
- trainer.train()
93
-
94
- # Perform inference
95
- predictions = trainer.predict(test_dataset).predictions
96
- predictions = torch.argmax(torch.tensor(predictions), axis=1).tolist()
97
- true_labels = test_dataset["label"]
98
 
99
  # Stop tracking emissions
100
  emissions_data = tracker.stop_task()
101
 
102
  # Calculate accuracy
103
- accuracy = accuracy_score(true_labels, predictions)
104
 
105
  # Prepare results dictionary
106
  results = {
 
1
  from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
+ from sklearn.feature_extraction.text import TfidfVectorizer
5
+ from sklearn.naive_bayes import MultinomialNB
6
  from sklearn.metrics import accuracy_score
 
 
 
7
 
8
  from .utils.evaluation import TextEvaluationRequest
9
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
10
 
11
  router = APIRouter()
12
 
13
+ DESCRIPTION = "Naive Bayes Text Classifier"
14
  ROUTE = "/text"
15
 
16
  @router.post(ROUTE, tags=["Text Task"],
 
19
  """
20
  Evaluate text classification for climate disinformation detection.
21
 
22
+ Current Model: Naive Bayes Classifier
23
+ - Uses TF-IDF for text vectorization
24
+ - Trains and evaluates a Multinomial Naive Bayes model
25
  """
26
  # Get space info
27
  username, space_url = get_space_info()
 
49
  train_dataset = train_test["train"]
50
  test_dataset = train_test["test"]
51
 
52
+ # Extract text and labels
53
+ train_texts = [x["text"] for x in train_dataset]
54
+ train_labels = [x["label"] for x in train_dataset]
55
+ test_texts = [x["text"] for x in test_dataset]
56
+ test_labels = [x["label"] for x in test_dataset]
57
+
58
+ # TF-IDF Vectorization
59
+ vectorizer = TfidfVectorizer(max_features=5000)
60
+ train_vectors = vectorizer.fit_transform(train_texts)
61
+ test_vectors = vectorizer.transform(test_texts)
62
+
63
+ # Train Naive Bayes Classifier
64
+ model = MultinomialNB()
65
+ model.fit(train_vectors, train_labels)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  # Start tracking emissions
68
  tracker.start()
69
  tracker.start_task("inference")
70
 
71
+ # Inference
72
+ predictions = model.predict(test_vectors)
 
 
 
 
 
73
 
74
  # Stop tracking emissions
75
  emissions_data = tracker.stop_task()
76
 
77
  # Calculate accuracy
78
+ accuracy = accuracy_score(test_labels, predictions)
79
 
80
  # Prepare results dictionary
81
  results = {