update submit
Browse files- src/submission/submit.py +14 -4
src/submission/submit.py
CHANGED
@@ -29,14 +29,24 @@ def get_top_prediction(text, tokenizer, model):
|
|
29 |
|
30 |
with torch.no_grad():
|
31 |
outputs = model(**inputs)
|
32 |
-
logits = outputs.logits[0, -1]
|
33 |
|
34 |
options = [' A', ' B', ' C', ' D']
|
35 |
option_logits = []
|
|
|
|
|
36 |
for option in options:
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
# Get the option with the highest logit
|
42 |
top_option = max(option_logits, key=lambda x: x[0])[1]
|
|
|
29 |
|
30 |
with torch.no_grad():
|
31 |
outputs = model(**inputs)
|
32 |
+
logits = outputs.logits[0, -1] # Get logits of the last token
|
33 |
|
34 |
options = [' A', ' B', ' C', ' D']
|
35 |
option_logits = []
|
36 |
+
|
37 |
+
# Iterate through each option
|
38 |
for option in options:
|
39 |
+
option_ids = tokenizer(option).input_ids
|
40 |
+
# Ensure option_ids are within range and not empty
|
41 |
+
if option_ids and option_ids[-1] < logits.size(0):
|
42 |
+
option_id = option_ids[-1]
|
43 |
+
option_logit = logits[option_id]
|
44 |
+
option_logits.append((option_logit.item(), option.strip()))
|
45 |
+
else:
|
46 |
+
print(f"Skipping option '{option}' due to index out of range.")
|
47 |
+
|
48 |
+
if not option_logits:
|
49 |
+
return "No valid options"
|
50 |
|
51 |
# Get the option with the highest logit
|
52 |
top_option = max(option_logits, key=lambda x: x[0])[1]
|