Tuchuanhuhuhu commited on
Commit
b0ccc7f
·
2 Parent(s): 25b27e3 c12b724

Merge branch 'local-upstream-sync'

Browse files
modules/chat_func.py CHANGED
@@ -13,6 +13,9 @@ import colorama
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
 
 
 
16
 
17
  from modules.presets import *
18
  from modules.llama_func import *
@@ -103,13 +106,17 @@ def stream_predict(
103
  else:
104
  chatbot.append((inputs, ""))
105
  user_token_count = 0
 
 
 
 
106
  if len(all_token_counts) == 0:
107
  system_prompt_token_count = count_token(construct_system(system_prompt))
108
  user_token_count = (
109
- count_token(construct_user(inputs)) + system_prompt_token_count
110
  )
111
  else:
112
- user_token_count = count_token(construct_user(inputs))
113
  all_token_counts.append(user_token_count)
114
  logging.info(f"输入token计数: {user_token_count}")
115
  yield get_return_value()
@@ -137,6 +144,8 @@ def stream_predict(
137
  yield get_return_value()
138
  error_json_str = ""
139
 
 
 
140
  for chunk in tqdm(response.iter_lines()):
141
  if counter == 0:
142
  counter += 1
@@ -201,7 +210,10 @@ def predict_all(
201
  chatbot.append((fake_input, ""))
202
  else:
203
  chatbot.append((inputs, ""))
204
- all_token_counts.append(count_token(construct_user(inputs)))
 
 
 
205
  try:
206
  response = get_response(
207
  openai_api_key,
@@ -224,13 +236,22 @@ def predict_all(
224
  status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
225
  return chatbot, history, status_text, all_token_counts
226
  response = json.loads(response.text)
227
- content = response["choices"][0]["message"]["content"]
228
- history[-1] = construct_assistant(content)
229
- chatbot[-1] = (chatbot[-1][0], content+display_append)
230
- total_token_count = response["usage"]["total_tokens"]
231
- all_token_counts[-1] = total_token_count - sum(all_token_counts)
232
- status_text = construct_token_message(total_token_count)
233
- return chatbot, history, status_text, all_token_counts
 
 
 
 
 
 
 
 
 
234
 
235
 
236
  def predict(
@@ -254,37 +275,55 @@ def predict(
254
  yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
255
  if reply_language == "跟随问题语言(不稳定)":
256
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
 
 
 
257
  if files:
 
 
258
  msg = "加载索引中……(这可能需要几分钟)"
259
  logging.info(msg)
260
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
261
  index = construct_index(openai_api_key, file_src=files)
262
  msg = "索引构建完成,获取回答中……"
 
263
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
264
- history, chatbot, status_text = chat_ai(openai_api_key, index, inputs, history, chatbot, reply_language)
265
- yield chatbot, history, status_text, all_token_counts
266
- return
267
-
268
- old_inputs = ""
269
- link_references = []
270
- if use_websearch:
 
 
 
 
 
 
 
 
 
 
 
271
  search_results = ddg(inputs, max_results=5)
272
  old_inputs = inputs
273
- web_results = []
274
  for idx, result in enumerate(search_results):
275
  logging.info(f"搜索结果{idx + 1}:{result}")
276
  domain_name = urllib3.util.parse_url(result["href"]).host
277
- web_results.append(f'[{idx+1}]"{result["body"]}"\nURL: {result["href"]}')
278
- link_references.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
279
- link_references = "\n\n" + "".join(link_references)
 
280
  inputs = (
281
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
282
  .replace("{query}", inputs)
283
- .replace("{web_results}", "\n\n".join(web_results))
284
  .replace("{reply_language}", reply_language )
285
  )
286
  else:
287
- link_references = ""
288
 
289
  if len(openai_api_key) != 51:
290
  status_text = standard_error_msg + no_apikey_msg
@@ -317,7 +356,7 @@ def predict(
317
  temperature,
318
  selected_model,
319
  fake_input=old_inputs,
320
- display_append=link_references
321
  )
322
  for chatbot, history, status_text, all_token_counts in iter:
323
  if shared.state.interrupted:
@@ -337,7 +376,7 @@ def predict(
337
  temperature,
338
  selected_model,
339
  fake_input=old_inputs,
340
- display_append=link_references
341
  )
342
  yield chatbot, history, status_text, all_token_counts
343
 
@@ -350,6 +389,11 @@ def predict(
350
  + colorama.Style.RESET_ALL
351
  )
352
 
 
 
 
 
 
353
  if stream:
354
  max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
355
  else:
 
13
  from duckduckgo_search import ddg
14
  import asyncio
15
  import aiohttp
16
+ from llama_index.indices.query.vector_store import GPTVectorStoreIndexQuery
17
+ from llama_index.indices.query.schema import QueryBundle
18
+ from langchain.llms import OpenAIChat
19
 
20
  from modules.presets import *
21
  from modules.llama_func import *
 
106
  else:
107
  chatbot.append((inputs, ""))
108
  user_token_count = 0
109
+ if fake_input is not None:
110
+ input_token_count = count_token(construct_user(fake_input))
111
+ else:
112
+ input_token_count = count_token(construct_user(inputs))
113
  if len(all_token_counts) == 0:
114
  system_prompt_token_count = count_token(construct_system(system_prompt))
115
  user_token_count = (
116
+ input_token_count + system_prompt_token_count
117
  )
118
  else:
119
+ user_token_count = input_token_count
120
  all_token_counts.append(user_token_count)
121
  logging.info(f"输入token计数: {user_token_count}")
122
  yield get_return_value()
 
144
  yield get_return_value()
145
  error_json_str = ""
146
 
147
+ if fake_input is not None:
148
+ history[-2] = construct_user(fake_input)
149
  for chunk in tqdm(response.iter_lines()):
150
  if counter == 0:
151
  counter += 1
 
210
  chatbot.append((fake_input, ""))
211
  else:
212
  chatbot.append((inputs, ""))
213
+ if fake_input is not None:
214
+ all_token_counts.append(count_token(construct_user(fake_input)))
215
+ else:
216
+ all_token_counts.append(count_token(construct_user(inputs)))
217
  try:
218
  response = get_response(
219
  openai_api_key,
 
236
  status_text = standard_error_msg + ssl_error_prompt + error_retrieve_prompt
237
  return chatbot, history, status_text, all_token_counts
238
  response = json.loads(response.text)
239
+ if fake_input is not None:
240
+ history[-2] = construct_user(fake_input)
241
+ try:
242
+ content = response["choices"][0]["message"]["content"]
243
+ history[-1] = construct_assistant(content)
244
+ chatbot[-1] = (chatbot[-1][0], content+display_append)
245
+ total_token_count = response["usage"]["total_tokens"]
246
+ if fake_input is not None:
247
+ all_token_counts[-1] += count_token(construct_assistant(content))
248
+ else:
249
+ all_token_counts[-1] = total_token_count - sum(all_token_counts)
250
+ status_text = construct_token_message(total_token_count)
251
+ return chatbot, history, status_text, all_token_counts
252
+ except KeyError:
253
+ status_text = standard_error_msg + str(response)
254
+ return chatbot, history, status_text, all_token_counts
255
 
256
 
257
  def predict(
 
275
  yield chatbot+[(inputs, "")], history, "开始生成回答……", all_token_counts
276
  if reply_language == "跟随问题语言(不稳定)":
277
  reply_language = "the same language as the question, such as English, 中文, 日本語, Español, Français, or Deutsch."
278
+ old_inputs = None
279
+ display_reference = []
280
+ limited_context = False
281
  if files:
282
+ limited_context = True
283
+ old_inputs = inputs
284
  msg = "加载索引中……(这可能需要几分钟)"
285
  logging.info(msg)
286
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
287
  index = construct_index(openai_api_key, file_src=files)
288
  msg = "索引构建完成,获取回答中……"
289
+ logging.info(msg)
290
  yield chatbot+[(inputs, "")], history, msg, all_token_counts
291
+ llm_predictor = LLMPredictor(llm=OpenAIChat(temperature=0, model_name=selected_model))
292
+ prompt_helper = PromptHelper(max_input_size = 4096, num_output = 5, max_chunk_overlap = 20, chunk_size_limit=600)
293
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper)
294
+ query_object = GPTVectorStoreIndexQuery(index.index_struct, service_context=service_context, similarity_top_k=5, vector_store=index._vector_store, docstore=index._docstore)
295
+ query_bundle = QueryBundle(inputs)
296
+ nodes = query_object.retrieve(query_bundle)
297
+ reference_results = [n.node.text for n in nodes]
298
+ reference_results = add_source_numbers(reference_results, use_source=False)
299
+ display_reference = add_details(reference_results)
300
+ display_reference = "\n\n" + "".join(display_reference)
301
+ inputs = (
302
+ replace_today(PROMPT_TEMPLATE)
303
+ .replace("{query_str}", inputs)
304
+ .replace("{context_str}", "\n\n".join(reference_results))
305
+ .replace("{reply_language}", reply_language )
306
+ )
307
+ elif use_websearch:
308
+ limited_context = True
309
  search_results = ddg(inputs, max_results=5)
310
  old_inputs = inputs
311
+ reference_results = []
312
  for idx, result in enumerate(search_results):
313
  logging.info(f"搜索结果{idx + 1}:{result}")
314
  domain_name = urllib3.util.parse_url(result["href"]).host
315
+ reference_results.append([result["body"], result["href"]])
316
+ display_reference.append(f"{idx+1}. [{domain_name}]({result['href']})\n")
317
+ reference_results = add_source_numbers(reference_results)
318
+ display_reference = "\n\n" + "".join(display_reference)
319
  inputs = (
320
  replace_today(WEBSEARCH_PTOMPT_TEMPLATE)
321
  .replace("{query}", inputs)
322
+ .replace("{web_results}", "\n\n".join(reference_results))
323
  .replace("{reply_language}", reply_language )
324
  )
325
  else:
326
+ display_reference = ""
327
 
328
  if len(openai_api_key) != 51:
329
  status_text = standard_error_msg + no_apikey_msg
 
356
  temperature,
357
  selected_model,
358
  fake_input=old_inputs,
359
+ display_append=display_reference
360
  )
361
  for chatbot, history, status_text, all_token_counts in iter:
362
  if shared.state.interrupted:
 
376
  temperature,
377
  selected_model,
378
  fake_input=old_inputs,
379
+ display_append=display_reference
380
  )
381
  yield chatbot, history, status_text, all_token_counts
382
 
 
389
  + colorama.Style.RESET_ALL
390
  )
391
 
392
+ if limited_context:
393
+ history = history[-4:]
394
+ all_token_counts = all_token_counts[-2:]
395
+ yield chatbot, history, status_text, all_token_counts
396
+
397
  if stream:
398
  max_token = MODEL_SOFT_TOKEN_LIMIT[selected_model]["streaming"]
399
  else:
modules/llama_func.py CHANGED
@@ -1,7 +1,7 @@
1
  import os
2
  import logging
3
 
4
- from llama_index import GPTSimpleVectorIndex
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
@@ -11,7 +11,10 @@ from llama_index import (
11
  RefinePrompt,
12
  )
13
  from langchain.llms import OpenAI
 
14
  import colorama
 
 
15
 
16
  from modules.presets import *
17
  from modules.utils import *
@@ -28,6 +31,12 @@ def get_index_name(file_src):
28
 
29
  return md5_hash.hexdigest()
30
 
 
 
 
 
 
 
31
 
32
  def get_documents(file_src):
33
  documents = []
@@ -37,9 +46,12 @@ def get_documents(file_src):
37
  logging.info(f"loading file: {file.name}")
38
  if os.path.splitext(file.name)[1] == ".pdf":
39
  logging.debug("Loading PDF...")
40
- CJKPDFReader = download_loader("CJKPDFReader")
41
- loader = CJKPDFReader()
42
- text_raw = loader.load_data(file=file.name)[0].text
 
 
 
43
  elif os.path.splitext(file.name)[1] == ".docx":
44
  logging.debug("Loading DOCX...")
45
  DocxReader = download_loader("DocxReader")
@@ -55,7 +67,10 @@ def get_documents(file_src):
55
  with open(file.name, "r", encoding="utf-8") as f:
56
  text_raw = f.read()
57
  text = add_space(text_raw)
 
 
58
  documents += [Document(text)]
 
59
  return documents
60
 
61
 
@@ -63,13 +78,11 @@ def construct_index(
63
  api_key,
64
  file_src,
65
  max_input_size=4096,
66
- num_outputs=1,
67
  max_chunk_overlap=20,
68
  chunk_size_limit=600,
69
  embedding_limit=None,
70
- separator=" ",
71
- num_children=10,
72
- max_keywords_per_chunk=10,
73
  ):
74
  os.environ["OPENAI_API_KEY"] = api_key
75
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
@@ -77,16 +90,9 @@ def construct_index(
77
  separator = " " if separator == "" else separator
78
 
79
  llm_predictor = LLMPredictor(
80
- llm=OpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
81
- )
82
- prompt_helper = PromptHelper(
83
- max_input_size,
84
- num_outputs,
85
- max_chunk_overlap,
86
- embedding_limit,
87
- chunk_size_limit,
88
- separator=separator,
89
  )
 
90
  index_name = get_index_name(file_src)
91
  if os.path.exists(f"./index/{index_name}.json"):
92
  logging.info("找到了缓存的索引文件,加载中……")
@@ -94,14 +100,19 @@ def construct_index(
94
  else:
95
  try:
96
  documents = get_documents(file_src)
97
- logging.debug("构建索引中……")
98
- index = GPTSimpleVectorIndex(
99
- documents, llm_predictor=llm_predictor, prompt_helper=prompt_helper
 
100
  )
 
101
  os.makedirs("./index", exist_ok=True)
102
  index.save_to_disk(f"./index/{index_name}.json")
 
103
  return index
 
104
  except Exception as e:
 
105
  print(e)
106
  return None
107
 
@@ -148,7 +159,7 @@ def ask_ai(
148
  question,
149
  prompt_tmpl,
150
  refine_tmpl,
151
- sim_k=1,
152
  temprature=0,
153
  prefix_messages=[],
154
  reply_language="中文",
@@ -158,7 +169,7 @@ def ask_ai(
158
  logging.debug("Index file found")
159
  logging.debug("Querying index...")
160
  llm_predictor = LLMPredictor(
161
- llm=OpenAI(
162
  temperature=temprature,
163
  model_name="gpt-3.5-turbo-0301",
164
  prefix_messages=prefix_messages,
@@ -170,7 +181,6 @@ def ask_ai(
170
  rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
171
  response = index.query(
172
  question,
173
- llm_predictor=llm_predictor,
174
  similarity_top_k=sim_k,
175
  text_qa_template=qa_prompt,
176
  refine_template=rf_prompt,
 
1
  import os
2
  import logging
3
 
4
+ from llama_index import GPTSimpleVectorIndex, ServiceContext
5
  from llama_index import download_loader
6
  from llama_index import (
7
  Document,
 
11
  RefinePrompt,
12
  )
13
  from langchain.llms import OpenAI
14
+ from langchain.chat_models import ChatOpenAI
15
  import colorama
16
+ import PyPDF2
17
+ from tqdm import tqdm
18
 
19
  from modules.presets import *
20
  from modules.utils import *
 
31
 
32
  return md5_hash.hexdigest()
33
 
34
+ def block_split(text):
35
+ blocks = []
36
+ while len(text) > 0:
37
+ blocks.append(Document(text[:1000]))
38
+ text = text[1000:]
39
+ return blocks
40
 
41
  def get_documents(file_src):
42
  documents = []
 
46
  logging.info(f"loading file: {file.name}")
47
  if os.path.splitext(file.name)[1] == ".pdf":
48
  logging.debug("Loading PDF...")
49
+ pdftext = ""
50
+ with open(file.name, 'rb') as pdfFileObj:
51
+ pdfReader = PyPDF2.PdfReader(pdfFileObj)
52
+ for page in tqdm(pdfReader.pages):
53
+ pdftext += page.extract_text()
54
+ text_raw = pdftext
55
  elif os.path.splitext(file.name)[1] == ".docx":
56
  logging.debug("Loading DOCX...")
57
  DocxReader = download_loader("DocxReader")
 
67
  with open(file.name, "r", encoding="utf-8") as f:
68
  text_raw = f.read()
69
  text = add_space(text_raw)
70
+ # text = block_split(text)
71
+ # documents += text
72
  documents += [Document(text)]
73
+ logging.debug("Documents loaded.")
74
  return documents
75
 
76
 
 
78
  api_key,
79
  file_src,
80
  max_input_size=4096,
81
+ num_outputs=5,
82
  max_chunk_overlap=20,
83
  chunk_size_limit=600,
84
  embedding_limit=None,
85
+ separator=" "
 
 
86
  ):
87
  os.environ["OPENAI_API_KEY"] = api_key
88
  chunk_size_limit = None if chunk_size_limit == 0 else chunk_size_limit
 
90
  separator = " " if separator == "" else separator
91
 
92
  llm_predictor = LLMPredictor(
93
+ llm=ChatOpenAI(model_name="gpt-3.5-turbo-0301", openai_api_key=api_key)
 
 
 
 
 
 
 
 
94
  )
95
+ prompt_helper = PromptHelper(max_input_size = max_input_size, num_output = num_outputs, max_chunk_overlap = max_chunk_overlap, embedding_limit=embedding_limit, chunk_size_limit=600, separator=separator)
96
  index_name = get_index_name(file_src)
97
  if os.path.exists(f"./index/{index_name}.json"):
98
  logging.info("找到了缓存的索引文件,加载中……")
 
100
  else:
101
  try:
102
  documents = get_documents(file_src)
103
+ logging.info("构建索引中……")
104
+ service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor, prompt_helper=prompt_helper, chunk_size_limit=chunk_size_limit)
105
+ index = GPTSimpleVectorIndex.from_documents(
106
+ documents, service_context=service_context
107
  )
108
+ logging.debug("索引构建完成!")
109
  os.makedirs("./index", exist_ok=True)
110
  index.save_to_disk(f"./index/{index_name}.json")
111
+ logging.debug("索引已保存至本地!")
112
  return index
113
+
114
  except Exception as e:
115
+ logging.error("索引构建失败!", e)
116
  print(e)
117
  return None
118
 
 
159
  question,
160
  prompt_tmpl,
161
  refine_tmpl,
162
+ sim_k=5,
163
  temprature=0,
164
  prefix_messages=[],
165
  reply_language="中文",
 
169
  logging.debug("Index file found")
170
  logging.debug("Querying index...")
171
  llm_predictor = LLMPredictor(
172
+ llm=ChatOpenAI(
173
  temperature=temprature,
174
  model_name="gpt-3.5-turbo-0301",
175
  prefix_messages=prefix_messages,
 
181
  rf_prompt = RefinePrompt(refine_tmpl.replace("{reply_language}", reply_language))
182
  response = index.query(
183
  question,
 
184
  similarity_top_k=sim_k,
185
  text_qa_template=qa_prompt,
186
  refine_template=rf_prompt,
modules/presets.py CHANGED
@@ -83,7 +83,8 @@ MODEL_SOFT_TOKEN_LIMIT = {
83
  }
84
 
85
  REPLY_LANGUAGES = [
86
- "中文",
 
87
  "English",
88
  "日本語",
89
  "Español",
 
83
  }
84
 
85
  REPLY_LANGUAGES = [
86
+ "简体中文",
87
+ "繁體中文",
88
  "English",
89
  "日本語",
90
  "Español",
modules/utils.py CHANGED
@@ -375,8 +375,8 @@ def replace_today(prompt):
375
 
376
 
377
  def get_geoip():
378
- response = requests.get("https://ipapi.co/json/", timeout=5)
379
  try:
 
380
  data = response.json()
381
  except:
382
  data = {"error": True, "reason": "连接ipapi失败"}
@@ -384,7 +384,7 @@ def get_geoip():
384
  logging.warning(f"无法获取IP地址信息。\n{data}")
385
  if data["reason"] == "RateLimited":
386
  return (
387
- f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用,但请注意,如果您的IP地址在不受支持的地区,您可能会遇到问题。"
388
  )
389
  else:
390
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
@@ -457,7 +457,7 @@ def get_proxies():
457
 
458
  if proxies == {}:
459
  proxies = None
460
-
461
  return proxies
462
 
463
  def run(command, desc=None, errdesc=None, custom_env=None, live=False):
@@ -500,4 +500,19 @@ Python: <span title="{sys.version}">{python_version}</span>
500
  Gradio: {gr.__version__}
501
   • 
502
  Commit: {commit_info}
503
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
 
377
  def get_geoip():
 
378
  try:
379
+ response = requests.get("https://ipapi.co/json/", timeout=5)
380
  data = response.json()
381
  except:
382
  data = {"error": True, "reason": "连接ipapi失败"}
 
384
  logging.warning(f"无法获取IP地址信息。\n{data}")
385
  if data["reason"] == "RateLimited":
386
  return (
387
+ f"获取IP地理位置失败,因为达到了检测IP的速率限制。聊天功能可能仍然可用。"
388
  )
389
  else:
390
  return f"获取IP地理位置失败。原因:{data['reason']}。你仍然可以使用聊天功能。"
 
457
 
458
  if proxies == {}:
459
  proxies = None
460
+
461
  return proxies
462
 
463
  def run(command, desc=None, errdesc=None, custom_env=None, live=False):
 
500
  Gradio: {gr.__version__}
501
   • 
502
  Commit: {commit_info}
503
+ """
504
+
505
+ def add_source_numbers(lst, source_name = "Source", use_source = True):
506
+ if use_source:
507
+ return [f'[{idx+1}]\t "{item[0]}"\n{source_name}: {item[1]}' for idx, item in enumerate(lst)]
508
+ else:
509
+ return [f'[{idx+1}]\t "{item}"' for idx, item in enumerate(lst)]
510
+
511
+ def add_details(lst):
512
+ nodes = []
513
+ for index, txt in enumerate(lst):
514
+ brief = txt[:25].replace("\n", "")
515
+ nodes.append(
516
+ f"<details><summary>{brief}...</summary><p>{txt}</p></details>"
517
+ )
518
+ return nodes
requirements.txt CHANGED
@@ -10,3 +10,4 @@ Pygments
10
  llama_index==0.4.40
11
  langchain
12
  markdown
 
 
10
  llama_index==0.4.40
11
  langchain
12
  markdown
13
+ PyPDF2