#embedding_utils.py

from sentence_transformers import SentenceTransformer
from llama_index.core import SimpleDirectoryReader
from huggingface_hub import login
from typing import List, Tuple
from dotenv import load_dotenv
import numpy as np
import os
import tempfile
from docx import Document
import tempfile
import os
import logging


# Load environment variables from .env file
#load_dotenv()

# Set the HF_HOME environment variable to a writable directory
os.environ["HF_HOME"] = "/tmp/huggingface_cache"

# Ensure the cache directory exists
cache_dir = os.environ["HF_HOME"]
if not os.path.exists(cache_dir):
    os.makedirs(cache_dir)
    
# Load Hugging Face token from environment variable
huggingface_token = os.getenv('HUGGINGFACE_HUB_TOKEN')


if huggingface_token:
    login(token=huggingface_token, add_to_git_credential=True, write_permission=True)
else:
    raise ValueError("Hugging Face token is not set. Please set the HUGGINGFACE_HUB_TOKEN environment variable.")

# Load model and tokenizer with authentication
model_name = 'nvidia/NV-Embed-v1'

model_name = 'nomic-ai/nomic-embed-text-v1.5'
model = SentenceTransformer('nomic-ai/nomic-embed-text-v1.5', trust_remote_code=True)
model.max_seq_length = 4096
model.tokenizer.padding_side = "right"


def read_document(file_content: bytes, file_id: int, file_format: str) -> str:
    """Extract text content from a document file depending on its format."""
    try:
        # Create a temporary directory to save the file
        with tempfile.TemporaryDirectory() as temp_dir:
            file_path = os.path.join(temp_dir, f"document_{file_id}.{file_format}")
            
            # Save the content to the file
            with open(file_path, 'wb') as temp_file:
                temp_file.write(file_content)

            # Handle different file formats
            if file_format.lower() == 'docx':
                text_content = extract_text_from_docx(file_path)
            elif file_format.lower() == 'pdf':
                text_content = extract_text_from_pdf(file_path)
            elif file_format.lower() in ['txt', 'md', 'csv']:
                reader = SimpleDirectoryReader(input_files=[file_path])
                documents = reader.load_data()
                text_content = documents[0].text if documents else ''
            else:
                raise ValueError(f"Unsupported file format: {file_format}")

            if text_content:
                return text_content
            else:
                raise ValueError("No content extracted from the document.")
    
    except Exception as e:
        logging.error(f"Error reading document: {e}")
        raise

def extract_text_from_docx(file_path: str) -> str:
    """Extract text from a DOCX file."""
    try:
        doc = Document(file_path)
        full_text = [para.text for para in doc.paragraphs]
        return '\n'.join(full_text)
    except Exception as e:
        logging.error(f"Error extracting text from DOCX file: {e}")
        raise

def extract_text_from_pdf(file_path: str) -> str:
    """Extract text from a PDF file."""
    import pdfplumber
    try:
        with pdfplumber.open(file_path) as pdf:
            full_text = [page.extract_text() for page in pdf.pages]
        return '\n'.join(full_text).strip()
    except Exception as e:
        logging.error(f"Error extracting text from PDF file: {e}")
        raise



def cumulative_semantic_chunking( text: str, max_chunk_size: int, similarity_threshold: float, embedding_model: SentenceTransformer = model) -> List[str]:
    """Cumulative semantic chunking using sentence embeddings."""
    sentences = text.split('.')

    # Encode sentences
    sentence_embeddings = model.encode(sentences)

    chunks = []
    current_chunk = sentences[0]
    #print('current chunk',current_chunk)
    current_embedding = sentence_embeddings[0]

    for sentence, embedding in zip(sentences[1:], sentence_embeddings[1:]):
        combined_chunk = current_chunk + '. ' + sentence
        combined_embedding = (current_embedding * len(current_chunk.split()) + embedding * len(sentence.split())) / (len(current_chunk.split()) + len(sentence.split()))

        similarity = np.dot(current_embedding, combined_embedding) / (np.linalg.norm(current_embedding) * np.linalg.norm(combined_embedding))

        if similarity >= similarity_threshold and len(combined_chunk) <= max_chunk_size:
            current_chunk = combined_chunk
            current_embedding = combined_embedding
        else:
            chunks.append(current_chunk.strip())
            current_chunk = sentence
            current_embedding = embedding

    if current_chunk:
        chunks.append(current_chunk.strip())

    #print('chunks',chunks)
    return chunks

# def embed_chunks(chunks: List[str]) -> List[np.ndarray]:
#     """Embed the chunks using the SentenceTransformer model."""
#     return model.encode(chunks)

def embed_chunks(chunks: List[str]) -> Tuple[List[np.ndarray], int]:
    """Embed the chunks using the SentenceTransformer model and return embeddings along with the total token count."""
    total_tokens = 0
    embeddings = []

    for chunk in chunks:
        tokens = model.tokenizer.encode(chunk, add_special_tokens=False)
        total_tokens += len(tokens)
        embedding = model.encode(chunk)
        embeddings.append(embedding)

    return embeddings, total_tokens