import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Optional, Dict, Any import logging import asyncio # Configure logging logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) class Charm15Chatbot: def __init__( self, model_path: str, device: Optional[str] = None, tokenizer_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None, ): """ Initialize the chatbot. Args: model_path (str): Path or name of the pre-trained model. device (str, optional): Device to run the model on (e.g., "cuda" or "cpu"). Defaults to "cuda" if available. tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. model_kwargs (dict, optional): Additional arguments for the model. """ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") self.tokenizer_kwargs = tokenizer_kwargs or {} self.model_kwargs = model_kwargs or {} # Load tokenizer and model logger.info(f"Loading model and tokenizer from {model_path}...") self.tokenizer = AutoTokenizer.from_pretrained(model_path, **self.tokenizer_kwargs) self.model = AutoModelForCausalLM.from_pretrained(model_path, **self.model_kwargs).to(self.device) self.model.eval() logger.info("Model and tokenizer loaded successfully.") def generate_response( self, input_text: str, max_length: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: Optional[int] = None, repetition_penalty: float = 1.0, **kwargs, ) -> str: """ Generate a response to the input text. Args: input_text (str): The input prompt. max_length (int): Maximum length of the generated text. temperature (float): Sampling temperature (higher = more random). top_p (float): Top-p (nucleus) sampling. top_k (int): Top-k sampling. repetition_penalty (float): Penalty for repeating tokens. **kwargs: Additional arguments for model.generate(). Returns: str: The generated response. """ try: inputs = self.tokenizer( input_text, return_tensors="pt", truncation=True, max_length=1024, ).to(self.device) with torch.no_grad(): output = self.model.generate( **inputs, max_length=max_length, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, pad_token_id=self.tokenizer.eos_token_id, **kwargs, ) response = self.tokenizer.decode(output[0], skip_special_tokens=True) logger.info("Response generated successfully.") return response except Exception as e: logger.error(f"Error generating response: {e}") raise async def async_generate( self, input_text: str, max_length: int = 512, temperature: float = 0.7, top_p: float = 0.9, top_k: Optional[int] = None, repetition_penalty: float = 1.0, **kwargs, ) -> str: """ Asynchronously generate a response to the input text. Args: input_text (str): The input prompt. max_length (int): Maximum length of the generated text. temperature (float): Sampling temperature (higher = more random). top_p (float): Top-p (nucleus) sampling. top_k (int): Top-k sampling. repetition_penalty (float): Penalty for repeating tokens. **kwargs: Additional arguments for model.generate(). Returns: str: The generated response. """ return await asyncio.to_thread( self.generate_response, input_text, max_length=max_length, temperature=temperature, top_p=top_p, top_k=top_k, repetition_penalty=repetition_penalty, **kwargs, )