Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -5,6 +5,8 @@ import threading
|
|
5 |
import queue
|
6 |
import time
|
7 |
import spaces
|
|
|
|
|
8 |
|
9 |
# Model configuration
|
10 |
model_name = "HelpingAI/Dhanishtha-2.0-preview"
|
@@ -30,30 +32,30 @@ def load_model():
|
|
30 |
|
31 |
print("Model loaded successfully!")
|
32 |
|
33 |
-
class
|
34 |
-
"""
|
35 |
-
def __init__(self
|
36 |
-
# TextStreamer only accepts tokenizer and skip_prompt parameters
|
37 |
-
super().__init__(tokenizer, skip_prompt)
|
38 |
self.text_queue = queue.Queue()
|
39 |
-
self.
|
40 |
-
self.skip_special_tokens = True # Handle this manually if needed
|
41 |
-
|
42 |
-
def on_finalized_text(self, text: str, stream_end: bool = False):
|
43 |
-
"""Called when text is finalized"""
|
44 |
-
self.generated_text += text
|
45 |
-
self.text_queue.put(text)
|
46 |
-
if stream_end:
|
47 |
-
self.text_queue.put(None)
|
48 |
-
|
49 |
-
def get_generated_text(self):
|
50 |
-
"""Get all generated text so far"""
|
51 |
-
return self.generated_text
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
def reset(self):
|
54 |
-
"""Reset the
|
55 |
-
self.
|
56 |
-
# Clear the queue
|
57 |
while not self.text_queue.empty():
|
58 |
try:
|
59 |
self.text_queue.get_nowait()
|
@@ -89,11 +91,16 @@ def generate_response(message, history, max_tokens, temperature, top_p):
|
|
89 |
# Tokenize input
|
90 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
91 |
|
92 |
-
# Create
|
93 |
-
|
94 |
-
streamer.reset()
|
95 |
|
96 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
generation_kwargs = {
|
98 |
**model_inputs,
|
99 |
"max_new_tokens": max_tokens,
|
@@ -102,17 +109,21 @@ def generate_response(message, history, max_tokens, temperature, top_p):
|
|
102 |
"do_sample": True,
|
103 |
"pad_token_id": tokenizer.eos_token_id,
|
104 |
"streamer": streamer,
|
105 |
-
"return_dict_in_generate": True
|
106 |
}
|
107 |
|
108 |
-
#
|
109 |
def generate():
|
110 |
try:
|
|
|
|
|
111 |
with torch.no_grad():
|
112 |
model.generate(**generation_kwargs)
|
113 |
except Exception as e:
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
116 |
|
117 |
thread = threading.Thread(target=generate)
|
118 |
thread.start()
|
@@ -121,7 +132,7 @@ def generate_response(message, history, max_tokens, temperature, top_p):
|
|
121 |
generated_text = ""
|
122 |
while True:
|
123 |
try:
|
124 |
-
new_text =
|
125 |
if new_text is None:
|
126 |
break
|
127 |
generated_text += new_text
|
|
|
5 |
import queue
|
6 |
import time
|
7 |
import spaces
|
8 |
+
import sys
|
9 |
+
from io import StringIO
|
10 |
|
11 |
# Model configuration
|
12 |
model_name = "HelpingAI/Dhanishtha-2.0-preview"
|
|
|
32 |
|
33 |
print("Model loaded successfully!")
|
34 |
|
35 |
+
class StreamCapture:
|
36 |
+
"""Capture streaming output from TextStreamer"""
|
37 |
+
def __init__(self):
|
|
|
|
|
38 |
self.text_queue = queue.Queue()
|
39 |
+
self.captured_text = ""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
def write(self, text):
|
42 |
+
"""Capture written text"""
|
43 |
+
if text and text.strip():
|
44 |
+
self.captured_text += text
|
45 |
+
self.text_queue.put(text)
|
46 |
+
return len(text)
|
47 |
+
|
48 |
+
def flush(self):
|
49 |
+
"""Flush method for compatibility"""
|
50 |
+
pass
|
51 |
+
|
52 |
+
def get_text(self):
|
53 |
+
"""Get all captured text"""
|
54 |
+
return self.captured_text
|
55 |
+
|
56 |
def reset(self):
|
57 |
+
"""Reset the capture"""
|
58 |
+
self.captured_text = ""
|
|
|
59 |
while not self.text_queue.empty():
|
60 |
try:
|
61 |
self.text_queue.get_nowait()
|
|
|
91 |
# Tokenize input
|
92 |
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
93 |
|
94 |
+
# Create stream capture
|
95 |
+
stream_capture = StreamCapture()
|
|
|
96 |
|
97 |
+
# Create TextStreamer with our capture
|
98 |
+
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
99 |
+
|
100 |
+
# Temporarily redirect the streamer's output
|
101 |
+
original_stdout = sys.stdout
|
102 |
+
|
103 |
+
# Generation parameters
|
104 |
generation_kwargs = {
|
105 |
**model_inputs,
|
106 |
"max_new_tokens": max_tokens,
|
|
|
109 |
"do_sample": True,
|
110 |
"pad_token_id": tokenizer.eos_token_id,
|
111 |
"streamer": streamer,
|
|
|
112 |
}
|
113 |
|
114 |
+
# Start generation in a separate thread
|
115 |
def generate():
|
116 |
try:
|
117 |
+
# Redirect stdout to capture streamer output
|
118 |
+
sys.stdout = stream_capture
|
119 |
with torch.no_grad():
|
120 |
model.generate(**generation_kwargs)
|
121 |
except Exception as e:
|
122 |
+
stream_capture.text_queue.put(f"Error: {str(e)}")
|
123 |
+
finally:
|
124 |
+
# Restore stdout
|
125 |
+
sys.stdout = original_stdout
|
126 |
+
stream_capture.text_queue.put(None) # Signal end
|
127 |
|
128 |
thread = threading.Thread(target=generate)
|
129 |
thread.start()
|
|
|
132 |
generated_text = ""
|
133 |
while True:
|
134 |
try:
|
135 |
+
new_text = stream_capture.text_queue.get(timeout=30)
|
136 |
if new_text is None:
|
137 |
break
|
138 |
generated_text += new_text
|