import os
import gradio as gr
import json
from rxnim import RXNIM
from getReaction import generate_combined_image
import torch
from rxn.reaction import Reaction
from rdkit import Chem
from rdkit.Chem import rdChemReactions
from rdkit.Chem import Draw
from rdkit.Chem import AllChem
from rdkit.Chem.Draw import rdMolDraw2D
import cairosvg

PROMPT_DIR = "prompts/"
ckpt_path = "./rxn/model/model.ckpt"
model = Reaction(ckpt_path, device=torch.device('cpu'))

# 定义 prompt 文件名到友好名字的映射
PROMPT_NAMES = {
    "2_RxnOCR.txt": "Reaction Image Parsing Workflow",
}
example_diagram = "examples/exp.png"
rdkit_image = "examples/rdkit.png"

def list_prompt_files_with_names():
    """
    列出 prompts 目录下的所有 .txt 文件,为没有名字的生成默认名字。
    返回 {friendly_name: filename} 映射。
    """
    prompt_files = {}
    for f in os.listdir(PROMPT_DIR):
        if f.endswith(".txt"):
            # 如果文件名有预定义的名字,使用预定义名字
            friendly_name = PROMPT_NAMES.get(f, f"Task: {os.path.splitext(f)[0]}")
            prompt_files[friendly_name] = f
    return prompt_files

def parse_reactions(output_json):
    """
    解析 JSON 格式的反应数据并格式化输出,包含颜色定制。
    """
    reactions_data = json.loads(output_json)  # 转换 JSON 字符串为字典
    reactions_list = reactions_data.get("reactions", [])
    detailed_output = []
    smiles_output = [] 

    for reaction in reactions_list:
        reaction_id = reaction.get("reaction_id", "Unknown ID")
        reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])]
        conditions = [
            f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
            for c in reaction.get("conditions", [])
        ]
        conditions_1 = [
            f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>"
            for c in reaction.get("conditions", [])
        ]
        products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
        products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])]
        products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])]

        # 构造反应的完整字符串,定制字体颜色
        full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {', '.join(conditions_1)}"
        full_reaction = f"<span style='color:black'>{full_reaction}</span>"
        
        # 详细反应格式化输出
        reaction_output = f"<b>Reaction: </b> {reaction_id}<br>"
        reaction_output += f"  Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>"
        reaction_output += f"  Conditions: {', '.join(conditions)}<br>"
        reaction_output += f"  Products: {', '.join(products)}<br>"
        reaction_output += f"  <b>Full Reaction:</b> {full_reaction}<br>"
        reaction_output += "<br>"
        detailed_output.append(reaction_output)

        reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}"
        smiles_output.append(reaction_smiles)



    return detailed_output, smiles_output

def process_chem_image(image, selected_task):
    chem_mllm = RXNIM()

    # 将友好名字转换为实际文件名
    prompt_path = os.path.join(PROMPT_DIR, prompts_with_names[selected_task])
    image_path = "temp_image.png"
    image.save(image_path)

    # 调用 RXNIM 处理
    rxnim_result = chem_mllm.process(image_path, prompt_path)

    # 将 JSON 结果解析为结构化输出
    detailed_reactions, smiles_output = parse_reactions(rxnim_result)

    # 调用 RxnScribe 模型处理并生成整合图像
    predictions = model.predict_image_file(image_path, molscribe=True, ocr=True)
    combined_image_path = generate_combined_image(predictions, image_path)
    #combined_image_path = model.draw_predictions(predictions, image_path)

    json_file_path = "output.json"
    with open(json_file_path, "w") as json_file:
        json.dump(json.loads(rxnim_result), json_file, indent=4)


    # 返回详细反应和整合图像
    return "\n\n".join(detailed_reactions), smiles_output, combined_image_path, example_diagram, json_file_path


# 获取 prompts 和友好名字
prompts_with_names = list_prompt_files_with_names()

# 示例数据:图像路径 + 任务选项
examples = [
    
    ["examples/reaction1.png", "Reaction Image Parsing Workflow"],
    ["examples/reaction2.png", "Reaction Image Parsing Workflow"],
    ["examples/reaction3.png", "Reaction Image Parsing Workflow"],
    ["examples/reaction4.png", "Reaction Image Parsing Workflow"],
]

# 定义 Gradio 界面
with gr.Blocks() as demo:
    gr.Markdown(
"""

    <center> <h1>Towards Large-scale Chemical Reaction Image Parsing via a Multimodal Large Language Model<h1></center>

    Upload a reaction image and select a predefined task prompt.
    """)

    

    # 上半部分,输入区域
    with gr.Row(equal_height=False):
        with gr.Column(scale=1):  # 左侧列
            image_input = gr.Image(type="pil", label="Upload Reaction Image")
            task_radio = gr.Radio(
                choices=list(prompts_with_names.keys()),
                label="Select a predefined task",
            )
            with gr.Row():  # Clear 和 Submit 按钮放在同一行
                clear_button = gr.Button("Clear")
                process_button = gr.Button("Run", elem_id="submit-btn")
            
            gr.Markdown("### Reaction Imge Parsing Output")
            reaction_output = gr.HTML(label="Reaction outputs")
            
        
        with gr.Column(scale=1):
                
                gr.Markdown("### Reaction Extraction Output")
                visualization_output = gr.Image(label="Visualization Output") 
                schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram")      
                

        with gr.Column(scale=1):  
            gr.Markdown("### Machine-readable Data Output")
            smiles_output = gr.Textbox(
                label="Reaction SMILES",
                show_copy_button=True,
                interactive=False,
                visible=False,
            )
            

    # 下半部分,图像和 JSON 输出
            @gr.render(inputs = smiles_output)  # 使用gr.render修饰器绑定输入和渲染逻辑
            def show_split(inputs):  # 定义处理和展示分割文本的函数
                if not inputs or isinstance(inputs, str) and inputs.strip() == "":  # 检查输入文本是否为空
                    return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100)
                else:
                    # 假设输入是逗号分隔的 SMILES 字符串
                    smiles_list = inputs.split(",")
                    smiles_list = [item.strip("[]' ") for item in smiles_list]
                    components = []  # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件
                    for i, smiles in enumerate(smiles_list): 
                        smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "")
                        rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True)
                        
                        if rxn:

                            new_rxn = AllChem.ChemicalReaction()	
                            for mol in rxn.GetReactants():
                                mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
                                new_rxn.AddReactantTemplate(mol)
                            for mol in rxn.GetProducts():
                                mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol))
                                new_rxn.AddProductTemplate(mol)

                            rxn = new_rxn

                            def atom_mapping_remover(rxn):
                                for reactant in rxn.GetReactants():
                                    for atom in reactant.GetAtoms():
                                        atom.SetAtomMapNum(0)
                                for product in rxn.GetProducts():
                                    for atom in product.GetAtoms():
                                        atom.SetAtomMapNum(0)
                                return rxn
                            
                            atom_mapping_remover(rxn)

                            reactant1 = rxn.GetReactantTemplate(0)
                            print(reactant1.GetNumBonds)
                            reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None

                            if reactant1.GetNumBonds() > 0:
                                bond_length_reference = Draw.MeanBondLength(reactant1)
                            elif reactant2 and reactant2.GetNumBonds() > 0:
                                bond_length_reference = Draw.MeanBondLength(reactant2)
                            else:
                                bond_length_reference = 1.0 


                            drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1)
                            dopts = drawer.drawOptions()
                            dopts.padding = 0.1 
                            dopts.includeRadicals = True
                            Draw.SetACS1996Mode(dopts, bond_length_reference*0.55)
                            dopts.bondLineWidth = 1.5
                            drawer.DrawReaction(rxn)
                            drawer.FinishDrawing()
                            svg_content = drawer.GetDrawingText()
                            svg_file = f"reaction{i+1}.svg"
                            with open(svg_file, "w") as f:
                                f.write(svg_content)
                            png_file = f"reaction_{i+1}.png"
                            cairosvg.svg2png(url=svg_file, write_to=png_file)


                        
                        components.append(gr.Textbox(value=smiles,label= f"Reaction {i + 1} SMILES", show_copy_button=True, interactive=False))
                        components.append(gr.Image(value=png_file,label= f"Reaction {i + 1} RDKit Image")) 
                    return components  # 返回包含所有 SMILES Textbox 组件的列表
                
            download_json = gr.File(label="Download JSON File",)
        
        
            

    # 示例部分
    gr.Examples(
        examples=examples,
        inputs=[image_input, task_radio],
        outputs=[reaction_output, smiles_output, visualization_output],
    )

    # 绑定功能
    clear_button.click(
        lambda: (None, None, None, None, None),
        inputs=[],
        outputs=[
            image_input,
            task_radio,
            reaction_output,
            smiles_output,
            visualization_output,
        ],
    )

    process_button.click(
        process_chem_image,
        inputs=[image_input, task_radio],
        outputs=[
            reaction_output,
            smiles_output,
            visualization_output,
            schematic_diagram,
            download_json,
        ],
    )

demo.css = """
#submit-btn {
    background-color: #FF914D;
    color: white;
    font-weight: bold;
}
"""
demo.launch()