GeminiFan207 commited on
Commit
81f26d6
·
verified ·
1 Parent(s): eb813a3

Create chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +125 -0
chatbot.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from typing import Optional, Dict, Any
4
+ import logging
5
+ import asyncio
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class Charm15Chatbot:
12
+ def __init__(
13
+ self,
14
+ model_path: str,
15
+ device: Optional[str] = None,
16
+ tokenizer_kwargs: Optional[Dict[str, Any]] = None,
17
+ model_kwargs: Optional[Dict[str, Any]] = None,
18
+ ):
19
+ """
20
+ Initialize the chatbot.
21
+
22
+ Args:
23
+ model_path (str): Path or name of the pre-trained model.
24
+ device (str, optional): Device to run the model on (e.g., "cuda" or "cpu"). Defaults to "cuda" if available.
25
+ tokenizer_kwargs (dict, optional): Additional arguments for the tokenizer.
26
+ model_kwargs (dict, optional): Additional arguments for the model.
27
+ """
28
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
29
+ self.tokenizer_kwargs = tokenizer_kwargs or {}
30
+ self.model_kwargs = model_kwargs or {}
31
+
32
+ # Load tokenizer and model
33
+ logger.info(f"Loading model and tokenizer from {model_path}...")
34
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, **self.tokenizer_kwargs)
35
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, **self.model_kwargs).to(self.device)
36
+ self.model.eval()
37
+ logger.info("Model and tokenizer loaded successfully.")
38
+
39
+ def generate_response(
40
+ self,
41
+ input_text: str,
42
+ max_length: int = 512,
43
+ temperature: float = 0.7,
44
+ top_p: float = 0.9,
45
+ top_k: Optional[int] = None,
46
+ repetition_penalty: float = 1.0,
47
+ **kwargs,
48
+ ) -> str:
49
+ """
50
+ Generate a response to the input text.
51
+
52
+ Args:
53
+ input_text (str): The input prompt.
54
+ max_length (int): Maximum length of the generated text.
55
+ temperature (float): Sampling temperature (higher = more random).
56
+ top_p (float): Top-p (nucleus) sampling.
57
+ top_k (int): Top-k sampling.
58
+ repetition_penalty (float): Penalty for repeating tokens.
59
+ **kwargs: Additional arguments for model.generate().
60
+
61
+ Returns:
62
+ str: The generated response.
63
+ """
64
+ try:
65
+ inputs = self.tokenizer(
66
+ input_text,
67
+ return_tensors="pt",
68
+ truncation=True,
69
+ max_length=1024,
70
+ ).to(self.device)
71
+
72
+ with torch.no_grad():
73
+ output = self.model.generate(
74
+ **inputs,
75
+ max_length=max_length,
76
+ temperature=temperature,
77
+ top_p=top_p,
78
+ top_k=top_k,
79
+ repetition_penalty=repetition_penalty,
80
+ pad_token_id=self.tokenizer.eos_token_id,
81
+ **kwargs,
82
+ )
83
+
84
+ response = self.tokenizer.decode(output[0], skip_special_tokens=True)
85
+ logger.info("Response generated successfully.")
86
+ return response
87
+ except Exception as e:
88
+ logger.error(f"Error generating response: {e}")
89
+ raise
90
+
91
+ async def async_generate(
92
+ self,
93
+ input_text: str,
94
+ max_length: int = 512,
95
+ temperature: float = 0.7,
96
+ top_p: float = 0.9,
97
+ top_k: Optional[int] = None,
98
+ repetition_penalty: float = 1.0,
99
+ **kwargs,
100
+ ) -> str:
101
+ """
102
+ Asynchronously generate a response to the input text.
103
+
104
+ Args:
105
+ input_text (str): The input prompt.
106
+ max_length (int): Maximum length of the generated text.
107
+ temperature (float): Sampling temperature (higher = more random).
108
+ top_p (float): Top-p (nucleus) sampling.
109
+ top_k (int): Top-k sampling.
110
+ repetition_penalty (float): Penalty for repeating tokens.
111
+ **kwargs: Additional arguments for model.generate().
112
+
113
+ Returns:
114
+ str: The generated response.
115
+ """
116
+ return await asyncio.to_thread(
117
+ self.generate_response,
118
+ input_text,
119
+ max_length=max_length,
120
+ temperature=temperature,
121
+ top_p=top_p,
122
+ top_k=top_k,
123
+ repetition_penalty=repetition_penalty,
124
+ **kwargs,
125
+ )