import os
import asyncio
import discord
from discord.ext import commands
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import logging

# 로깅 설정
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

MODEL = "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct"
DISCORD_TOKEN = os.getenv("DISCORD_TOKEN")
DISCORD_CHANNEL_ID = int(os.getenv("DISCORD_CHANNEL_ID"))

logger.info(f"Discord Token: {'Set' if DISCORD_TOKEN else 'Not Set'}")
logger.info(f"Discord Channel ID: {DISCORD_CHANNEL_ID}")

device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Using device: {device}")

logger.info("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL)
logger.info("Tokenizer loaded successfully")

logger.info("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    ignore_mismatched_sizes=True
)
logger.info("Model loaded successfully")

intents = discord.Intents.default()
intents.message_content = True
bot = commands.Bot(command_prefix="!", intents=intents)

async def generate_response(message, history, system_prompt):
    logger.info(f"Generating response for message: {message[:50]}...")  # Log first 50 chars of message
    conversation = [{"role": "system", "content": system_prompt}]
    for prompt, answer in history:
        conversation.extend([
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": answer},
        ])
    conversation.append({"role": "user", "content": message})

    inputs = tokenizer.apply_chat_template(
        conversation,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(device)

    with torch.no_grad():
        output = model.generate(
            inputs,
            max_new_tokens=1024,
            do_sample=True,
            top_p=1.0,
            top_k=50,
            temperature=1.0,
            pad_token_id=0,
            eos_token_id=361
        )

    response = tokenizer.decode(output[0], skip_special_tokens=True)
    logger.info(f"Generated response: {response[:100]}...")  # Log first 100 chars of response
    return response.split("Assistant:")[-1].strip()

@bot.event
async def on_ready():
    logger.info(f"{bot.user} has connected to Discord!")

@bot.event
async def on_message(message):
    if message.author == bot.user:
        return

    logger.info(f"Received message: {message.content[:50]}...")  # Log first 50 chars of message
    logger.info(f"Message channel ID: {message.channel.id}")

    if message.channel.id != DISCORD_CHANNEL_ID:
        logger.info("Message not in target channel")
        return

    try:
        response = await generate_response(message.content, [], "You are EXAONE model from LG AI Research, a helpful assistant.")
        
        chunks = [response[i:i+2000] for i in range(0, len(response), 2000)]
        
        for i, chunk in enumerate(chunks):
            await message.channel.send(chunk)
            logger.info(f"Sent response chunk {i+1}/{len(chunks)}")
    except Exception as e:
        logger.error(f"Error generating or sending response: {e}")

if __name__ == "__main__":
    import subprocess
    logger.info("Starting web.py...")
    subprocess.Popen(["python", "web.py"])
    logger.info("Running Discord bot...")
    bot.run(DISCORD_TOKEN)