|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from typing import Optional, Dict, Any |
|
import logging |
|
import asyncio |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") |
|
logger = logging.getLogger(__name__) |
|
|
|
class Charm15Chatbot: |
|
def __init__( |
|
self, |
|
model_path: str, |
|
device: Optional[str] = None, |
|
tokenizer_kwargs: Optional[Dict[str, Any]] = None, |
|
model_kwargs: Optional[Dict[str, Any]] = None, |
|
): |
|
""" |
|
Initialize the chatbot. |
|
|
|
Args: |
|
model_path (str): Path or name of the pre-trained model. |
|
device (str, optional): Device to run the model on (e.g., "cuda" or "cpu"). Defaults to "cuda" if available. |
|
tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer. |
|
model_kwargs (dict, optional): Additional arguments for the model. |
|
""" |
|
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") |
|
self.tokenizer_kwargs = tokenizer_kwargs or {} |
|
self.model_kwargs = model_kwargs or {} |
|
|
|
|
|
logger.info(f"Loading model and tokenizer from {model_path}...") |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_path, **self.tokenizer_kwargs) |
|
self.model = AutoModelForCausalLM.from_pretrained(model_path, **self.model_kwargs).to(self.device) |
|
self.model.eval() |
|
logger.info("Model and tokenizer loaded successfully.") |
|
|
|
def generate_response( |
|
self, |
|
input_text: str, |
|
max_length: int = 512, |
|
temperature: float = 0.7, |
|
top_p: float = 0.9, |
|
top_k: Optional[int] = None, |
|
repetition_penalty: float = 1.0, |
|
**kwargs, |
|
) -> str: |
|
""" |
|
Generate a response to the input text. |
|
|
|
Args: |
|
input_text (str): The input prompt. |
|
max_length (int): Maximum length of the generated text. |
|
temperature (float): Sampling temperature (higher = more random). |
|
top_p (float): Top-p (nucleus) sampling. |
|
top_k (int): Top-k sampling. |
|
repetition_penalty (float): Penalty for repeating tokens. |
|
**kwargs: Additional arguments for model.generate(). |
|
|
|
Returns: |
|
str: The generated response. |
|
""" |
|
try: |
|
inputs = self.tokenizer( |
|
input_text, |
|
return_tensors="pt", |
|
truncation=True, |
|
max_length=1024, |
|
).to(self.device) |
|
|
|
with torch.no_grad(): |
|
output = self.model.generate( |
|
**inputs, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
**kwargs, |
|
) |
|
|
|
response = self.tokenizer.decode(output[0], skip_special_tokens=True) |
|
logger.info("Response generated successfully.") |
|
return response |
|
except Exception as e: |
|
logger.error(f"Error generating response: {e}") |
|
raise |
|
|
|
async def async_generate( |
|
self, |
|
input_text: str, |
|
max_length: int = 512, |
|
temperature: float = 0.7, |
|
top_p: float = 0.9, |
|
top_k: Optional[int] = None, |
|
repetition_penalty: float = 1.0, |
|
**kwargs, |
|
) -> str: |
|
""" |
|
Asynchronously generate a response to the input text. |
|
|
|
Args: |
|
input_text (str): The input prompt. |
|
max_length (int): Maximum length of the generated text. |
|
temperature (float): Sampling temperature (higher = more random). |
|
top_p (float): Top-p (nucleus) sampling. |
|
top_k (int): Top-k sampling. |
|
repetition_penalty (float): Penalty for repeating tokens. |
|
**kwargs: Additional arguments for model.generate(). |
|
|
|
Returns: |
|
str: The generated response. |
|
""" |
|
return await asyncio.to_thread( |
|
self.generate_response, |
|
input_text, |
|
max_length=max_length, |
|
temperature=temperature, |
|
top_p=top_p, |
|
top_k=top_k, |
|
repetition_penalty=repetition_penalty, |
|
**kwargs, |
|
) |