Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from collections import defaultdict | |
| from pathlib import Path | |
| import pandas as pd | |
| from rouge_cli import calculate_rouge_path | |
| from utils import calculate_rouge | |
| PRED = [ | |
| 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of the' | |
| ' final seconds on board Flight 9525. The Germanwings co-pilot says he had a "previous episode of severe' | |
| " depression\" German airline confirms it knew of Andreas Lubitz's depression years before he took control.", | |
| "The Palestinian Authority officially becomes the 123rd member of the International Criminal Court. The formal" | |
| " accession was marked with a ceremony at The Hague, in the Netherlands. The Palestinians signed the ICC's" | |
| " founding Rome Statute in January. Israel and the United States opposed the Palestinians' efforts to join the" | |
| " body.", | |
| "Amnesty International releases its annual report on the death penalty. The report catalogs the use of" | |
| " state-sanctioned killing as a punitive measure across the globe. At least 607 people were executed around the" | |
| " world in 2014, compared to 778 in 2013. The U.S. remains one of the worst offenders for imposing capital" | |
| " punishment.", | |
| ] | |
| TGT = [ | |
| 'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .' | |
| ' Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz' | |
| " had informed his Lufthansa training school of an episode of severe depression, airline says .", | |
| "Membership gives the ICC jurisdiction over alleged crimes committed in Palestinian territories since last June ." | |
| " Israel and the United States opposed the move, which could open the door to war crimes investigations against" | |
| " Israelis .", | |
| "Amnesty's annual death penalty report catalogs encouraging signs, but setbacks in numbers of those sentenced to" | |
| " death . Organization claims that governments around the world are using the threat of terrorism to advance" | |
| " executions . The number of executions worldwide has gone down by almost 22% compared with 2013, but death" | |
| " sentences up by 28% .", | |
| ] | |
| def test_disaggregated_scores_are_determinstic(): | |
| no_aggregation = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2", "rougeL"]) | |
| assert isinstance(no_aggregation, defaultdict) | |
| no_aggregation_just_r2 = calculate_rouge(PRED, TGT, bootstrap_aggregation=False, rouge_keys=["rouge2"]) | |
| assert ( | |
| pd.DataFrame(no_aggregation["rouge2"]).fmeasure.mean() | |
| == pd.DataFrame(no_aggregation_just_r2["rouge2"]).fmeasure.mean() | |
| ) | |
| def test_newline_cnn_improvement(): | |
| k = "rougeLsum" | |
| score = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=[k])[k] | |
| score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=[k])[k] | |
| assert score > score_no_sep | |
| def test_newline_irrelevant_for_other_metrics(): | |
| k = ["rouge1", "rouge2", "rougeL"] | |
| score_sep = calculate_rouge(PRED, TGT, newline_sep=True, rouge_keys=k) | |
| score_no_sep = calculate_rouge(PRED, TGT, newline_sep=False, rouge_keys=k) | |
| assert score_sep == score_no_sep | |
| def test_single_sent_scores_dont_depend_on_newline_sep(): | |
| pred = [ | |
| "Her older sister, Margot Frank, died in 1945, a month earlier than previously thought.", | |
| 'Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports .', | |
| ] | |
| tgt = [ | |
| "Margot Frank, died in 1945, a month earlier than previously thought.", | |
| 'Prosecutor: "No videos were used in the crash investigation" German papers say they saw a cell phone video of' | |
| " the final seconds on board Flight 9525.", | |
| ] | |
| assert calculate_rouge(pred, tgt, newline_sep=True) == calculate_rouge(pred, tgt, newline_sep=False) | |
| def test_pegasus_newline(): | |
| pred = [ | |
| """" "a person who has such a video needs to immediately give it to the investigators," prosecutor says .<n> "it is a very disturbing scene," editor-in-chief of bild online tells "erin burnett: outfront" """ | |
| ] | |
| tgt = [ | |
| """ Marseille prosecutor says "so far no videos were used in the crash investigation" despite media reports . Journalists at Bild and Paris Match are "very confident" the video clip is real, an editor says . Andreas Lubitz had informed his Lufthansa training school of an episode of severe depression, airline says .""" | |
| ] | |
| prev_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"], newline_sep=False)["rougeLsum"] | |
| new_score = calculate_rouge(pred, tgt, rouge_keys=["rougeLsum"])["rougeLsum"] | |
| assert new_score > prev_score | |
| def test_rouge_cli(): | |
| data_dir = Path("examples/seq2seq/test_data/wmt_en_ro") | |
| metrics = calculate_rouge_path(data_dir.joinpath("test.source"), data_dir.joinpath("test.target")) | |
| assert isinstance(metrics, dict) | |
| metrics_default_dict = calculate_rouge_path( | |
| data_dir.joinpath("test.source"), data_dir.joinpath("test.target"), bootstrap_aggregation=False | |
| ) | |
| assert isinstance(metrics_default_dict, defaultdict) | |