import streamlit as st
import datasets
import numpy as np
import html
def show_examples(category_name, dataset_name, model_lists, display_model_names):
    st.divider()
    sample_folder = f"./examples/{category_name}/{dataset_name}"
    
    dataset = datasets.load_from_disk(sample_folder)
    for index in range(len(dataset)):
        with st.container():
            st.markdown(f'##### Example-{index+1}')
            col1, col2 = st.columns([0.3, 0.7], vertical_alignment="center")
            # with col1:
            st.audio(f'{sample_folder}/sample_{index}.wav', format="audio/wav")
                        
            if dataset_name in ['CN-College-Listen-MCQ-Test', 'DREAM-TTS-MCQ-Test']:
                
                choices = dataset[index]['other_attributes']['choices'] 
                if isinstance(choices, str):
                    choices_text = choices
                elif isinstance(choices, list):
                    choices_text = ' '.join(i for i in choices)
                
                question_text = f"""{dataset[index]['instruction']['text']} {choices_text}"""
            else:
                question_text = f"""{dataset[index]['instruction']['text']}"""
            question_text = html.escape(question_text)
            
            # st.divider()
            with st.container():
                custom_css = """
                            
                            """
                st.markdown(custom_css, unsafe_allow_html=True)
                model_lists.sort()
                s = f"""
                       | REFERENCE | {html.escape(question_text.replace('(A)', ' (A)').replace('(B)', '
 (B)').replace('(C)', '
 (C)'))}
 | {html.escape(dataset[index]['answer']['text'])} | 
                """
                if dataset_name in ['CN-College-Listen-MCQ-Test', 'DREAM-TTS-MCQ-Test']:
                    for model in model_lists:
                        try:
                            model_prediction = dataset[index][model]['model_prediction']
                            model_prediction = model_prediction.replace('<','').replace('>','').replace('\n','(newline)').replace('*','')
                            s += f"""
                                | {display_model_names[model]} | {dataset[index][model]['text'].replace('Choices:', ' Choices:').replace('(A)', '
 (A)').replace('(B)', '
 (B)').replace('(C)', '
 (C)') 
                                     }
 | {html.escape(model_prediction)} | 
"""
                        except:
                            print(f"{model} is not in {dataset_name}")
                            continue
                else:
                    for model in model_lists:
                        print(dataset[index][model]['model_prediction'])
                        try:
                            model_prediction = dataset[index][model]['model_prediction']
                            model_prediction = model_prediction.replace('<','').replace('>','').replace('\n','(newline)').replace('*','')
                            s += f"""
                                | {display_model_names[model]} | {html.escape(dataset[index][model]['text'])} | {html.escape(model_prediction)} | 
"""
                        except:
                            print(f"{model} is not in {dataset_name}")
                            continue
                
                body_details = f"""
                
                    
                        | MODEL | QUESTION | MODEL PREDICTION | 
                {s}
                
                
"""
                
                st.markdown(f"""
                                {body_details}
                                
""", unsafe_allow_html=True)
            
                st.text("")
        
        st.divider()