Tonic commited on
Commit
b4aa97d
·
unverified ·
1 Parent(s): 2f1ac51

simplify model loading

Browse files
Files changed (1) hide show
  1. tasks/text.py +11 -16
tasks/text.py CHANGED
@@ -59,30 +59,26 @@ async def evaluate_text(request: TextEvaluationRequest):
59
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
 
61
  # Model and tokenizer paths
62
- path_model = 'Tonic/climate-guard-toxic-agent'
63
- path_tokenizer = "Tonic/climate-guard-toxic-agent"
64
 
65
- # Load tokenizer
66
- tokenizer = AutoTokenizer.from_pretrained(path_tokenizer)
67
-
68
- # Load model
69
  model = AutoModelForSequenceClassification.from_pretrained(
70
- path_model,
71
- trust_remote_code=True,
72
- num_labels=8,
73
- problem_type="single_label_classification",
74
- ignore_mismatched_sizes=True
75
  ).to(device)
76
 
77
- # Convert to half precision and eval mode
78
- model = model.half()
79
  model.eval()
80
 
81
  # Preprocess function
82
  def preprocess_function(examples):
83
  return tokenizer(
84
  examples["quote"],
 
85
  truncation=True,
 
86
  return_tensors=None
87
  )
88
 
@@ -110,11 +106,10 @@ async def evaluate_text(request: TextEvaluationRequest):
110
  with torch.no_grad():
111
  for batch in test_loader:
112
  # Move batch to device
113
- input_ids = batch['input_ids'].to(device)
114
- attention_mask = batch['attention_mask'].to(device)
115
 
116
  # Get model outputs
117
- outputs = model(input_ids=input_ids, attention_mask=attention_mask)
118
  preds = torch.argmax(outputs.logits, dim=-1)
119
 
120
  # Add batch predictions to list
 
59
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
 
61
  # Model and tokenizer paths
62
+ model_name = "Tonic/climate-guard-toxic-agent"
 
63
 
64
+ # Load tokenizer and model
65
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
66
  model = AutoModelForSequenceClassification.from_pretrained(
67
+ model_name,
68
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
69
+ trust_remote_code=True
 
 
70
  ).to(device)
71
 
72
+ # Set model to evaluation mode
 
73
  model.eval()
74
 
75
  # Preprocess function
76
  def preprocess_function(examples):
77
  return tokenizer(
78
  examples["quote"],
79
+ padding=False,
80
  truncation=True,
81
+ max_length=512,
82
  return_tensors=None
83
  )
84
 
 
106
  with torch.no_grad():
107
  for batch in test_loader:
108
  # Move batch to device
109
+ batch = {k: v.to(device) for k, v in batch.items()}
 
110
 
111
  # Get model outputs
112
+ outputs = model(**batch)
113
  preds = torch.argmax(outputs.logits, dim=-1)
114
 
115
  # Add batch predictions to list