update submit
Browse files- src/submission/submit.py +21 -13
src/submission/submit.py
CHANGED
@@ -26,32 +26,40 @@ def get_top_prediction(text, tokenizer, model):
|
|
26 |
if torch.cuda.is_available():
|
27 |
model = model.cuda()
|
28 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
|
|
|
|
29 |
|
30 |
with torch.no_grad():
|
31 |
outputs = model(**inputs)
|
32 |
-
|
|
|
|
|
33 |
|
34 |
-
options = ['
|
35 |
option_logits = []
|
36 |
-
|
37 |
-
# Iterate through each option
|
38 |
for option in options:
|
39 |
-
|
40 |
-
|
41 |
-
if
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
|
|
|
|
|
|
|
|
47 |
|
48 |
if not option_logits:
|
49 |
-
|
|
|
50 |
|
51 |
# Get the option with the highest logit
|
52 |
top_option = max(option_logits, key=lambda x: x[0])[1]
|
53 |
return top_option
|
54 |
|
|
|
55 |
def evaluate_model_accuracy(model_name, num_examples):
|
56 |
try:
|
57 |
# Load the model and tokenizer
|
|
|
26 |
if torch.cuda.is_available():
|
27 |
model = model.cuda()
|
28 |
inputs = {k: v.cuda() for k, v in inputs.items()}
|
29 |
+
else:
|
30 |
+
model = model.cpu()
|
31 |
|
32 |
with torch.no_grad():
|
33 |
outputs = model(**inputs)
|
34 |
+
# outputs.logits shape: [batch_size, seq_len, vocab_size]
|
35 |
+
# We want the logits for the last token
|
36 |
+
logits = outputs.logits[0, -1, :] # Shape: [vocab_size]
|
37 |
|
38 |
+
options = ['A', 'B', 'C', 'D']
|
39 |
option_logits = []
|
|
|
|
|
40 |
for option in options:
|
41 |
+
# Encode the option without adding special tokens
|
42 |
+
option_ids = tokenizer.encode(option, add_special_tokens=False)
|
43 |
+
if not option_ids:
|
44 |
+
print(f"Option '{option}' could not be tokenized.")
|
45 |
+
continue
|
46 |
+
option_id = option_ids[0]
|
47 |
+
vocab_size = logits.size(0)
|
48 |
+
if option_id >= vocab_size:
|
49 |
+
print(f"Option ID {option_id} is out of bounds for vocabulary size {vocab_size}")
|
50 |
+
continue
|
51 |
+
option_logit = logits[option_id]
|
52 |
+
option_logits.append((option_logit.item(), option))
|
53 |
|
54 |
if not option_logits:
|
55 |
+
print("No valid options found.")
|
56 |
+
return None
|
57 |
|
58 |
# Get the option with the highest logit
|
59 |
top_option = max(option_logits, key=lambda x: x[0])[1]
|
60 |
return top_option
|
61 |
|
62 |
+
|
63 |
def evaluate_model_accuracy(model_name, num_examples):
|
64 |
try:
|
65 |
# Load the model and tokenizer
|