import os, shutil
import re
import numpy as np

PRESERVE = 0
TRANSFORM = 1

pj = os.path.join


class LinkedListNode:
    """
    Linked List Node
    """

    def __init__(self, string, preserve=True) -> None:
        self.string = string
        self.preserve = preserve
        self.next = None
        self.range = None
        # self.begin_line = 0
        # self.begin_char = 0


def convert_to_linklist(text, mask):
    root = LinkedListNode("", preserve=True)
    current_node = root
    for c, m, i in zip(text, mask, range(len(text))):
        if (m == PRESERVE and current_node.preserve) or (
            m == TRANSFORM and not current_node.preserve
        ):
            # add
            current_node.string += c
        else:
            current_node.next = LinkedListNode(c, preserve=(m == PRESERVE))
            current_node = current_node.next
    return root


def post_process(root):
    # 修复括号
    node = root
    while True:
        string = node.string
        if node.preserve:
            node = node.next
            if node is None:
                break
            continue

        def break_check(string):
            str_stack = [""]  # (lv, index)
            for i, c in enumerate(string):
                if c == "{":
                    str_stack.append("{")
                elif c == "}":
                    if len(str_stack) == 1:
                        print("stack fix")
                        return i
                    str_stack.pop(-1)
                else:
                    str_stack[-1] += c
            return -1

        bp = break_check(string)

        if bp == -1:
            pass
        elif bp == 0:
            node.string = string[:1]
            q = LinkedListNode(string[1:], False)
            q.next = node.next
            node.next = q
        else:
            node.string = string[:bp]
            q = LinkedListNode(string[bp:], False)
            q.next = node.next
            node.next = q

        node = node.next
        if node is None:
            break

    # 屏蔽空行和太短的句子
    node = root
    while True:
        if len(node.string.strip("\n").strip("")) == 0:
            node.preserve = True
        if len(node.string.strip("\n").strip("")) < 42:
            node.preserve = True
        node = node.next
        if node is None:
            break
    node = root
    while True:
        if node.next and node.preserve and node.next.preserve:
            node.string += node.next.string
            node.next = node.next.next
        node = node.next
        if node is None:
            break

    # 将前后断行符脱离
    node = root
    prev_node = None
    while True:
        if not node.preserve:
            lstriped_ = node.string.lstrip().lstrip("\n")
            if (
                (prev_node is not None)
                and (prev_node.preserve)
                and (len(lstriped_) != len(node.string))
            ):
                prev_node.string += node.string[: -len(lstriped_)]
                node.string = lstriped_
            rstriped_ = node.string.rstrip().rstrip("\n")
            if (
                (node.next is not None)
                and (node.next.preserve)
                and (len(rstriped_) != len(node.string))
            ):
                node.next.string = node.string[len(rstriped_) :] + node.next.string
                node.string = rstriped_
        # =-=-=
        prev_node = node
        node = node.next
        if node is None:
            break

    # 标注节点的行数范围
    node = root
    n_line = 0
    expansion = 2
    while True:
        n_l = node.string.count("\n")
        node.range = [n_line - expansion, n_line + n_l + expansion]  # 失败时,扭转的范围
        n_line = n_line + n_l
        node = node.next
        if node is None:
            break
    return root


"""
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
Latex segmentation with a binary mask (PRESERVE=0, TRANSFORM=1)
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
"""


def set_forbidden_text(text, mask, pattern, flags=0):
    """
    Add a preserve text area in this paper
    e.g. with pattern = r"\\begin\{algorithm\}(.*?)\\end\{algorithm\}"
    you can mask out (mask = PRESERVE so that text become untouchable for GPT)
    everything between "\begin{equation}" and "\end{equation}"
    """
    if isinstance(pattern, list):
        pattern = "|".join(pattern)
    pattern_compile = re.compile(pattern, flags)
    for res in pattern_compile.finditer(text):
        mask[res.span()[0] : res.span()[1]] = PRESERVE
    return text, mask


def reverse_forbidden_text(text, mask, pattern, flags=0, forbid_wrapper=True):
    """
    Move area out of preserve area (make text editable for GPT)
    count the number of the braces so as to catch compelete text area.
    e.g.
    \begin{abstract} blablablablablabla. \end{abstract}
    """
    if isinstance(pattern, list):
        pattern = "|".join(pattern)
    pattern_compile = re.compile(pattern, flags)
    for res in pattern_compile.finditer(text):
        if not forbid_wrapper:
            mask[res.span()[0] : res.span()[1]] = TRANSFORM
        else:
            mask[res.regs[0][0] : res.regs[1][0]] = PRESERVE  # '\\begin{abstract}'
            mask[res.regs[1][0] : res.regs[1][1]] = TRANSFORM  # abstract
            mask[res.regs[1][1] : res.regs[0][1]] = PRESERVE  # abstract
    return text, mask


def set_forbidden_text_careful_brace(text, mask, pattern, flags=0):
    """
    Add a preserve text area in this paper (text become untouchable for GPT).
    count the number of the braces so as to catch compelete text area.
    e.g.
    \caption{blablablablabla\texbf{blablabla}blablabla.}
    """
    pattern_compile = re.compile(pattern, flags)
    for res in pattern_compile.finditer(text):
        brace_level = -1
        p = begin = end = res.regs[0][0]
        for _ in range(1024 * 16):
            if text[p] == "}" and brace_level == 0:
                break
            elif text[p] == "}":
                brace_level -= 1
            elif text[p] == "{":
                brace_level += 1
            p += 1
        end = p + 1
        mask[begin:end] = PRESERVE
    return text, mask


def reverse_forbidden_text_careful_brace(
    text, mask, pattern, flags=0, forbid_wrapper=True
):
    """
    Move area out of preserve area (make text editable for GPT)
    count the number of the braces so as to catch compelete text area.
    e.g.
    \caption{blablablablabla\texbf{blablabla}blablabla.}
    """
    pattern_compile = re.compile(pattern, flags)
    for res in pattern_compile.finditer(text):
        brace_level = 0
        p = begin = end = res.regs[1][0]
        for _ in range(1024 * 16):
            if text[p] == "}" and brace_level == 0:
                break
            elif text[p] == "}":
                brace_level -= 1
            elif text[p] == "{":
                brace_level += 1
            p += 1
        end = p
        mask[begin:end] = TRANSFORM
        if forbid_wrapper:
            mask[res.regs[0][0] : begin] = PRESERVE
            mask[end : res.regs[0][1]] = PRESERVE
    return text, mask


def set_forbidden_text_begin_end(text, mask, pattern, flags=0, limit_n_lines=42):
    """
    Find all \begin{} ... \end{} text block that with less than limit_n_lines lines.
    Add it to preserve area
    """
    pattern_compile = re.compile(pattern, flags)

    def search_with_line_limit(text, mask):
        for res in pattern_compile.finditer(text):
            cmd = res.group(1)  # begin{what}
            this = res.group(2)  # content between begin and end
            this_mask = mask[res.regs[2][0] : res.regs[2][1]]
            white_list = [
                "document",
                "abstract",
                "lemma",
                "definition",
                "sproof",
                "em",
                "emph",
                "textit",
                "textbf",
                "itemize",
                "enumerate",
            ]
            if (cmd in white_list) or this.count(
                "\n"
            ) >= limit_n_lines:  # use a magical number 42
                this, this_mask = search_with_line_limit(this, this_mask)
                mask[res.regs[2][0] : res.regs[2][1]] = this_mask
            else:
                mask[res.regs[0][0] : res.regs[0][1]] = PRESERVE
        return text, mask

    return search_with_line_limit(text, mask)


"""
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
Latex Merge File
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
"""


def find_main_tex_file(file_manifest, mode):
    """
    在多Tex文档中,寻找主文件,必须包含documentclass,返回找到的第一个。
    P.S. 但愿没人把latex模板放在里面传进来 (6.25 加入判定latex模板的代码)
    """
    canidates = []
    for texf in file_manifest:
        if os.path.basename(texf).startswith("merge"):
            continue
        with open(texf, "r", encoding="utf8", errors="ignore") as f:
            file_content = f.read()
        if r"\documentclass" in file_content:
            canidates.append(texf)
        else:
            continue

    if len(canidates) == 0:
        raise RuntimeError("无法找到一个主Tex文件(包含documentclass关键字)")
    elif len(canidates) == 1:
        return canidates[0]
    else:  # if len(canidates) >= 2 通过一些Latex模板中常见(但通常不会出现在正文)的单词,对不同latex源文件扣分,取评分最高者返回
        canidates_score = []
        # 给出一些判定模板文档的词作为扣分项
        unexpected_words = [
            "\\LaTeX",
            "manuscript",
            "Guidelines",
            "font",
            "citations",
            "rejected",
            "blind review",
            "reviewers",
        ]
        expected_words = ["\\input", "\\ref", "\\cite"]
        for texf in canidates:
            canidates_score.append(0)
            with open(texf, "r", encoding="utf8", errors="ignore") as f:
                file_content = f.read()
                file_content = rm_comments(file_content)
            for uw in unexpected_words:
                if uw in file_content:
                    canidates_score[-1] -= 1
            for uw in expected_words:
                if uw in file_content:
                    canidates_score[-1] += 1
        select = np.argmax(canidates_score)  # 取评分最高者返回
        return canidates[select]


def rm_comments(main_file):
    new_file_remove_comment_lines = []
    for l in main_file.splitlines():
        # 删除整行的空注释
        if l.lstrip().startswith("%"):
            pass
        else:
            new_file_remove_comment_lines.append(l)
    main_file = "\n".join(new_file_remove_comment_lines)
    # main_file = re.sub(r"\\include{(.*?)}", r"\\input{\1}", main_file)  # 将 \include 命令转换为 \input 命令
    main_file = re.sub(r"(?<!\\)%.*", "", main_file)  # 使用正则表达式查找半行注释, 并替换为空字符串
    return main_file


def find_tex_file_ignore_case(fp):
    dir_name = os.path.dirname(fp)
    base_name = os.path.basename(fp)
    # 如果输入的文件路径是正确的
    if os.path.isfile(pj(dir_name, base_name)):
        return pj(dir_name, base_name)
    # 如果不正确,试着加上.tex后缀试试
    if not base_name.endswith(".tex"):
        base_name += ".tex"
    if os.path.isfile(pj(dir_name, base_name)):
        return pj(dir_name, base_name)
    # 如果还找不到,解除大小写限制,再试一次
    import glob

    for f in glob.glob(dir_name + "/*.tex"):
        base_name_s = os.path.basename(fp)
        base_name_f = os.path.basename(f)
        if base_name_s.lower() == base_name_f.lower():
            return f
        # 试着加上.tex后缀试试
        if not base_name_s.endswith(".tex"):
            base_name_s += ".tex"
        if base_name_s.lower() == base_name_f.lower():
            return f
    return None


def merge_tex_files_(project_foler, main_file, mode):
    """
    Merge Tex project recrusively
    """
    main_file = rm_comments(main_file)
    for s in reversed([q for q in re.finditer(r"\\input\{(.*?)\}", main_file, re.M)]):
        f = s.group(1)
        fp = os.path.join(project_foler, f)
        fp_ = find_tex_file_ignore_case(fp)
        if fp_:
            try:
                with open(fp_, "r", encoding="utf-8", errors="replace") as fx:
                    c = fx.read()
            except:
                c = f"\n\nWarning from GPT-Academic: LaTex source file is missing!\n\n"
        else:
            raise RuntimeError(f"找不到{fp},Tex源文件缺失!")
        c = merge_tex_files_(project_foler, c, mode)
        main_file = main_file[: s.span()[0]] + c + main_file[s.span()[1] :]
    return main_file


def find_title_and_abs(main_file):
    def extract_abstract_1(text):
        pattern = r"\\abstract\{(.*?)\}"
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1)
        else:
            return None

    def extract_abstract_2(text):
        pattern = r"\\begin\{abstract\}(.*?)\\end\{abstract\}"
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1)
        else:
            return None

    def extract_title(string):
        pattern = r"\\title\{(.*?)\}"
        match = re.search(pattern, string, re.DOTALL)

        if match:
            return match.group(1)
        else:
            return None

    abstract = extract_abstract_1(main_file)
    if abstract is None:
        abstract = extract_abstract_2(main_file)
    title = extract_title(main_file)
    return title, abstract


def merge_tex_files(project_foler, main_file, mode):
    """
    Merge Tex project recrusively
    P.S. 顺便把CTEX塞进去以支持中文
    P.S. 顺便把Latex的注释去除
    """
    main_file = merge_tex_files_(project_foler, main_file, mode)
    main_file = rm_comments(main_file)

    if mode == "translate_zh":
        # find paper documentclass
        pattern = re.compile(r"\\documentclass.*\n")
        match = pattern.search(main_file)
        assert match is not None, "Cannot find documentclass statement!"
        position = match.end()
        add_ctex = "\\usepackage{ctex}\n"
        add_url = "\\usepackage{url}\n" if "{url}" not in main_file else ""
        main_file = main_file[:position] + add_ctex + add_url + main_file[position:]
        # fontset=windows
        import platform

        main_file = re.sub(
            r"\\documentclass\[(.*?)\]{(.*?)}",
            r"\\documentclass[\1,fontset=windows,UTF8]{\2}",
            main_file,
        )
        main_file = re.sub(
            r"\\documentclass{(.*?)}",
            r"\\documentclass[fontset=windows,UTF8]{\1}",
            main_file,
        )
        # find paper abstract
        pattern_opt1 = re.compile(r"\\begin\{abstract\}.*\n")
        pattern_opt2 = re.compile(r"\\abstract\{(.*?)\}", flags=re.DOTALL)
        match_opt1 = pattern_opt1.search(main_file)
        match_opt2 = pattern_opt2.search(main_file)
        if (match_opt1 is None) and (match_opt2 is None):
            # "Cannot find paper abstract section!"
            main_file = insert_abstract(main_file)
        match_opt1 = pattern_opt1.search(main_file)
        match_opt2 = pattern_opt2.search(main_file)
        assert (match_opt1 is not None) or (
            match_opt2 is not None
        ), "Cannot find paper abstract section!"
    return main_file


insert_missing_abs_str = r"""
\begin{abstract}
The GPT-Academic program cannot find abstract section in this paper.
\end{abstract}
"""


def insert_abstract(tex_content):
    if "\\maketitle" in tex_content:
        # find the position of "\maketitle"
        find_index = tex_content.index("\\maketitle")
        # find the nearest ending line
        end_line_index = tex_content.find("\n", find_index)
        # insert "abs_str" on the next line
        modified_tex = (
            tex_content[: end_line_index + 1]
            + "\n\n"
            + insert_missing_abs_str
            + "\n\n"
            + tex_content[end_line_index + 1 :]
        )
        return modified_tex
    elif r"\begin{document}" in tex_content:
        # find the position of "\maketitle"
        find_index = tex_content.index(r"\begin{document}")
        # find the nearest ending line
        end_line_index = tex_content.find("\n", find_index)
        # insert "abs_str" on the next line
        modified_tex = (
            tex_content[: end_line_index + 1]
            + "\n\n"
            + insert_missing_abs_str
            + "\n\n"
            + tex_content[end_line_index + 1 :]
        )
        return modified_tex
    else:
        return tex_content


"""
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
Post process
=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=
"""


def mod_inbraket(match):
    """
    为啥chatgpt会把cite里面的逗号换成中文逗号呀
    """
    # get the matched string
    cmd = match.group(1)
    str_to_modify = match.group(2)
    # modify the matched string
    str_to_modify = str_to_modify.replace(":", ":")  # 前面是中文冒号,后面是英文冒号
    str_to_modify = str_to_modify.replace(",", ",")  # 前面是中文逗号,后面是英文逗号
    # str_to_modify = 'BOOM'
    return "\\" + cmd + "{" + str_to_modify + "}"


def fix_content(final_tex, node_string):
    """
    Fix common GPT errors to increase success rate
    """
    final_tex = re.sub(r"(?<!\\)%", "\\%", final_tex)
    final_tex = re.sub(r"\\([a-z]{2,10})\ \{", r"\\\1{", string=final_tex)
    final_tex = re.sub(r"\\\ ([a-z]{2,10})\{", r"\\\1{", string=final_tex)
    final_tex = re.sub(r"\\([a-z]{2,10})\{([^\}]*?)\}", mod_inbraket, string=final_tex)

    if "Traceback" in final_tex and "[Local Message]" in final_tex:
        final_tex = node_string  # 出问题了,还原原文
    if node_string.count("\\begin") != final_tex.count("\\begin"):
        final_tex = node_string  # 出问题了,还原原文
    if node_string.count("\_") > 0 and node_string.count("\_") > final_tex.count("\_"):
        # walk and replace any _ without \
        final_tex = re.sub(r"(?<!\\)_", "\\_", final_tex)

    def compute_brace_level(string):
        # this function count the number of { and }
        brace_level = 0
        for c in string:
            if c == "{":
                brace_level += 1
            elif c == "}":
                brace_level -= 1
        return brace_level

    def join_most(tex_t, tex_o):
        # this function join translated string and original string when something goes wrong
        p_t = 0
        p_o = 0

        def find_next(string, chars, begin):
            p = begin
            while p < len(string):
                if string[p] in chars:
                    return p, string[p]
                p += 1
            return None, None

        while True:
            res1, char = find_next(tex_o, ["{", "}"], p_o)
            if res1 is None:
                break
            res2, char = find_next(tex_t, [char], p_t)
            if res2 is None:
                break
            p_o = res1 + 1
            p_t = res2 + 1
        return tex_t[:p_t] + tex_o[p_o:]

    if compute_brace_level(final_tex) != compute_brace_level(node_string):
        # 出问题了,还原部分原文,保证括号正确
        final_tex = join_most(final_tex, node_string)
    return final_tex


def compile_latex_with_timeout(command, cwd, timeout=60):
    import subprocess

    process = subprocess.Popen(
        command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=cwd
    )
    try:
        stdout, stderr = process.communicate(timeout=timeout)
    except subprocess.TimeoutExpired:
        process.kill()
        stdout, stderr = process.communicate()
        print("Process timed out!")
        return False
    return True


def run_in_subprocess_wrapper_func(func, args, kwargs, return_dict, exception_dict):
    import sys

    try:
        result = func(*args, **kwargs)
        return_dict["result"] = result
    except Exception as e:
        exc_info = sys.exc_info()
        exception_dict["exception"] = exc_info


def run_in_subprocess(func):
    import multiprocessing

    def wrapper(*args, **kwargs):
        return_dict = multiprocessing.Manager().dict()
        exception_dict = multiprocessing.Manager().dict()
        process = multiprocessing.Process(
            target=run_in_subprocess_wrapper_func,
            args=(func, args, kwargs, return_dict, exception_dict),
        )
        process.start()
        process.join()
        process.close()
        if "exception" in exception_dict:
            # ooops, the subprocess ran into an exception
            exc_info = exception_dict["exception"]
            raise exc_info[1].with_traceback(exc_info[2])
        if "result" in return_dict.keys():
            # If the subprocess ran successfully, return the result
            return return_dict["result"]

    return wrapper


def _merge_pdfs(pdf1_path, pdf2_path, output_path):
    import PyPDF2  # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放

    Percent = 0.95
    # raise RuntimeError('PyPDF2 has a serious memory leak problem, please use other tools to merge PDF files.')
    # Open the first PDF file
    with open(pdf1_path, "rb") as pdf1_file:
        pdf1_reader = PyPDF2.PdfFileReader(pdf1_file)
        # Open the second PDF file
        with open(pdf2_path, "rb") as pdf2_file:
            pdf2_reader = PyPDF2.PdfFileReader(pdf2_file)
            # Create a new PDF file to store the merged pages
            output_writer = PyPDF2.PdfFileWriter()
            # Determine the number of pages in each PDF file
            num_pages = max(pdf1_reader.numPages, pdf2_reader.numPages)
            # Merge the pages from the two PDF files
            for page_num in range(num_pages):
                # Add the page from the first PDF file
                if page_num < pdf1_reader.numPages:
                    page1 = pdf1_reader.getPage(page_num)
                else:
                    page1 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
                # Add the page from the second PDF file
                if page_num < pdf2_reader.numPages:
                    page2 = pdf2_reader.getPage(page_num)
                else:
                    page2 = PyPDF2.PageObject.createBlankPage(pdf1_reader)
                # Create a new empty page with double width
                new_page = PyPDF2.PageObject.createBlankPage(
                    width=int(
                        int(page1.mediaBox.getWidth())
                        + int(page2.mediaBox.getWidth()) * Percent
                    ),
                    height=max(page1.mediaBox.getHeight(), page2.mediaBox.getHeight()),
                )
                new_page.mergeTranslatedPage(page1, 0, 0)
                new_page.mergeTranslatedPage(
                    page2,
                    int(
                        int(page1.mediaBox.getWidth())
                        - int(page2.mediaBox.getWidth()) * (1 - Percent)
                    ),
                    0,
                )
                output_writer.addPage(new_page)
            # Save the merged PDF file
            with open(output_path, "wb") as output_file:
                output_writer.write(output_file)


merge_pdfs = run_in_subprocess(_merge_pdfs)  # PyPDF2这个库有严重的内存泄露问题,把它放到子进程中运行,从而方便内存的释放