Omartificial-Intelligence-Space commited on
Commit
898b1fc
·
verified ·
1 Parent(s): 04438ea

update submit

Browse files
Files changed (1) hide show
  1. src/submission/submit.py +14 -4
src/submission/submit.py CHANGED
@@ -29,14 +29,24 @@ def get_top_prediction(text, tokenizer, model):
29
 
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
- logits = outputs.logits[0, -1]
33
 
34
  options = [' A', ' B', ' C', ' D']
35
  option_logits = []
 
 
36
  for option in options:
37
- option_id = tokenizer(option).input_ids[-1]
38
- option_logit = logits[option_id]
39
- option_logits.append((option_logit.item(), option.strip()))
 
 
 
 
 
 
 
 
40
 
41
  # Get the option with the highest logit
42
  top_option = max(option_logits, key=lambda x: x[0])[1]
 
29
 
30
  with torch.no_grad():
31
  outputs = model(**inputs)
32
+ logits = outputs.logits[0, -1] # Get logits of the last token
33
 
34
  options = [' A', ' B', ' C', ' D']
35
  option_logits = []
36
+
37
+ # Iterate through each option
38
  for option in options:
39
+ option_ids = tokenizer(option).input_ids
40
+ # Ensure option_ids are within range and not empty
41
+ if option_ids and option_ids[-1] < logits.size(0):
42
+ option_id = option_ids[-1]
43
+ option_logit = logits[option_id]
44
+ option_logits.append((option_logit.item(), option.strip()))
45
+ else:
46
+ print(f"Skipping option '{option}' due to index out of range.")
47
+
48
+ if not option_logits:
49
+ return "No valid options"
50
 
51
  # Get the option with the highest logit
52
  top_option = max(option_logits, key=lambda x: x[0])[1]