import logging
from typing import Dict, List

import streamlit as st
from streamlit_tags import st_tags

from llm_guard.input_scanners.anonymize import default_entity_types
from llm_guard.output_scanners import (
    BanSubstrings,
    BanTopics,
    Bias,
    Code,
    Deanonymize,
    MaliciousURLs,
    NoRefusal,
    Refutation,
    Regex,
    Relevance,
    Sensitive,
)
from llm_guard.output_scanners.sentiment import Sentiment
from llm_guard.output_scanners.toxicity import Toxicity
from llm_guard.vault import Vault

logger = logging.getLogger("llm-guard-demo")


def init_settings() -> (List, Dict):
    all_scanners = [
        "BanSubstrings",
        "BanTopics",
        "Bias",
        "Code",
        "Deanonymize",
        "MaliciousURLs",
        "NoRefusal",
        "Refutation",
        "Regex",
        "Relevance",
        "Sensitive",
        "Sentiment",
        "Toxicity",
    ]

    st_enabled_scanners = st.sidebar.multiselect(
        "Select scanners",
        options=all_scanners,
        default=all_scanners,
        help="The list can be found here: https://laiyer-ai.github.io/llm-guard/output_scanners/bias/",
    )

    settings = {}

    if "BanSubstrings" in st_enabled_scanners:
        st_bs_expander = st.sidebar.expander(
            "Ban Substrings",
            expanded=False,
        )

        with st_bs_expander:
            st_bs_substrings = st.text_area(
                "Enter substrings to ban (one per line)",
                value="test\nhello\nworld\n",
                height=200,
            ).split("\n")

            st_bs_match_type = st.selectbox("Match type", ["str", "word"])
            st_bs_case_sensitive = st.checkbox("Case sensitive", value=False)

        settings["BanSubstrings"] = {
            "substrings": st_bs_substrings,
            "match_type": st_bs_match_type,
            "case_sensitive": st_bs_case_sensitive,
        }

    if "BanTopics" in st_enabled_scanners:
        st_bt_expander = st.sidebar.expander(
            "Ban Topics",
            expanded=False,
        )

        with st_bt_expander:
            st_bt_topics = st_tags(
                label="List of topics",
                text="Type and press enter",
                value=["politics", "religion", "money", "crime"],
                suggestions=[],
                maxtags=30,
                key="bt_topics",
            )

            st_bt_threshold = st.slider(
                label="Threshold",
                value=0.75,
                min_value=0.0,
                max_value=1.0,
                step=0.05,
                key="ban_topics_threshold",
            )

        settings["BanTopics"] = {"topics": st_bt_topics, "threshold": st_bt_threshold}

    if "Bias" in st_enabled_scanners:
        st_bias_expander = st.sidebar.expander(
            "Bias",
            expanded=False,
        )

        with st_bias_expander:
            st_bias_threshold = st.slider(
                label="Threshold",
                value=0.75,
                min_value=0.0,
                max_value=1.0,
                step=0.05,
                key="bias_threshold",
            )

        settings["Bias"] = {"threshold": st_bias_threshold}

    if "Code" in st_enabled_scanners:
        st_cd_expander = st.sidebar.expander(
            "Code",
            expanded=False,
        )

        with st_cd_expander:
            st_cd_languages = st.multiselect(
                "Programming languages",
                options=["python", "java", "javascript", "go", "php", "ruby"],
                default=["python"],
            )

            st_cd_mode = st.selectbox("Mode", ["allowed", "denied"], index=0)

        settings["Code"] = {"languages": st_cd_languages, "mode": st_cd_mode}

    if "MaliciousURLs" in st_enabled_scanners:
        st_murls_expander = st.sidebar.expander(
            "Malicious URLs",
            expanded=False,
        )

        with st_murls_expander:
            st_murls_threshold = st.slider(
                label="Threshold",
                value=0.75,
                min_value=0.0,
                max_value=1.0,
                step=0.05,
                key="murls_threshold",
            )

        settings["MaliciousURLs"] = {"threshold": st_murls_threshold}

    if "NoRefusal" in st_enabled_scanners:
        st_no_ref_expander = st.sidebar.expander(
            "No refusal",
            expanded=False,
        )

        with st_no_ref_expander:
            st_no_ref_threshold = st.slider(
                label="Threshold",
                value=0.5,
                min_value=0.0,
                max_value=1.0,
                step=0.05,
                key="no_ref_threshold",
            )

        settings["NoRefusal"] = {"threshold": st_no_ref_threshold}

    if "Refutation" in st_enabled_scanners:
        st_refu_expander = st.sidebar.expander(
            "Refutation",
            expanded=False,
        )

        with st_refu_expander:
            st_refu_threshold = st.slider(
                label="Threshold",
                value=0.5,
                min_value=0.0,
                max_value=1.0,
                step=0.05,
                key="refu_threshold",
            )

        settings["Refutation"] = {"threshold": st_refu_threshold}

    if "Regex" in st_enabled_scanners:
        st_regex_expander = st.sidebar.expander(
            "Regex",
            expanded=False,
        )

        with st_regex_expander:
            st_regex_patterns = st.text_area(
                "Enter patterns to ban (one per line)",
                value="Bearer [A-Za-z0-9-._~+/]+",
                height=200,
            ).split("\n")

            st_regex_type = st.selectbox(
                "Match type",
                ["good", "bad"],
                index=1,
                help="good: allow only good patterns, bad: ban bad patterns",
            )

        settings["Regex"] = {"patterns": st_regex_patterns, "type": st_regex_type}

    if "Relevance" in st_enabled_scanners:
        st_rele_expander = st.sidebar.expander(
            "Relevance",
            expanded=False,
        )

        with st_rele_expander:
            st_rele_threshold = st.slider(
                label="Threshold",
                value=0.5,
                min_value=-1.0,
                max_value=1.0,
                step=0.05,
                key="rele_threshold",
                help="The minimum cosine similarity (-1 to 1) between the prompt and output for the output to be considered relevant.",
            )

        settings["Relevance"] = {"threshold": st_rele_threshold}

    if "Sensitive" in st_enabled_scanners:
        st_sens_expander = st.sidebar.expander(
            "Sensitive",
            expanded=False,
        )

        with st_sens_expander:
            st_sens_entity_types = st_tags(
                label="Sensitive entities",
                text="Type and press enter",
                value=default_entity_types,
                suggestions=default_entity_types
                + ["DATE_TIME", "NRP", "LOCATION", "MEDICAL_LICENSE", "US_PASSPORT"],
                maxtags=30,
                key="sensitive_entity_types",
            )
            st.caption(
                "Check all supported entities: https://microsoft.github.io/presidio/supported_entities/#list-of-supported-entities"
            )

        settings["Sensitive"] = {"entity_types": st_sens_entity_types}

    if "Sentiment" in st_enabled_scanners:
        st_sent_expander = st.sidebar.expander(
            "Sentiment",
            expanded=False,
        )

        with st_sent_expander:
            st_sent_threshold = st.slider(
                label="Threshold",
                value=-0.1,
                min_value=-1.0,
                max_value=1.0,
                step=0.1,
                key="sentiment_threshold",
                help="Negative values are negative sentiment, positive values are positive sentiment",
            )

        settings["Sentiment"] = {"threshold": st_sent_threshold}

    if "Toxicity" in st_enabled_scanners:
        st_tox_expander = st.sidebar.expander(
            "Toxicity",
            expanded=False,
        )

        with st_tox_expander:
            st_tox_threshold = st.slider(
                label="Threshold",
                value=0.0,
                min_value=-1.0,
                max_value=1.0,
                step=0.05,
                key="toxicity_threshold",
                help="A negative value (closer to 0 as the label output) indicates toxicity in the text, while a positive logit (closer to 1 as the label output) suggests non-toxicity.",
            )

        settings["Toxicity"] = {"threshold": st_tox_threshold}

    return st_enabled_scanners, settings


def get_scanner(scanner_name: str, vault: Vault, settings: Dict):
    logger.debug(f"Initializing {scanner_name} scanner")

    if scanner_name == "BanSubstrings":
        return BanSubstrings(
            substrings=settings["substrings"],
            match_type=settings["match_type"],
            case_sensitive=settings["case_sensitive"],
        )

    if scanner_name == "BanTopics":
        return BanTopics(topics=settings["topics"], threshold=settings["threshold"])

    if scanner_name == "Bias":
        return Bias(threshold=settings["threshold"])

    if scanner_name == "Deanonymize":
        return Deanonymize(vault=vault)

    if scanner_name == "Code":
        mode = settings["mode"]

        allowed_languages = None
        denied_languages = None
        if mode == "allowed":
            allowed_languages = settings["languages"]
        elif mode == "denied":
            denied_languages = settings["languages"]

        return Code(allowed=allowed_languages, denied=denied_languages)

    if scanner_name == "MaliciousURLs":
        return MaliciousURLs(threshold=settings["threshold"])

    if scanner_name == "NoRefusal":
        return NoRefusal(threshold=settings["threshold"])

    if scanner_name == "Refutation":
        return Refutation(threshold=settings["threshold"])

    if scanner_name == "Regex":
        match_type = settings["type"]

        good_patterns = None
        bad_patterns = None
        if match_type == "good":
            good_patterns = settings["patterns"]
        elif match_type == "bad":
            bad_patterns = settings["patterns"]

        return Regex(good_patterns=good_patterns, bad_patterns=bad_patterns)

    if scanner_name == "Relevance":
        return Relevance(threshold=settings["threshold"])

    if scanner_name == "Sensitive":
        return Sensitive(entity_types=settings["entity_types"])

    if scanner_name == "Sentiment":
        return Sentiment(threshold=settings["threshold"])

    if scanner_name == "Toxicity":
        return Toxicity(threshold=settings["threshold"])

    raise ValueError("Unknown scanner name")


def scan(
    vault: Vault, enabled_scanners: List[str], settings: Dict, prompt: str, text: str
) -> (str, Dict[str, bool], Dict[str, float]):
    sanitized_output = text
    results_valid = {}
    results_score = {}

    with st.status("Scanning output...", expanded=True) as status:
        for scanner_name in enabled_scanners:
            st.write(f"{scanner_name} scanner...")
            scanner = get_scanner(
                scanner_name, vault, settings[scanner_name] if scanner_name in settings else {}
            )
            sanitized_output, is_valid, risk_score = scanner.scan(prompt, sanitized_output)
            results_valid[scanner_name] = is_valid
            results_score[scanner_name] = risk_score
        status.update(label="Scanning complete", state="complete", expanded=False)

    return sanitized_output, results_valid, results_score