Niklas Hoepner
commited on
Commit
·
9be3985
1
Parent(s):
e40a01a
Check edge cases
Browse files- L3Score.py +5 -4
L3Score.py
CHANGED
@@ -109,10 +109,9 @@ class L3Score(evaluate.Metric):
|
|
109 |
"""Optional: download external resources useful to compute the scores"""
|
110 |
pass
|
111 |
|
112 |
-
def _verify_input(self,
|
113 |
"""Verify the input parameters"""
|
114 |
|
115 |
-
print(provider)
|
116 |
if provider not in PROVIDER_WITH_TOP_LOGPROBS:
|
117 |
raise ValueError(
|
118 |
"Provider must offer top_logprobs to use this metric, pick from {}".format(
|
@@ -120,6 +119,8 @@ class L3Score(evaluate.Metric):
|
|
120 |
)
|
121 |
)
|
122 |
|
|
|
|
|
123 |
|
124 |
def _get_llm(self, model, api_key):
|
125 |
"""Get the LLM"""
|
@@ -137,10 +138,10 @@ class L3Score(evaluate.Metric):
|
|
137 |
model="gpt-4o-mini",
|
138 |
):
|
139 |
"""Returns the scores"""
|
|
|
140 |
|
141 |
-
print("Inside compute")
|
142 |
# Check whether llm can be initialized
|
143 |
-
self._verify_input(
|
144 |
|
145 |
# Initialize the LLM
|
146 |
llm = self._get_llm(model, api_key)
|
|
|
109 |
"""Optional: download external resources useful to compute the scores"""
|
110 |
pass
|
111 |
|
112 |
+
def _verify_input(self, questions, predictions, references, provider):
|
113 |
"""Verify the input parameters"""
|
114 |
|
|
|
115 |
if provider not in PROVIDER_WITH_TOP_LOGPROBS:
|
116 |
raise ValueError(
|
117 |
"Provider must offer top_logprobs to use this metric, pick from {}".format(
|
|
|
119 |
)
|
120 |
)
|
121 |
|
122 |
+
assert len(questions) == len(predictions) == len(references), "Questions, predictions and references must have the same length"
|
123 |
+
|
124 |
|
125 |
def _get_llm(self, model, api_key):
|
126 |
"""Get the LLM"""
|
|
|
138 |
model="gpt-4o-mini",
|
139 |
):
|
140 |
"""Returns the scores"""
|
141 |
+
print(questions,predictions,references)
|
142 |
|
|
|
143 |
# Check whether llm can be initialized
|
144 |
+
self._verify_input(questions, predictions, references, provider)
|
145 |
|
146 |
# Initialize the LLM
|
147 |
llm = self._get_llm(model, api_key)
|