import evaluate
import datasets
import numpy as np

# TODO: Add BibTeX citation
_CITATION = """\
@InProceedings{huggingface:module,
title = {A great new module},
authors={huggingface, Inc.},
year={2020}
}
"""

# TODO: Add description of the module here
_DESCRIPTION = """\
This metric is used for evaluating the quality of relation extraction output. By calculating the Micro and Macro F1 score of every relation extraction outputs to ensure the quality.
"""


_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using Precision, Recall, F1 Score.
Args:
    predictions (list of list of dictionary): A list of predicted relations from the model.
    references (list of list of dictionary): A list of ground-truth or reference relations to compare the predictions against.
    
Returns:
    **output** (`dictionary` of `dictionary`s) A dictionary mapping each entity type to its respective scoring metrics such as Precision, Recall, F1 Score.
- **entity type** (`dictionary`): score of selected relation type
  - **tp** : true positive count
  - **fp** : false positive count
  - **fn** : false negative count
  - **p** : precision
  - **r** : recall
  - **f1** : micro f1 score
  - **Macro_f1** : macro f1 score
  - **Macro_p** : macro precision
  - **Macro_r** : macro recall
Examples:
    metric_path = "Ikala-allen/relation_extraction"
    module = evaluate.load(metric_path)
    references = [
      [
        {"head": "phipigments", "head_type": "brand", "type": "sell", "tail": "國際認證之色乳", "tail_type": "product"},
        {"head": "tinadaviespigments", "head_type": "brand", "type": "sell", "tail": "國際認證之色乳", "tail_type": "product"},
        {'head': 'A醛賦活緊緻精華', 'tail': 'Serum', 'head_type': 'product', 'tail_type': 'category', 'type': 'belongs_to'},
      ]
    ]
    predictions = [
      [
        {"head": "phipigments", "head_type": "product", "type": "sell", "tail": "國際認證之色乳", "tail_type": "product"},
        {"head": "tinadaviespigments", "head_type": "brand", "type": "sell", "tail": "國際認證之色乳", "tail_type": "product"},
      ]
    ]
    evaluation_scores = module.compute(predictions=predictions, references=references, mode="strict", detailed_scores=False, relation_types=[])
    print(evaluation_scores)
    {'tp': 1, 'fp': 1, 'fn': 2, 'p': 50.0, 'r': 33.333333333333336, 'f1': 40.0, 'Macro_f1': 25.0, 'Macro_p': 25.0, 'Macro_r': 25.0}
"""


def convert_format(data:list):
    """
    Args:
        data (list) : list of dictionaries with different entity elements
    e.g
        [
        {'head': ['phipigments', 'tinadaviespigments'...], 
        'head_type': ['product', 'brand'...], 
        'type': ['sell', 'sell'...], 
        'tail': ['國際認證之色乳', '國際認證之色乳'...], 
        'tail_type': ['product', 'product'...]}, 

        {'head': ['SABONTAIWAN', 'SNTAIWAN'...], 
        'head_type': ['brand', 'brand'...], 
        'type': ['sell', 'sell'...], 
        'tail': ['大馬士革玫瑰有機光燦系列', '大馬士革玫瑰有機光燦系列'...], 
        'tail_type': ['product', 'product'...]}
        ...
        ]
        
    """
    predictions = []
    for item in data:
        prediction_group = []
        for i in range(len(item['head'])):
            prediction = {
                'head': item['head'][i],
                'head_type': item['head_type'][i],
                'type': item['type'][i],
                'tail': item['tail'][i],
                'tail_type': item['tail_type'][i],
            }
            prediction_group.append(prediction)
        predictions.append(prediction_group)
    return predictions

@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class relation_extraction(evaluate.Metric):
    """evaluating the quality of relation extraction output"""

    def _info(self):
        # TODO: Specifies the evaluate.EvaluationModuleInfo object
        return evaluate.MetricInfo(
            # This is the description that will appear on the modules page.
            module_type="metric",
            description=_DESCRIPTION,
            citation=_CITATION,
            inputs_description=_KWARGS_DESCRIPTION,
            # This defines the format of each prediction and reference
            features=datasets.Features({
                'predictions': datasets.Sequence({
                    "head": datasets.Value("string"),
                    "head_type": datasets.Value("string"),
                    "type": datasets.Value("string"),
                    "tail": datasets.Value("string"),
                    "tail_type": datasets.Value("string"),
                }),
                'references': datasets.Sequence({
                    "head": datasets.Value("string"),
                    "head_type": datasets.Value("string"),
                    "type": datasets.Value("string"),
                    "tail": datasets.Value("string"),
                    "tail_type": datasets.Value("string"),
                }),
            }),
            # Homepage of the module for documentation
            homepage="http://module.homepage",
            # Additional links to the codebase or references
            codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
            reference_urls=["http://path.to.reference.url/new_module"]
        )

    def _download_and_prepare(self, dl_manager):
        pass

    def _compute(self, predictions, references, mode="strict", detailed_scores=False, relation_types=[]):
        """
        This method computes and returns various scoring metrics for the prediction model based on the mode specified, including Precision, Recall, F1-Score and others. It evaluates the model's predictions against the provided reference data. 
        
        Parameters:
            predictions: A list of predicted relations from the model.
            references: A list of ground-truth or reference relations to compare the predictions against.
            mode: Evaluation mode - 'strict' or 'boundaries'. 'strict' mode takes into account both entities type and their relationships
                  while 'boundaries' mode only considers the entity spans of the relationships.
            detailed_scores: Boolean value, if True it returns scores for each relation type specifically, 
                                      if False it returns the overall scores.
            relation_types: A list of relation types to consider while evaluating. If not provided, relation types will be constructed 
                            from the ground truth or reference data.
        
        Returns:
            A dictionary mapping each entity type to its respective scoring metrics such as Precision, Recall, F1 Score.
        """
    
        predictions = convert_format(predictions)
        references = convert_format(references)
        
        assert mode in ["strict", "boundaries"]

        # construct relation_types from ground truth if not given
        if len(relation_types) == 0:
            for triplets in references:
                for triplet in triplets:
                    relation = triplet["type"]
                    if relation not in relation_types:
                        relation_types.append(relation)
        
        scores = {rel: {"tp": 0, "fp": 0, "fn": 0} for rel in relation_types + ["ALL"]}
        
        # Count GT relations and Predicted relations
        n_sents = len(references)
        n_rels = sum([len([rel for rel in sent]) for sent in references])
        n_found = sum([len([rel for rel in sent]) for sent in predictions])

        # Count TP, FP and FN per type
        for pred_sent, gt_sent in zip(predictions, references):
            for rel_type in relation_types:
                # strict mode takes argument types into account
                if mode == "strict":
                    pred_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in pred_sent if
                                rel["type"] == rel_type}
                    gt_rels = {(rel["head"], rel["head_type"], rel["tail"], rel["tail_type"]) for rel in gt_sent if
                            rel["type"] == rel_type}

                # boundaries mode only takes argument spans into account
                elif mode == "boundaries":
                    pred_rels = {(rel["head"], rel["tail"]) for rel in pred_sent if rel["type"] == rel_type}
                    gt_rels = {(rel["head"], rel["tail"]) for rel in gt_sent if rel["type"] == rel_type}

                scores[rel_type]["tp"] += len(pred_rels & gt_rels)
                scores[rel_type]["fp"] += len(pred_rels - gt_rels)
                scores[rel_type]["fn"] += len(gt_rels - pred_rels)

        # Compute per entity Precision / Recall / F1
        for rel_type in scores.keys():
            if scores[rel_type]["tp"]:
                scores[rel_type]["p"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fp"] + scores[rel_type]["tp"])
                scores[rel_type]["r"] = 100 * scores[rel_type]["tp"] / (scores[rel_type]["fn"] + scores[rel_type]["tp"])
            else:
                scores[rel_type]["p"], scores[rel_type]["r"] = 0, 0

            if not scores[rel_type]["p"] + scores[rel_type]["r"] == 0:
                scores[rel_type]["f1"] = 2 * scores[rel_type]["p"] * scores[rel_type]["r"] / (
                        scores[rel_type]["p"] + scores[rel_type]["r"])
            else:
                scores[rel_type]["f1"] = 0

        # Compute micro F1 Scores
        tp = sum([scores[rel_type]["tp"] for rel_type in relation_types])
        fp = sum([scores[rel_type]["fp"] for rel_type in relation_types])
        fn = sum([scores[rel_type]["fn"] for rel_type in relation_types])
        

        if tp:
            precision = 100 * tp / (tp + fp)
            recall = 100 * tp / (tp + fn)
            f1 = 2 * precision * recall / (precision + recall)

        else:
            precision, recall, f1 = 0, 0, 0

        scores["ALL"]["p"] = precision
        scores["ALL"]["r"] = recall
        scores["ALL"]["f1"] = f1
        scores["ALL"]["tp"] = tp
        scores["ALL"]["fp"] = fp
        scores["ALL"]["fn"] = fn


        # Compute Macro F1 Scores
        scores["ALL"]["Macro_f1"] = np.mean([scores[ent_type]["f1"] for ent_type in relation_types])
        scores["ALL"]["Macro_p"] = np.mean([scores[ent_type]["p"] for ent_type in relation_types])
        scores["ALL"]["Macro_r"] = np.mean([scores[ent_type]["r"] for ent_type in relation_types])

        if detailed_scores:
            return scores
            
        return scores["ALL"]