JenetGhumman commited on
Commit
860f09c
·
verified ·
1 Parent(s): 4c6dd48

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +85 -35
tasks/text.py CHANGED
@@ -2,31 +2,28 @@ from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.feature_extraction.text import TfidfVectorizer
 
5
  from sklearn.svm import SVC
6
  from sklearn.metrics import accuracy_score
7
 
8
  from .utils.evaluation import TextEvaluationRequest
9
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
10
 
 
11
  router = APIRouter()
12
 
13
- DESCRIPTION = "SVM Text Classifier with TF-IDF"
14
- ROUTE = "/text_svm"
15
 
16
- @router.post(ROUTE, tags=["Text Task"],
17
- description=DESCRIPTION)
18
- async def evaluate_text_svm(request: TextEvaluationRequest):
19
  """
20
- Evaluate text classification for climate disinformation detection.
21
-
22
- Current Model: SVM Classifier
23
- - Uses TF-IDF for text vectorization
24
- - Trains and evaluates a Support Vector Machine (SVM) model
25
  """
26
- # Get space info
27
  username, space_url = get_space_info()
28
 
29
- # Define the label mapping
30
  LABEL_MAPPING = {
31
  "0_not_relevant": 0,
32
  "1_not_happening": 1,
@@ -38,22 +35,82 @@ async def evaluate_text_svm(request: TextEvaluationRequest):
38
  "7_fossil_fuels_needed": 7
39
  }
40
 
41
- # Load and prepare the dataset
42
  dataset = load_dataset(request.dataset_name)
43
-
44
- # Convert string labels to integers
45
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
46
 
47
- # Split dataset
48
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
49
- train_dataset = train_test["train"]
50
- test_dataset = train_test["test"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Extract text and labels
53
- train_texts = [x["text"] for x in train_dataset]
54
- train_labels = [x["label"] for x in train_dataset]
55
- test_texts = [x["text"] for x in test_dataset]
56
- test_labels = [x["label"] for x in test_dataset]
 
 
 
 
 
57
 
58
  # TF-IDF Vectorization
59
  vectorizer = TfidfVectorizer(max_features=5000)
@@ -64,35 +121,28 @@ async def evaluate_text_svm(request: TextEvaluationRequest):
64
  model = SVC(kernel="linear", probability=True)
65
  model.fit(train_vectors, train_labels)
66
 
67
- # Start tracking emissions
68
  tracker.start()
69
  tracker.start_task("inference")
70
-
71
- # Inference
72
  predictions = model.predict(test_vectors)
73
-
74
- # Stop tracking emissions
75
  emissions_data = tracker.stop_task()
76
 
77
- # Calculate accuracy
78
  accuracy = accuracy_score(test_labels, predictions)
79
 
80
- # Prepare results dictionary
81
- results = {
82
  "username": username,
83
  "space_url": space_url,
84
  "submission_timestamp": datetime.now().isoformat(),
85
- "model_description": DESCRIPTION,
86
  "accuracy": float(accuracy),
87
  "energy_consumed_wh": emissions_data.energy_consumed * 1000,
88
  "emissions_gco2eq": emissions_data.emissions * 1000,
89
  "emissions_data": clean_emissions_data(emissions_data),
90
- "api_route": ROUTE,
91
  "dataset_config": {
92
  "dataset_name": request.dataset_name,
93
  "test_size": request.test_size,
94
  "test_seed": request.test_seed
95
  }
96
  }
97
-
98
- return results
 
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.feature_extraction.text import TfidfVectorizer
5
+ from sklearn.naive_bayes import MultinomialNB
6
  from sklearn.svm import SVC
7
  from sklearn.metrics import accuracy_score
8
 
9
  from .utils.evaluation import TextEvaluationRequest
10
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
+ # Define the router for text tasks
13
  router = APIRouter()
14
 
15
+ DESCRIPTION_NAIVE_BAYES = "Naive Bayes Text Classifier"
16
+ DESCRIPTION_SVM = "SVM Text Classifier with TF-IDF"
17
 
18
+ # Naive Bayes Endpoint
19
+ @router.post("/text", tags=["Text Task"], description=DESCRIPTION_NAIVE_BAYES)
20
+ async def evaluate_text(request: TextEvaluationRequest):
21
  """
22
+ Evaluate text classification using Naive Bayes.
 
 
 
 
23
  """
 
24
  username, space_url = get_space_info()
25
 
26
+ # Label Mapping
27
  LABEL_MAPPING = {
28
  "0_not_relevant": 0,
29
  "1_not_happening": 1,
 
35
  "7_fossil_fuels_needed": 7
36
  }
37
 
38
+ # Load and prepare dataset
39
  dataset = load_dataset(request.dataset_name)
 
 
40
  dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
41
 
42
+ # Train-Test Split
43
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
44
+ train_texts = [x["text"] for x in train_test["train"]]
45
+ train_labels = [x["label"] for x in train_test["train"]]
46
+ test_texts = [x["text"] for x in train_test["test"]]
47
+ test_labels = [x["label"] for x in train_test["test"]]
48
+
49
+ # TF-IDF Vectorization
50
+ vectorizer = TfidfVectorizer(max_features=5000)
51
+ train_vectors = vectorizer.fit_transform(train_texts)
52
+ test_vectors = vectorizer.transform(test_texts)
53
+
54
+ # Train Naive Bayes Classifier
55
+ model = MultinomialNB()
56
+ model.fit(train_vectors, train_labels)
57
+
58
+ # Track emissions
59
+ tracker.start()
60
+ tracker.start_task("inference")
61
+ predictions = model.predict(test_vectors)
62
+ emissions_data = tracker.stop_task()
63
+
64
+ # Calculate Accuracy
65
+ accuracy = accuracy_score(test_labels, predictions)
66
+
67
+ return {
68
+ "username": username,
69
+ "space_url": space_url,
70
+ "submission_timestamp": datetime.now().isoformat(),
71
+ "model_description": DESCRIPTION_NAIVE_BAYES,
72
+ "accuracy": float(accuracy),
73
+ "energy_consumed_wh": emissions_data.energy_consumed * 1000,
74
+ "emissions_gco2eq": emissions_data.emissions * 1000,
75
+ "emissions_data": clean_emissions_data(emissions_data),
76
+ "api_route": "/text",
77
+ "dataset_config": {
78
+ "dataset_name": request.dataset_name,
79
+ "test_size": request.test_size,
80
+ "test_seed": request.test_seed
81
+ }
82
+ }
83
+
84
+ # SVM Endpoint
85
+ @router.post("/text_svm", tags=["Text Task"], description=DESCRIPTION_SVM)
86
+ async def evaluate_text_svm(request: TextEvaluationRequest):
87
+ """
88
+ Evaluate text classification using SVM.
89
+ """
90
+ username, space_url = get_space_info()
91
+
92
+ # Label Mapping
93
+ LABEL_MAPPING = {
94
+ "0_not_relevant": 0,
95
+ "1_not_happening": 1,
96
+ "2_not_human": 2,
97
+ "3_not_bad": 3,
98
+ "4_solutions_harmful_unnecessary": 4,
99
+ "5_science_unreliable": 5,
100
+ "6_proponents_biased": 6,
101
+ "7_fossil_fuels_needed": 7
102
+ }
103
 
104
+ # Load and prepare dataset
105
+ dataset = load_dataset(request.dataset_name)
106
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
107
+
108
+ # Train-Test Split
109
+ train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
110
+ train_texts = [x["text"] for x in train_test["train"]]
111
+ train_labels = [x["label"] for x in train_test["train"]]
112
+ test_texts = [x["text"] for x in train_test["test"]]
113
+ test_labels = [x["label"] for x in train_test["test"]]
114
 
115
  # TF-IDF Vectorization
116
  vectorizer = TfidfVectorizer(max_features=5000)
 
121
  model = SVC(kernel="linear", probability=True)
122
  model.fit(train_vectors, train_labels)
123
 
124
+ # Track emissions
125
  tracker.start()
126
  tracker.start_task("inference")
 
 
127
  predictions = model.predict(test_vectors)
 
 
128
  emissions_data = tracker.stop_task()
129
 
130
+ # Calculate Accuracy
131
  accuracy = accuracy_score(test_labels, predictions)
132
 
133
+ return {
 
134
  "username": username,
135
  "space_url": space_url,
136
  "submission_timestamp": datetime.now().isoformat(),
137
+ "model_description": DESCRIPTION_SVM,
138
  "accuracy": float(accuracy),
139
  "energy_consumed_wh": emissions_data.energy_consumed * 1000,
140
  "emissions_gco2eq": emissions_data.emissions * 1000,
141
  "emissions_data": clean_emissions_data(emissions_data),
142
+ "api_route": "/text_svm",
143
  "dataset_config": {
144
  "dataset_name": request.dataset_name,
145
  "test_size": request.test_size,
146
  "test_seed": request.test_seed
147
  }
148
  }