Omartificial-Intelligence-Space commited on
Commit
560f753
·
verified ·
1 Parent(s): 898b1fc

update submit

Browse files
Files changed (1) hide show
  1. src/submission/submit.py +21 -13
src/submission/submit.py CHANGED
@@ -26,32 +26,40 @@ def get_top_prediction(text, tokenizer, model):
26
  if torch.cuda.is_available():
27
  model = model.cuda()
28
  inputs = {k: v.cuda() for k, v in inputs.items()}
 
 
29
 
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
- logits = outputs.logits[0, -1] # Get logits of the last token
 
 
33
 
34
- options = [' A', ' B', ' C', ' D']
35
  option_logits = []
36
-
37
- # Iterate through each option
38
  for option in options:
39
- option_ids = tokenizer(option).input_ids
40
- # Ensure option_ids are within range and not empty
41
- if option_ids and option_ids[-1] < logits.size(0):
42
- option_id = option_ids[-1]
43
- option_logit = logits[option_id]
44
- option_logits.append((option_logit.item(), option.strip()))
45
- else:
46
- print(f"Skipping option '{option}' due to index out of range.")
 
 
 
 
47
 
48
  if not option_logits:
49
- return "No valid options"
 
50
 
51
  # Get the option with the highest logit
52
  top_option = max(option_logits, key=lambda x: x[0])[1]
53
  return top_option
54
 
 
55
  def evaluate_model_accuracy(model_name, num_examples):
56
  try:
57
  # Load the model and tokenizer
 
26
  if torch.cuda.is_available():
27
  model = model.cuda()
28
  inputs = {k: v.cuda() for k, v in inputs.items()}
29
+ else:
30
+ model = model.cpu()
31
 
32
  with torch.no_grad():
33
  outputs = model(**inputs)
34
+ # outputs.logits shape: [batch_size, seq_len, vocab_size]
35
+ # We want the logits for the last token
36
+ logits = outputs.logits[0, -1, :] # Shape: [vocab_size]
37
 
38
+ options = ['A', 'B', 'C', 'D']
39
  option_logits = []
 
 
40
  for option in options:
41
+ # Encode the option without adding special tokens
42
+ option_ids = tokenizer.encode(option, add_special_tokens=False)
43
+ if not option_ids:
44
+ print(f"Option '{option}' could not be tokenized.")
45
+ continue
46
+ option_id = option_ids[0]
47
+ vocab_size = logits.size(0)
48
+ if option_id >= vocab_size:
49
+ print(f"Option ID {option_id} is out of bounds for vocabulary size {vocab_size}")
50
+ continue
51
+ option_logit = logits[option_id]
52
+ option_logits.append((option_logit.item(), option))
53
 
54
  if not option_logits:
55
+ print("No valid options found.")
56
+ return None
57
 
58
  # Get the option with the highest logit
59
  top_option = max(option_logits, key=lambda x: x[0])[1]
60
  return top_option
61
 
62
+
63
  def evaluate_model_accuracy(model_name, num_examples):
64
  try:
65
  # Load the model and tokenizer