Abhaykoul commited on
Commit
d104a8c
·
verified ·
1 Parent(s): d460687

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -30
app.py CHANGED
@@ -5,6 +5,8 @@ import threading
5
  import queue
6
  import time
7
  import spaces
 
 
8
 
9
  # Model configuration
10
  model_name = "HelpingAI/Dhanishtha-2.0-preview"
@@ -30,30 +32,30 @@ def load_model():
30
 
31
  print("Model loaded successfully!")
32
 
33
- class GradioTextStreamer(TextStreamer):
34
- """Custom TextStreamer for Gradio integration"""
35
- def __init__(self, tokenizer, skip_prompt=True):
36
- # TextStreamer only accepts tokenizer and skip_prompt parameters
37
- super().__init__(tokenizer, skip_prompt)
38
  self.text_queue = queue.Queue()
39
- self.generated_text = ""
40
- self.skip_special_tokens = True # Handle this manually if needed
41
-
42
- def on_finalized_text(self, text: str, stream_end: bool = False):
43
- """Called when text is finalized"""
44
- self.generated_text += text
45
- self.text_queue.put(text)
46
- if stream_end:
47
- self.text_queue.put(None)
48
-
49
- def get_generated_text(self):
50
- """Get all generated text so far"""
51
- return self.generated_text
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def reset(self):
54
- """Reset the streamer"""
55
- self.generated_text = ""
56
- # Clear the queue
57
  while not self.text_queue.empty():
58
  try:
59
  self.text_queue.get_nowait()
@@ -89,11 +91,16 @@ def generate_response(message, history, max_tokens, temperature, top_p):
89
  # Tokenize input
90
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
91
 
92
- # Create and setup streamer
93
- streamer = GradioTextStreamer(tokenizer, skip_prompt=True)
94
- streamer.reset()
95
 
96
- # Start generation in a separate thread
 
 
 
 
 
 
97
  generation_kwargs = {
98
  **model_inputs,
99
  "max_new_tokens": max_tokens,
@@ -102,17 +109,21 @@ def generate_response(message, history, max_tokens, temperature, top_p):
102
  "do_sample": True,
103
  "pad_token_id": tokenizer.eos_token_id,
104
  "streamer": streamer,
105
- "return_dict_in_generate": True
106
  }
107
 
108
- # Run generation in thread
109
  def generate():
110
  try:
 
 
111
  with torch.no_grad():
112
  model.generate(**generation_kwargs)
113
  except Exception as e:
114
- streamer.text_queue.put(f"Error: {str(e)}")
115
- streamer.text_queue.put(None)
 
 
 
116
 
117
  thread = threading.Thread(target=generate)
118
  thread.start()
@@ -121,7 +132,7 @@ def generate_response(message, history, max_tokens, temperature, top_p):
121
  generated_text = ""
122
  while True:
123
  try:
124
- new_text = streamer.text_queue.get(timeout=30)
125
  if new_text is None:
126
  break
127
  generated_text += new_text
 
5
  import queue
6
  import time
7
  import spaces
8
+ import sys
9
+ from io import StringIO
10
 
11
  # Model configuration
12
  model_name = "HelpingAI/Dhanishtha-2.0-preview"
 
32
 
33
  print("Model loaded successfully!")
34
 
35
+ class StreamCapture:
36
+ """Capture streaming output from TextStreamer"""
37
+ def __init__(self):
 
 
38
  self.text_queue = queue.Queue()
39
+ self.captured_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
+ def write(self, text):
42
+ """Capture written text"""
43
+ if text and text.strip():
44
+ self.captured_text += text
45
+ self.text_queue.put(text)
46
+ return len(text)
47
+
48
+ def flush(self):
49
+ """Flush method for compatibility"""
50
+ pass
51
+
52
+ def get_text(self):
53
+ """Get all captured text"""
54
+ return self.captured_text
55
+
56
  def reset(self):
57
+ """Reset the capture"""
58
+ self.captured_text = ""
 
59
  while not self.text_queue.empty():
60
  try:
61
  self.text_queue.get_nowait()
 
91
  # Tokenize input
92
  model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
93
 
94
+ # Create stream capture
95
+ stream_capture = StreamCapture()
 
96
 
97
+ # Create TextStreamer with our capture
98
+ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
99
+
100
+ # Temporarily redirect the streamer's output
101
+ original_stdout = sys.stdout
102
+
103
+ # Generation parameters
104
  generation_kwargs = {
105
  **model_inputs,
106
  "max_new_tokens": max_tokens,
 
109
  "do_sample": True,
110
  "pad_token_id": tokenizer.eos_token_id,
111
  "streamer": streamer,
 
112
  }
113
 
114
+ # Start generation in a separate thread
115
  def generate():
116
  try:
117
+ # Redirect stdout to capture streamer output
118
+ sys.stdout = stream_capture
119
  with torch.no_grad():
120
  model.generate(**generation_kwargs)
121
  except Exception as e:
122
+ stream_capture.text_queue.put(f"Error: {str(e)}")
123
+ finally:
124
+ # Restore stdout
125
+ sys.stdout = original_stdout
126
+ stream_capture.text_queue.put(None) # Signal end
127
 
128
  thread = threading.Thread(target=generate)
129
  thread.start()
 
132
  generated_text = ""
133
  while True:
134
  try:
135
+ new_text = stream_capture.text_queue.get(timeout=30)
136
  if new_text is None:
137
  break
138
  generated_text += new_text