Spaces:
Runtime error
Runtime error
davebulaval
commited on
Commit
Β·
008aa62
1
Parent(s):
35bc035
add fix for matching elements
Browse files- meaningbert.py +11 -1
meaningbert.py
CHANGED
@@ -118,6 +118,10 @@ class MeaningBERT(evaluate.Metric):
|
|
118 |
), "The number of document is different of the number of simplifications."
|
119 |
hashcode = _HASH
|
120 |
|
|
|
|
|
|
|
|
|
121 |
# We load the MeaningBERT pretrained model
|
122 |
scorer = AutoModelForSequenceClassification.from_pretrained(
|
123 |
"davebulaval/MeaningBERT"
|
@@ -140,8 +144,14 @@ class MeaningBERT(evaluate.Metric):
|
|
140 |
# We process the text
|
141 |
scores = scorer(**tokenize_text)
|
142 |
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
output_dict = {
|
144 |
-
"scores": scores
|
145 |
"hashcode": hashcode,
|
146 |
}
|
147 |
return output_dict
|
|
|
118 |
), "The number of document is different of the number of simplifications."
|
119 |
hashcode = _HASH
|
120 |
|
121 |
+
matching_index = [
|
122 |
+
i for i, item in enumerate(documents) if item in simplifications
|
123 |
+
]
|
124 |
+
|
125 |
# We load the MeaningBERT pretrained model
|
126 |
scorer = AutoModelForSequenceClassification.from_pretrained(
|
127 |
"davebulaval/MeaningBERT"
|
|
|
144 |
# We process the text
|
145 |
scores = scorer(**tokenize_text)
|
146 |
|
147 |
+
scores = scores.logits.tolist()
|
148 |
+
|
149 |
+
if len(matching_index) > 0:
|
150 |
+
for matching_element_index in matching_index:
|
151 |
+
scores[matching_element_index] = 100
|
152 |
+
|
153 |
output_dict = {
|
154 |
+
"scores": scores,
|
155 |
"hashcode": hashcode,
|
156 |
}
|
157 |
return output_dict
|