richardr1126 commited on
Commit
3247c3a
·
1 Parent(s): b748e45

testing num return sequences

Browse files
Files changed (2) hide show
  1. README.md +1 -1
  2. app.py +21 -12
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.37.0
8
- app_file: app-ngrok.py
9
  pinned: true
10
  license: bigcode-openrail-m
11
  tags:
 
5
  colorTo: purple
6
  sdk: gradio
7
  sdk_version: 3.37.0
8
+ app_file: app.py
9
  pinned: true
10
  license: bigcode-openrail-m
11
  tags:
app.py CHANGED
@@ -139,7 +139,7 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
139
 
140
  input_ids = tok(messages, return_tensors="pt").input_ids
141
  input_ids = input_ids.to(m.device)
142
- streamer = TextIteratorStreamer(tok, timeout=1000.0, skip_prompt=True, skip_special_tokens=True)
143
  generate_kwargs = dict(
144
  input_ids=input_ids,
145
  max_new_tokens=max_new_tokens,
@@ -147,27 +147,36 @@ def generate(input_message: str, db_info="", temperature=0.2, top_p=0.9, top_k=0
147
  top_p=top_p,
148
  top_k=top_k,
149
  repetition_penalty=repetition_penalty,
150
- streamer=streamer,
151
  stopping_criteria=StoppingCriteriaList([stop]),
152
  num_return_sequences=num_return_sequences,
153
  num_beams=num_beams,
154
  do_sample=do_sample,
155
  )
156
 
157
- stream_complete = Event()
158
 
159
- def generate_and_signal_complete():
160
- m.generate(**generate_kwargs)
161
- stream_complete.set()
162
 
163
- t1 = Thread(target=generate_and_signal_complete)
164
- t1.start()
165
 
166
- partial_text = ""
167
- for new_text in streamer:
168
- partial_text += new_text
169
 
170
- output = format(partial_text) if format_sql else partial_text
 
 
 
 
 
 
 
 
 
 
 
171
 
172
  if log:
173
  # Log the request to Firestore
 
139
 
140
  input_ids = tok(messages, return_tensors="pt").input_ids
141
  input_ids = input_ids.to(m.device)
142
+ #streamer = TextIteratorStreamer(tok, timeout=1000.0, skip_prompt=True, skip_special_tokens=True)
143
  generate_kwargs = dict(
144
  input_ids=input_ids,
145
  max_new_tokens=max_new_tokens,
 
147
  top_p=top_p,
148
  top_k=top_k,
149
  repetition_penalty=repetition_penalty,
150
+ #streamer=streamer,
151
  stopping_criteria=StoppingCriteriaList([stop]),
152
  num_return_sequences=num_return_sequences,
153
  num_beams=num_beams,
154
  do_sample=do_sample,
155
  )
156
 
157
+ #stream_complete = Event()
158
 
159
+ # def generate_and_signal_complete():
160
+ # m.generate(**generate_kwargs)
161
+ # stream_complete.set()
162
 
163
+ # t1 = Thread(target=generate_and_signal_complete)
164
+ # t1.start()
165
 
166
+ tokens = m.generate(**generate_kwargs)
 
 
167
 
168
+ responses = []
169
+ for response in tokens:
170
+ response_text = tok.decode(response, skip_special_tokens=True)
171
+
172
+ # Only take what comes after ### Response:
173
+ response_text = response_text.split("### Response:")[1].strip()
174
+
175
+ formatted_text = format(response_text) if format_sql else response_text
176
+ responses.append(formatted_text)
177
+
178
+ # Concat responses to be a single string seperated by a newline
179
+ output = "\n".join(responses)
180
 
181
  if log:
182
  # Log the request to Firestore