Liyan06 commited on
Commit
66fde05
·
1 Parent(s): 6145e0a

add progress streaming function

Browse files
Files changed (2) hide show
  1. handler.py +11 -11
  2. requirements.txt +2 -1
handler.py CHANGED
@@ -1,5 +1,6 @@
1
  from minicheck_web.minicheck import MiniCheck
2
  from web_retrieval import *
 
3
 
4
 
5
  def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
@@ -39,32 +40,31 @@ class EndpointHandler():
39
  assert len(data['inputs']['claims']) == 1, "Only one claim is allowed for web retrieval for the current version."
40
 
41
  claim = data['inputs']['claims'][0]
42
- ranked_docs, scores, ranked_urls = self.search_relevant_docs(claim)
43
-
44
- outputs = {
45
- 'ranked_docs': ranked_docs,
46
- 'scores': scores,
47
- 'ranked_urls': ranked_urls
48
- }
49
 
50
  return outputs
51
 
52
 
53
  def search_relevant_docs(self, claim, timeout=10, max_search_results_per_query=5, allow_duplicated_urls=False):
54
 
 
 
 
55
  search_results = search_google(claim, timeout=timeout)
56
 
57
- print('Searching webpages...')
58
  start = time()
59
  with concurrent.futures.ThreadPoolExecutor() as e:
60
  scraped_results = e.map(scrape_url, search_results, itertools.repeat(timeout))
61
  end = time()
62
- print(f"Finished searching in {round((end - start), 1)} seconds.\n")
 
63
  scraped_results = [(r[0][:50000], r[1]) for r in scraped_results if r[0] and '��' not in r[0] and ".pdf" not in r[1]]
64
 
65
  retrieved_docs, urls = zip(*scraped_results[:max_search_results_per_query])
66
 
67
- print('Scoring webpages...')
68
  start = time()
69
  retrieved_data = {
70
  'inputs': {
@@ -75,7 +75,7 @@ class EndpointHandler():
75
  _, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data)
76
  end = time()
77
  num_chunks = len([item for items in used_chunk for item in items])
78
- print(f'Finished {num_chunks} entailment checks in {round((end - start), 1)} seconds ({round(num_chunks / (end - start) * 60)} Doc./min).')
79
 
80
  ranked_docs, scores, ranked_urls = order_doc_score_url(used_chunk, support_prob_per_chunk, urls, allow_duplicated_urls=allow_duplicated_urls)
81
 
 
1
  from minicheck_web.minicheck import MiniCheck
2
  from web_retrieval import *
3
+ from flask import Response
4
 
5
 
6
  def sort_chunks_single_doc_claim(used_chunk, support_prob_per_chunk):
 
40
  assert len(data['inputs']['claims']) == 1, "Only one claim is allowed for web retrieval for the current version."
41
 
42
  claim = data['inputs']['claims'][0]
43
+ progress_stream = self.search_relevant_docs(claim)
44
+ outputs = Response(progress_stream, mimetype='text/event-stream')
 
 
 
 
 
45
 
46
  return outputs
47
 
48
 
49
  def search_relevant_docs(self, claim, timeout=10, max_search_results_per_query=5, allow_duplicated_urls=False):
50
 
51
+ def progress(message):
52
+ yield f"data: {message}\n\n"
53
+
54
  search_results = search_google(claim, timeout=timeout)
55
 
56
+ yield from progress('Searching webpages...')
57
  start = time()
58
  with concurrent.futures.ThreadPoolExecutor() as e:
59
  scraped_results = e.map(scrape_url, search_results, itertools.repeat(timeout))
60
  end = time()
61
+ yield from progress(f"Finished searching in {round((end - start), 1)} seconds.")
62
+
63
  scraped_results = [(r[0][:50000], r[1]) for r in scraped_results if r[0] and '��' not in r[0] and ".pdf" not in r[1]]
64
 
65
  retrieved_docs, urls = zip(*scraped_results[:max_search_results_per_query])
66
 
67
+ yield from progress('Scoring webpages...')
68
  start = time()
69
  retrieved_data = {
70
  'inputs': {
 
75
  _, _, used_chunk, support_prob_per_chunk = self.scorer.score(data=retrieved_data)
76
  end = time()
77
  num_chunks = len([item for items in used_chunk for item in items])
78
+ yield from progress(f'Finished {num_chunks} entailment checks in {round((end - start), 1)} seconds ({round(num_chunks / (end - start) * 60)} Doc./min).')
79
 
80
  ranked_docs, scores, ranked_urls = order_doc_score_url(used_chunk, support_prob_per_chunk, urls, allow_duplicated_urls=allow_duplicated_urls)
81
 
requirements.txt CHANGED
@@ -4,4 +4,5 @@ nltk==3.8.1
4
  pandas==2.2.1
5
  numpy==1.26.2
6
  tqdm
7
- bs4
 
 
4
  pandas==2.2.1
5
  numpy==1.26.2
6
  tqdm
7
+ bs4
8
+ flask