import numpy as np
from models import chat_with_model, embed
from prompts import create_gen_prompt, create_judge_prompt
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
import threading
import streamlit as st  # Import Streamlit
import queue
def generate_answer(question, previous_answers, model_name, open_router_key, openai_api_key):
    """Generates an answer to a question using the specified language model."""
    gen_prompt = create_gen_prompt(question, previous_answers)
    try:
        new_answer = chat_with_model(prompt=gen_prompt, model=model_name, open_router_key=open_router_key,
                                             openai_api_key=openai_api_key)
        return new_answer
    except Exception as e:
        st.write(f"Error generating answer: {str(e)}",
                         unsafe_allow_html=True)
        return None
def evaluate_answer(question, new_answer, open_router_key, openai_api_key):
    """Evaluates the coherence and novelty of an answer."""
    judge_prompt = create_judge_prompt(question, new_answer)
    judge = "openai/gpt-4o-mini"
    try:
        judge_response = chat_with_model(prompt=judge_prompt, model=judge, open_router_key=open_router_key,
                                                 openai_api_key=openai_api_key)
        coherence_score = int(judge_response.split("")[1].split("")[0])
        return coherence_score
    except Exception as e:
        st.write(f"Error getting judge response: {str(e)}",
                         unsafe_allow_html=True)
        return None
def process_question(question, model_name, open_router_key, openai_api_key, result_queue):
    start_time = time.time()
    # st.write(f"{question}", unsafe_allow_html=True)
    previous_answers = []
    question_novelty = 0
    try:
        while True:
            new_answer = generate_answer(question, previous_answers, model_name, open_router_key, openai_api_key)
            if new_answer is None:
                break
            coherence_score = evaluate_answer(question, new_answer, open_router_key, openai_api_key)
            if coherence_score is None:
                break
            if coherence_score <= 3:
                # st.write("Output is incoherent. Moving to next question.",
                #          unsafe_allow_html=True)
                break
            novelty_score = get_novelty_score(new_answer, previous_answers, openai_api_key)
            if novelty_score < 0.1:
                # st.write("Output is redundant. Moving to next question.",
                #          unsafe_allow_html=True)
                break
            # Append results to the queue instead of using st.write
            result_queue.put({
                "type": "answer",
                "question": question,
                "answer": new_answer,
                "coherence_score": coherence_score,
                "novelty_score": novelty_score,
                "results": [
                    {
                        "question": question,
                        "answers": previous_answers.copy() + [new_answer],  # Include the new answer
                        "coherence_score": coherence_score,
                        "novelty_score": question_novelty + novelty_score  # Accumulate novelty score
                    }
                ]
            })
            previous_answers.append(new_answer)
            question_novelty += novelty_score
    except Exception as e:
        result_queue.put({"type": "error", "message": str(e)})
    time_taken = time.time() - start_time
    result_queue.put({
        "type": "summary",
        "question": question,
        "total_novelty": question_novelty,
        "time_taken": time_taken
    })
    return question_novelty, [
        {
            "question": question,
            "answers": previous_answers,
            "coherence_score": coherence_score,
            "novelty_score": question_novelty
        }
    ]
def get_novelty_score(new_answer: str, previous_answers: list, openai_api_key):
    new_embedding = embed(new_answer, openai_api_key)
    # If there are no previous answers, return maximum novelty
    if not previous_answers:
        return 1.0
    previous_embeddings = [embed(answer, openai_api_key) for answer in previous_answers]
    similarities = [
        np.dot(new_embedding, prev_embedding) /
        (np.linalg.norm(new_embedding) * np.linalg.norm(prev_embedding))
        for prev_embedding in previous_embeddings
    ]
    max_similarity = max(similarities)
    novelty = 1 - max_similarity
    return novelty
def benchmark_model_multithreaded(model_name, questions, open_router_key, openai_api_key, max_threads=None):
    novelty_score = 0
    results = []
    result_queue = queue.Queue()  # Create a queue for communication
    # Use max_threads if provided, otherwise default to the number of questions
    if max_threads is None:
        max_workers = len(questions)
    else:
        max_workers = max_threads
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit tasks to the thread pool
        future_to_question = {
            executor.submit(process_question, question, model_name, open_router_key, openai_api_key, result_queue): question
            for question in questions
        }
        # Process results from the queue in the main thread
        while True:
            try:
                result = result_queue.get_nowait()
                if result["type"] == "answer":
                    st.write(f"**Question:** {result['question']}")
                    st.write(f"**New Answer:**\n{result['answer']}")
                    st.write(f"Coherence Score: {result['coherence_score']}",
                             unsafe_allow_html=True)
                    st.write(f"**Novelty Score:** {result['novelty_score']}")
                    results.extend(result["results"])  # Add results here
                elif result["type"] == "summary":
                    st.write(f"Total novelty score for question '{result['question']}': {result['total_novelty']}",
                             unsafe_allow_html=True)
                    st.write(f"Time taken: {result['time_taken']} seconds",
                             unsafe_allow_html=True)
                elif result["type"] == "error":
                    st.write(f"Error in thread: {result['message']}",
                             unsafe_allow_html=True)
            except queue.Empty:
                if not any(future.running() for future in future_to_question.keys()):
                    break  # All tasks are done
    st.write(f"Final total novelty score across all questions: {novelty_score}",
             unsafe_allow_html=True)
    return results
def benchmark_model_sequential(model_name, questions, open_router_key, openai_api_key, progress=0, progress_lock=None):
    novelty_score = 0
    results = []
    for i, question in enumerate(questions):
        question_novelty, question_results = process_question(question, model_name, open_router_key, openai_api_key,
                                                              progress_lock, i, len(questions), progress)
        novelty_score += question_novelty
        results.extend(question_results)
        st.write(
            f"Total novelty score across processed questions: {novelty_score}",
            unsafe_allow_html=True)  # Display progress after each question
    st.write(f"Final total novelty score across all questions: {novelty_score}",
             unsafe_allow_html=True)
    return results