from argparse import ArgumentParser
from loguru import logger

from weakly_supervised_parser.settings import TRAINED_MODEL_PATH
from weakly_supervised_parser.utils.prepare_dataset import DataLoaderHelper
from weakly_supervised_parser.utils.populate_chart import PopulateCKYChart
from weakly_supervised_parser.tree.evaluate import calculate_F1_for_spans, tree_to_spans
from weakly_supervised_parser.model.trainer import InsideOutsideStringClassifier
from weakly_supervised_parser.settings import PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH, PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH

from weakly_supervised_parser.model.span_classifier import LightningModel


class Predictor:
    def __init__(self, sentence):
        self.sentence = sentence
        self.sentence_list = sentence.split()

    def obtain_best_parse(self, predict_type, model, scale_axis, predict_batch_size, return_df=False):
        unique_tokens_flag, span_scores, df = PopulateCKYChart(sentence=self.sentence).fill_chart(predict_type=predict_type, 
                                                                                                  model=model, 
                                                                                                  scale_axis=scale_axis, 
                                                                                                  predict_batch_size=predict_batch_size)

        if unique_tokens_flag:
            best_parse = "(S " + " ".join(["(S " + item + ")" for item in self.sentence_list]) + ")"
            logger.info("BEST PARSE", best_parse)
        else:
            best_parse = PopulateCKYChart(sentence=self.sentence).best_parse_tree(span_scores)
        if return_df:
            return best_parse, df
        return best_parse


def process_test_sample(index, sentence, gold_file_path, predict_type, model, scale_axis, predict_batch_size, return_df=False):
    best_parse, df = Predictor(sentence=sentence).obtain_best_parse(predict_type=predict_type, 
                                                                    model=model, 
                                                                    scale_axis=scale_axis, 
                                                                    predict_batch_size=predict_batch_size,
                                                                    return_df=True)
    gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
    sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
    if sentence_f1 < 25.0:
        logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
    else:
        logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
    if return_df:
        return best_parse, df
    else:
        return best_parse


def process_co_train_test_sample(index, sentence, gold_file_path, inside_model, outside_model, return_df=False):
    _, df_inside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="inside", model=inside_model, return_df=True)
    _, df_outside = PopulateCKYChart(sentence=sentence).compute_scores(predict_type="outside", model=outside_model, return_df=True)
    df = df_inside.copy()
    df["scores"] = df_inside["scores"] * df_outside["scores"]
    _, span_scores, df = PopulateCKYChart(sentence=sentence).fill_chart(data=df)
    best_parse = PopulateCKYChart(sentence=sentence).best_parse_tree(span_scores)
    gold_standard = DataLoaderHelper(input_file_object=gold_file_path)
    sentence_f1 = calculate_F1_for_spans(tree_to_spans(gold_standard[index]), tree_to_spans(best_parse))
    if sentence_f1 < 25.0:
        logger.warning(f"Index: {index} <> F1: {sentence_f1:.2f}")
    else:
        logger.info(f"Index: {index} <> F1: {sentence_f1:.2f}")
    return best_parse


def main():
    parser = ArgumentParser(description="Inference Pipeline for the Inside Outside String Classifier", add_help=True)

    group = parser.add_mutually_exclusive_group(required=True)

    group.add_argument("--use_inside", action="store_true", help="Whether to predict using inside model")

    group.add_argument("--use_inside_self_train", action="store_true", help="Whether to predict using inside model with self-training")

    group.add_argument("--use_outside", action="store_true", help="Whether to predict using outside model")

    group.add_argument("--use_inside_outside_co_train", action="store_true", help="Whether to predict using inside-outside model with co-training")

    parser.add_argument("--model_name_or_path", type=str, default="roberta-base", help="Path to the model identifier from huggingface.co/models")

    parser.add_argument("--save_path", type=str, required=True, help="Path to save the final trees")
    
    parser.add_argument("--scale_axis", choices=[None, 1], default=None, help="Whether to scale axis globally (None) or sequentially (1) across batches during softmax computation")
    
    parser.add_argument("--predict_batch_size", type=int, help="Batch size during inference")

    parser.add_argument(
        "--inside_max_seq_length", default=256, type=int, help="The maximum total input sequence length after tokenization for the inside model"
    )

    parser.add_argument(
        "--outside_max_seq_length", default=64, type=int, help="The maximum total input sequence length after tokenization for the outside model"
    )

    args = parser.parse_args()

    if args.use_inside:
        pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model.ckpt"
        max_seq_length = args.inside_max_seq_length

    if args.use_inside_self_train:
        pre_trained_model_path = TRAINED_MODEL_PATH + "inside_model_self_trained.onnx"
        max_seq_length = args.inside_max_seq_length

    if args.use_outside:
        pre_trained_model_path = TRAINED_MODEL_PATH + "outside_model.onnx"
        max_seq_length = args.outside_max_seq_length

    if args.use_inside_outside_co_train:
        inside_pre_trained_model_path = "inside_model_co_trained.onnx"
        inside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.inside_max_seq_length)
        inside_model.load_model(pre_trained_model_path=inside_pre_trained_model_path)

        outside_pre_trained_model_path = "outside_model_co_trained.onnx"
        outside_model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=args.outside_max_seq_length)
        outside_model.load_model(pre_trained_model_path=outside_pre_trained_model_path)
    else:
        # model = InsideOutsideStringClassifier(model_name_or_path=args.model_name_or_path, max_seq_length=max_seq_length)
        # model.load_model(pre_trained_model_path=pre_trained_model_path)
        
        model = LightningModel.load_from_checkpoint(checkpoint_path=pre_trained_model_path)

    if args.use_inside or args.use_inside_self_train:
        predict_type = "inside"

    if args.use_outside:
        predict_type = "outside"

    with open(args.save_path, "w") as out_file:
        test_sentences = DataLoaderHelper(input_file_object=PTB_TEST_SENTENCES_WITHOUT_PUNCTUATION_PATH).read_lines()
        test_gold_file_path = PTB_TEST_GOLD_WITHOUT_PUNCTUATION_ALIGNED_PATH
        for test_index, test_sentence in enumerate(test_sentences):
            if args.use_inside_outside_co_train:
                best_parse = process_co_train_test_sample(
                    test_index, test_sentence, test_gold_file_path, inside_model=inside_model, outside_model=outside_model
                )
            else:
                best_parse = process_test_sample(test_index, test_sentence, test_gold_file_path, predict_type=predict_type, model=model,
                                                 scale_axis=args.scale_axis, predict_batch_size=args.predict_batch_size)

            out_file.write(best_parse + "\n")


if __name__ == "__main__":
    main()