JenetGhumman commited on
Commit
e2f75a8
·
verified ·
1 Parent(s): f6107f3

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +8 -8
tasks/text.py CHANGED
@@ -2,7 +2,7 @@ from fastapi import APIRouter
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.feature_extraction.text import TfidfVectorizer
5
- from sklearn.naive_bayes import MultinomialNB
6
  from sklearn.metrics import accuracy_score
7
 
8
  from .utils.evaluation import TextEvaluationRequest
@@ -10,18 +10,18 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
10
 
11
  router = APIRouter()
12
 
13
- DESCRIPTION = "Naive Bayes Text Classifier"
14
- ROUTE = "/text"
15
 
16
  @router.post(ROUTE, tags=["Text Task"],
17
  description=DESCRIPTION)
18
- async def evaluate_text(request: TextEvaluationRequest):
19
  """
20
  Evaluate text classification for climate disinformation detection.
21
 
22
- Current Model: Naive Bayes Classifier
23
  - Uses TF-IDF for text vectorization
24
- - Trains and evaluates a Multinomial Naive Bayes model
25
  """
26
  # Get space info
27
  username, space_url = get_space_info()
@@ -60,8 +60,8 @@ async def evaluate_text(request: TextEvaluationRequest):
60
  train_vectors = vectorizer.fit_transform(train_texts)
61
  test_vectors = vectorizer.transform(test_texts)
62
 
63
- # Train Naive Bayes Classifier
64
- model = MultinomialNB()
65
  model.fit(train_vectors, train_labels)
66
 
67
  # Start tracking emissions
 
2
  from datetime import datetime
3
  from datasets import load_dataset
4
  from sklearn.feature_extraction.text import TfidfVectorizer
5
+ from sklearn.svm import SVC
6
  from sklearn.metrics import accuracy_score
7
 
8
  from .utils.evaluation import TextEvaluationRequest
 
10
 
11
  router = APIRouter()
12
 
13
+ DESCRIPTION = "SVM Text Classifier with TF-IDF"
14
+ ROUTE = "/text_svm"
15
 
16
  @router.post(ROUTE, tags=["Text Task"],
17
  description=DESCRIPTION)
18
+ async def evaluate_text_svm(request: TextEvaluationRequest):
19
  """
20
  Evaluate text classification for climate disinformation detection.
21
 
22
+ Current Model: SVM Classifier
23
  - Uses TF-IDF for text vectorization
24
+ - Trains and evaluates a Support Vector Machine (SVM) model
25
  """
26
  # Get space info
27
  username, space_url = get_space_info()
 
60
  train_vectors = vectorizer.fit_transform(train_texts)
61
  test_vectors = vectorizer.transform(test_texts)
62
 
63
+ # Train SVM Classifier
64
+ model = SVC(kernel="linear", probability=True)
65
  model.fit(train_vectors, train_labels)
66
 
67
  # Start tracking emissions