lamhieu commited on
Commit
ef88752
·
1 Parent(s): 7a736e5

chore: support tools with search on internet

Browse files
Files changed (3) hide show
  1. README.md +2 -4
  2. app.py +247 -56
  3. requirements.txt +3 -1
README.md CHANGED
@@ -31,11 +31,9 @@ tags:
31
 
32
  ### Notes
33
 
34
- The extension source code belongs to: "LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning".
35
 
36
- See source code details [here](https://github.com/datamllab/LongLM).
37
-
38
- ```
39
  @misc{jin2024llm,
40
  title={LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning},
41
  author={Hongye Jin and Xiaotian Han and Jingfeng Yang and Zhimeng Jiang and Zirui Liu and Chia-Yuan Chang and Huiyuan Chen and Xia Hu},
 
31
 
32
  ### Notes
33
 
34
+ The extension source code belongs to: "LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning". See source code details [here](https://github.com/datamllab/LongLM).
35
 
36
+ ```tex
 
 
37
  @misc{jin2024llm,
38
  title={LLM Maybe LongLM: Self-Extend LLM Context Window Without Tuning},
39
  author={Hongye Jin and Xiaotian Han and Jingfeng Yang and Zhimeng Jiang and Zirui Liu and Chia-Yuan Chang and Huiyuan Chen and Xia Hu},
app.py CHANGED
@@ -1,6 +1,8 @@
1
  # pylint: skip-file
2
 
3
  import subprocess
 
 
4
 
5
  subprocess.run(
6
  f"pip install flash-attn --no-build-isolation",
@@ -15,24 +17,27 @@ from typing import Iterator
15
  import gradio as gr
16
  import spaces
17
  import torch
 
 
18
  import SelfExtend
19
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
 
 
20
 
21
 
22
- MAX_MAX_NEW_TOKENS = 4096
23
- DEFAULT_MAX_NEW_TOKENS = 1536
24
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "123392"))
25
 
26
  DESCRIPTION = """\
27
- # Playground with Ghost 8B Beta (β, 128k)
28
 
29
- **Ghost 8B Beta** is a large language model developed with goals that include excellent multilingual support, superior knowledge capabilities, and cost-effectiveness. The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default.
30
-
31
- The Ghost 8B Beta model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/).
32
 
33
  The languages supported are 🇺🇸 English, 🇫🇷 French, 🇮🇹 Italian, 🇪🇸 Spanish, 🇵🇹 Portuguese, 🇩🇪 German, 🇻🇳 Vietnamese, 🇰🇷 Korean and 🇨🇳 Chinese.
34
 
35
- 📋 Note: current model version is "disl-0x5" (10 Jul 2024), context length 128k (123392 tokens) and current status is "moderating / previewing". For detailed information about the model, see [here](https://ghost-x.org/docs/models/ghost-8b-beta/). Try to experience it the way you want!
 
36
  """
37
 
38
 
@@ -251,19 +256,19 @@ if not torch.cuda.is_available():
251
 
252
  if torch.cuda.is_available():
253
  model_id = "ghost-x/ghost-8b-beta"
254
- model_tk = os.getenv("HF_TOKEN", None)
255
  model = AutoModelForCausalLM.from_pretrained(
256
  model_id,
257
  device_map="auto",
258
  torch_dtype=torch.bfloat16,
259
  attn_implementation="flash_attention_2",
260
  trust_remote_code=True,
261
- token=model_tk,
262
  )
263
  tokenizer = AutoTokenizer.from_pretrained(
264
  model_id,
265
  trust_remote_code=True,
266
- token=model_tk,
267
  )
268
  SelfExtend.apply(
269
  model,
@@ -274,73 +279,259 @@ if torch.cuda.is_available():
274
  )
275
  model.generation_config.max_length = 123392
276
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
277
 
278
- @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
279
  def generate(
280
  message: str,
281
  chat_history: list[tuple[str, str]],
282
- system_prompt: str,
283
- max_new_tokens: int = 1536,
 
284
  temperature: float = 0.4,
285
  top_p: float = 0.95,
286
  top_k: int = 50,
287
  repetition_penalty: float = 1.0,
288
  ) -> Iterator[str]:
289
- conversation = []
290
- if system_prompt:
291
- conversation.append({"role": "system", "content": system_prompt})
292
- for user, assistant in chat_history:
293
- conversation.extend(
294
- [
295
- {"role": "user", "content": user},
296
- {"role": "assistant", "content": assistant},
297
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  )
299
- conversation.append({"role": "user", "content": message})
 
 
 
 
 
 
300
 
301
- input_ids = tokenizer.apply_chat_template(
302
- conversation, add_generation_prompt=True, return_tensors="pt"
303
- )
304
- input_ids = input_ids.to(model.device)
305
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
306
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
307
- gr.Warning(
308
- f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
309
  )
 
 
 
 
 
 
310
 
311
- streamer = TextIteratorStreamer(
312
- tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
313
- )
314
- generate_kwargs = dict(
315
- input_ids=input_ids,
316
- streamer=streamer,
317
- max_new_tokens=max_new_tokens,
318
- do_sample=True,
319
- repetition_penalty=repetition_penalty,
320
- )
321
- if temperature == 0:
322
- generate_kwargs["do_sample"] = False
323
- else:
324
- generate_kwargs["temperature"] = temperature
325
- generate_kwargs["top_p"] = top_p
326
- generate_kwargs["top_k"] = top_k
327
-
328
- t = Thread(target=model.generate, kwargs=generate_kwargs)
329
- t.start()
330
 
331
- outputs = []
332
- for text in streamer:
333
- outputs.append(text)
334
- yield "".join(outputs)
 
 
 
335
 
 
336
 
337
- chatbot = gr.Chatbot(height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta")
 
 
 
338
 
339
  chat_interface = gr.ChatInterface(
340
  fn=generate,
341
  chatbot=chatbot,
342
  fill_height=True,
343
  additional_inputs=[
 
 
 
344
  gr.Textbox(label="System prompt", lines=6),
345
  gr.Slider(
346
  label="Max new tokens",
@@ -382,6 +573,7 @@ chat_interface = gr.ChatInterface(
382
  cache_examples=False,
383
  examples=EXAMPLES,
384
  examples_per_page=9,
 
385
  )
386
 
387
  with gr.Blocks(fill_height=True, css="style.css") as demo:
@@ -391,4 +583,3 @@ with gr.Blocks(fill_height=True, css="style.css") as demo:
391
 
392
  if __name__ == "__main__":
393
  demo.queue(max_size=20).launch(share=True)
394
- # demo.launch(share=True)
 
1
  # pylint: skip-file
2
 
3
  import subprocess
4
+ import json
5
+ import requests
6
 
7
  subprocess.run(
8
  f"pip install flash-attn --no-build-isolation",
 
17
  import gradio as gr
18
  import spaces
19
  import torch
20
+ import wikipedia
21
+ import time
22
  import SelfExtend
23
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
24
+ from bs4 import BeautifulSoup
25
+ from functools import lru_cache
26
 
27
 
28
+ MAX_MAX_NEW_TOKENS = 8192
29
+ DEFAULT_MAX_NEW_TOKENS = 2048
30
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "123392"))
31
 
32
  DESCRIPTION = """\
33
+ # Playground with Ghost 8B Beta (β, 8k)
34
 
35
+ **Ghost 8B Beta** model outperforms prominent models such as Llama 3 8B Instruct, GPT 3.5 Turbo in the lc_winrate score. In addition, it also outperforms Claude 3 Opus, Claude 3 Sonnet, GPT-4, and Mistral Large when comparing the winrate score of AlpacaEval 2.0, [*](https://ghost-x.org/docs/models/ghost-8b-beta/). The model comes in two context length versions, [8k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-8k) and [128k](https://huggingface.co/spaces/lamhieu/ghost-8b-beta-128k), along with multilingual function tools support by default.
 
 
36
 
37
  The languages supported are 🇺🇸 English, 🇫🇷 French, 🇮🇹 Italian, 🇪🇸 Spanish, 🇵🇹 Portuguese, 🇩🇪 German, 🇻🇳 Vietnamese, 🇰🇷 Korean and 🇨🇳 Chinese.
38
 
39
+ 🗞️ **Updates**
40
+ * Jul 23, 2024: added support for tools, now available to search for information on the internet.
41
  """
42
 
43
 
 
256
 
257
  if torch.cuda.is_available():
258
  model_id = "ghost-x/ghost-8b-beta"
259
+ hf_serect = os.getenv("HF_TOKEN", None)
260
  model = AutoModelForCausalLM.from_pretrained(
261
  model_id,
262
  device_map="auto",
263
  torch_dtype=torch.bfloat16,
264
  attn_implementation="flash_attention_2",
265
  trust_remote_code=True,
266
+ token=hf_serect,
267
  )
268
  tokenizer = AutoTokenizer.from_pretrained(
269
  model_id,
270
  trust_remote_code=True,
271
+ token=hf_serect,
272
  )
273
  SelfExtend.apply(
274
  model,
 
279
  )
280
  model.generation_config.max_length = 123392
281
 
282
+ waiting_tools_timeout = 7.5
283
+ supported_tools = json.dumps(
284
+ [
285
+ {
286
+ "type": "function",
287
+ "function": {
288
+ "name": "search_on_internet",
289
+ "description": "Use this tool to search online, only use it for information you don't know or are unsure of, don't abuse it.",
290
+ "parameters": {
291
+ "type": "object",
292
+ "properties": {
293
+ "keyword": {
294
+ "type": "string",
295
+ "description": "Search keywords, rephrase to optimize search results based on questions suitable to the specified search type.",
296
+ "required": True,
297
+ },
298
+ "type": {
299
+ "type": "string",
300
+ "description": "Search type, based on the question to determine whether to search for it in 'wikipedia' or 'google', prefer to use wikipedia for information about events, history and people.",
301
+ "enum": ["wikipedia", "google"],
302
+ "default": "google",
303
+ "required": True,
304
+ },
305
+ },
306
+ },
307
+ },
308
+ }
309
+ ],
310
+ ensure_ascii=False,
311
+ )
312
+
313
+
314
+ @lru_cache(maxsize=128)
315
+ def extract_text_from_webpage(html_content):
316
+ soup = BeautifulSoup(html_content, "html.parser")
317
+ for tag in soup(["script", "style", "header", "footer", "nav", "form", "svg"]):
318
+ tag.extract()
319
+ visible_text = soup.get_text(strip=True, separator=" ")
320
+ return visible_text
321
+
322
+
323
+ def search_with_wikipedia(query: str):
324
+ all_results = []
325
+ try:
326
+ all_results.append(wikipedia.summary(query))
327
+ except Exception as e:
328
+ pass
329
+ return all_results
330
 
331
+
332
+ def search_with_google(
333
+ query: str,
334
+ num_results: int = 3,
335
+ timeout: int = 5,
336
+ ssl_verify: bool = None,
337
+ ):
338
+ all_results = []
339
+ max_chars_per_page = 4096
340
+ with requests.Session() as session:
341
+ resp = session.get(
342
+ url="https://www.google.com/search",
343
+ headers={
344
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
345
+ },
346
+ params={
347
+ "q": query,
348
+ "num": num_results,
349
+ "udm": 14,
350
+ },
351
+ timeout=timeout,
352
+ verify=ssl_verify,
353
+ )
354
+ resp.raise_for_status()
355
+ soup = BeautifulSoup(resp.text, "html.parser")
356
+ result_block = soup.find_all("div", attrs={"class": "g"})
357
+ for result in result_block:
358
+ link = result.find("a", href=True)
359
+ if link:
360
+ link = link["href"]
361
+ try:
362
+ webpage = session.get(
363
+ link,
364
+ headers={
365
+ "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64; rv:109.0) Gecko/20100101 Firefox/111.0"
366
+ },
367
+ )
368
+ webpage.raise_for_status()
369
+ visible_text = extract_text_from_webpage(webpage.text)
370
+ if len(visible_text) > max_chars_per_page:
371
+ visible_text = visible_text[:max_chars_per_page]
372
+ all_results.append({"link": link, "text": visible_text})
373
+ except requests.exceptions.RequestException as e:
374
+ print(f"Error fetching or processing {link}: {e}")
375
+ pass
376
+ else:
377
+ pass
378
+ return all_results
379
+
380
+
381
+ @spaces.GPU(duration=180)
382
  def generate(
383
  message: str,
384
  chat_history: list[tuple[str, str]],
385
+ allow_used_tools: bool = True,
386
+ system_prompt: str = "",
387
+ max_new_tokens: int = 2048,
388
  temperature: float = 0.4,
389
  top_p: float = 0.95,
390
  top_k: int = 50,
391
  repetition_penalty: float = 1.0,
392
  ) -> Iterator[str]:
393
+ # print()
394
+ # print("allow_used_tools:\n", allow_used_tools)
395
+ # print("system_prompt:\n", system_prompt)
396
+ # print("max_new_tokens:\n", max_new_tokens)
397
+ # print("temperature:\n", temperature)
398
+
399
+ def build_input_ids(
400
+ apply_tools: bool = None,
401
+ references: list[str] = None,
402
+ ):
403
+ conversation = []
404
+ if system_prompt:
405
+ conversation.append({"role": "system", "content": system_prompt})
406
+ if apply_tools is True:
407
+ conversation.append({"role": "tools", "content": supported_tools})
408
+ if (
409
+ references is not None
410
+ and isinstance(references, list)
411
+ and len(references) > 0
412
+ ):
413
+ conversation.append(
414
+ {
415
+ "role": "refs",
416
+ "content": json.dumps(references, ensure_ascii=False),
417
+ }
418
+ )
419
+
420
+ for user, assistant in chat_history:
421
+ conversation.extend(
422
+ [
423
+ {"role": "user", "content": user},
424
+ {"role": "assistant", "content": assistant},
425
+ ]
426
+ )
427
+ conversation.append({"role": "user", "content": message})
428
+
429
+ input_ids = tokenizer.apply_chat_template(
430
+ conversation, add_generation_prompt=True, return_tensors="pt"
431
  )
432
+ input_ids = input_ids.to(model.device)
433
+ if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
434
+ input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
435
+ gr.Warning(
436
+ f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens."
437
+ )
438
+ return input_ids
439
 
440
+ def generate_chat_responses(
441
+ previous_response: str = None,
442
+ ):
443
+ document_references = []
444
+ if previous_response is not None:
445
+ scheduled_tools_runs = None
446
+ try:
447
+ scheduled_tools_runs = json.loads(previous_response)
448
+ if scheduled_tools_runs["type"] == "function" and scheduled_tools_runs[
449
+ "name"
450
+ ] in ["search_on_internet"]:
451
+ pass
452
+ else:
453
+ scheduled_tools_runs = None
454
+ except Exception as e:
455
+ print(e)
456
+ pass
457
+
458
+ if (
459
+ scheduled_tools_runs is not None
460
+ and scheduled_tools_runs["name"] == "search_on_internet"
461
+ ):
462
+ keyword = scheduled_tools_runs["arguments"]["keyword"]
463
+ search_type = scheduled_tools_runs["arguments"]["type"]
464
+ if search_type == "wikipedia":
465
+ gr.Info("Searching for information on the Wikipedia.")
466
+ document_references = search_with_wikipedia(keyword)
467
+ else:
468
+ gr.Info("Searching for information on the Google.")
469
+ document_references = search_with_google(keyword)
470
+
471
+ input_ids = build_input_ids(
472
+ apply_tools=(
473
+ True
474
+ if allow_used_tools is True and previous_response is None
475
+ else False
476
+ ),
477
+ references=document_references,
478
+ )
479
+ streamer = TextIteratorStreamer(
480
+ tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
481
+ )
482
+ generate_kwargs = dict(
483
+ input_ids=input_ids,
484
+ streamer=streamer,
485
+ max_new_tokens=max_new_tokens,
486
+ do_sample=True,
487
+ repetition_penalty=repetition_penalty,
488
  )
489
+ if temperature == 0:
490
+ generate_kwargs["do_sample"] = False
491
+ else:
492
+ generate_kwargs["temperature"] = temperature
493
+ generate_kwargs["top_p"] = top_p
494
+ generate_kwargs["top_k"] = top_k
495
 
496
+ t = Thread(target=model.generate, kwargs=generate_kwargs)
497
+ t.start()
498
+
499
+ state = {
500
+ "mark": None,
501
+ "respond": False,
502
+ }
503
+ outputs = []
504
+ for text in streamer:
505
+ if state["mark"] is None:
506
+ state["mark"] = time.time()
507
+ outputs.append(text)
508
+ if state["mark"] + waiting_tools_timeout < time.time():
509
+ state["respond"] = True
510
+ yield "".join(outputs)
 
 
 
 
511
 
512
+ if (
513
+ state["respond"] is False
514
+ and state["mark"] + waiting_tools_timeout > time.time()
515
+ ):
516
+ gr.Info("Searching for information on the internet.")
517
+ previous_response = "".join(outputs)
518
+ yield from generate_chat_responses(previous_response=previous_response)
519
 
520
+ yield from generate_chat_responses(previous_response=None)
521
 
522
+
523
+ chatbot = gr.Chatbot(
524
+ height=500, placeholder=PLACEHOLDER, label="Ghost 8B Beta", show_copy_button=True
525
+ )
526
 
527
  chat_interface = gr.ChatInterface(
528
  fn=generate,
529
  chatbot=chatbot,
530
  fill_height=True,
531
  additional_inputs=[
532
+ gr.Checkbox(
533
+ label="Allow used tools (available: search on internet)", value=True
534
+ ),
535
  gr.Textbox(label="System prompt", lines=6),
536
  gr.Slider(
537
  label="Max new tokens",
 
573
  cache_examples=False,
574
  examples=EXAMPLES,
575
  examples_per_page=9,
576
+ concurrency_limit=100,
577
  )
578
 
579
  with gr.Blocks(fill_height=True, css="style.css") as demo:
 
583
 
584
  if __name__ == "__main__":
585
  demo.queue(max_size=20).launch(share=True)
 
requirements.txt CHANGED
@@ -1,8 +1,10 @@
1
  accelerate==0.30.1
2
  bitsandbytes==0.43.1
3
- gradio==4.37.2
4
  scipy==1.13.0
5
  sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.0
8
  transformers==4.41.0
 
 
 
1
  accelerate==0.30.1
2
  bitsandbytes==0.43.1
3
+ gradio==4.39.0
4
  scipy==1.13.0
5
  sentencepiece==0.2.0
6
  spaces==0.28.3
7
  torch==2.0.0
8
  transformers==4.41.0
9
+ beautifulsoup4>=4.9
10
+ wikipedia==1.4.0