Abhaykoul commited on
Commit
d460687
·
verified ·
1 Parent(s): a1c55c3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -3
app.py CHANGED
@@ -5,6 +5,7 @@ import threading
5
  import queue
6
  import time
7
  import spaces
 
8
  # Model configuration
9
  model_name = "HelpingAI/Dhanishtha-2.0-preview"
10
 
@@ -31,10 +32,12 @@ def load_model():
31
 
32
  class GradioTextStreamer(TextStreamer):
33
  """Custom TextStreamer for Gradio integration"""
34
- def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=True):
35
- super().__init__(tokenizer, skip_prompt, skip_special_tokens)
 
36
  self.text_queue = queue.Queue()
37
  self.generated_text = ""
 
38
 
39
  def on_finalized_text(self, text: str, stream_end: bool = False):
40
  """Called when text is finalized"""
@@ -56,6 +59,7 @@ class GradioTextStreamer(TextStreamer):
56
  self.text_queue.get_nowait()
57
  except queue.Empty:
58
  break
 
59
  @spaces.GPU()
60
  def generate_response(message, history, max_tokens, temperature, top_p):
61
  """Generate streaming response"""
@@ -86,7 +90,7 @@ def generate_response(message, history, max_tokens, temperature, top_p):
86
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
87
 
88
  # Create and setup streamer
89
- streamer = GradioTextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
90
  streamer.reset()
91
 
92
  # Start generation in a separate thread
 
5
  import queue
6
  import time
7
  import spaces
8
+
9
  # Model configuration
10
  model_name = "HelpingAI/Dhanishtha-2.0-preview"
11
 
 
32
 
33
  class GradioTextStreamer(TextStreamer):
34
  """Custom TextStreamer for Gradio integration"""
35
+ def __init__(self, tokenizer, skip_prompt=True):
36
+ # TextStreamer only accepts tokenizer and skip_prompt parameters
37
+ super().__init__(tokenizer, skip_prompt)
38
  self.text_queue = queue.Queue()
39
  self.generated_text = ""
40
+ self.skip_special_tokens = True # Handle this manually if needed
41
 
42
  def on_finalized_text(self, text: str, stream_end: bool = False):
43
  """Called when text is finalized"""
 
59
  self.text_queue.get_nowait()
60
  except queue.Empty:
61
  break
62
+
63
  @spaces.GPU()
64
  def generate_response(message, history, max_tokens, temperature, top_p):
65
  """Generate streaming response"""
 
90
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
91
 
92
  # Create and setup streamer
93
+ streamer = GradioTextStreamer(tokenizer, skip_prompt=True)
94
  streamer.reset()
95
 
96
  # Start generation in a separate thread