Andrew Green commited on
Commit
c23cd24
·
1 Parent(s): 0e4ad79

batch inference and add progressbar

Browse files
Files changed (1) hide show
  1. app.py +56 -11
app.py CHANGED
@@ -23,7 +23,7 @@ def get_pipeline():
23
  model_name = "afg1/pombe_curation_fold_0"
24
 
25
 
26
- pipe = pipeline(model=model_name)
27
  return pipe
28
 
29
 
@@ -31,16 +31,58 @@ def get_pipeline():
31
 
32
 
33
  @spaces.GPU
34
- def classify_abstracts(abstracts:Dict[str, str]) -> None:
35
  pipe = get_pipeline()
36
- pmids = list(abstracts.keys())
37
- classification = pipe(text=list(abstracts.values()))
 
 
 
 
38
 
39
- for pmid, abs in zip(pmids, classification):
40
- abs['label'] = label_lookup[abs['label']]
41
- abs['pmid'] = pmid
42
 
43
- return classification
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
 
46
 
@@ -122,9 +164,9 @@ def fetch_abstracts_batch(pmids: List[str], batch_size: int = 200) -> Dict[str,
122
  # Simple abstract
123
  abstract_text = abstract_element.text
124
  else:
125
- abstract_text = "No abstract available"
126
-
127
- all_abstracts[pmid] = abstract_text
128
 
129
  # Respect NCBI's rate limits
130
  time.sleep(0.34)
@@ -275,6 +317,9 @@ def create_interface():
275
  with gr.Row():
276
  d = gr.DownloadButton("Download results", visible=True, interactive=False)
277
 
 
 
 
278
  d.click(download_file, None, d)
279
 
280
  search_button.click(
 
23
  model_name = "afg1/pombe_curation_fold_0"
24
 
25
 
26
+ pipe = pipeline(model=model_name, task="text-classification")
27
  return pipe
28
 
29
 
 
31
 
32
 
33
  @spaces.GPU
34
+ def classify_abstracts(abstracts:Dict[str, str],batch_size=64, progress=gr.Progress()) -> None:
35
  pipe = get_pipeline()
36
+ # pmids = list(abstracts.keys())
37
+ # batch_size = 64
38
+ # classification = []
39
+ # abstracts_list = list(abstracts.values())
40
+ # for i in range(0, len(abstracts), batch_size):
41
+ # classification.extend(pipe(abstracts_list[i:i+batch_size]))
42
 
43
+ # for pmid, abs in zip(pmids, classification):
44
+ # abs['label'] = label_lookup[abs['label']]
45
+ # abs['pmid'] = pmid
46
 
47
+ # return classification
48
+ results = []
49
+ total = len(abstracts)
50
+
51
+ # Convert dictionary to lists of PMIDs and abstracts, preserving order
52
+ pmids = list(abstracts.keys())
53
+ abstract_texts = list(abstracts.values())
54
+
55
+ # Initialize progress bar
56
+ progress(0, desc="Starting classification...")
57
+
58
+ # Process in batches
59
+ for i in range(0, total, batch_size):
60
+ # Get current batch
61
+ batch_abstracts = abstract_texts[i:i + batch_size]
62
+ batch_pmids = pmids[i:i + batch_size]
63
+
64
+ try:
65
+ # Classify the batch
66
+ classifications = pipe(batch_abstracts)
67
+
68
+ # Process each result in the batch
69
+ for pmid, classification in zip(batch_pmids, classifications):
70
+ results.append({
71
+ 'pmid': pmid,
72
+ 'classification': classification['label'],
73
+ 'score': classification['score']
74
+ })
75
+
76
+ # Update progress
77
+ progress(min((i + batch_size) / total, 1.0),
78
+ desc=f"Classified {min(i + batch_size, total)}/{total} abstracts...")
79
+
80
+ except Exception as e:
81
+ print(f"Error classifying batch starting at index {i}: {str(e)}")
82
+ continue
83
+
84
+ progress(1.0, desc="Classification complete!")
85
+ return results
86
 
87
 
88
 
 
164
  # Simple abstract
165
  abstract_text = abstract_element.text
166
  else:
167
+ abstract_text = ""
168
+ if len(abstract_text) > 0:
169
+ all_abstracts[pmid] = abstract_text
170
 
171
  # Respect NCBI's rate limits
172
  time.sleep(0.34)
 
317
  with gr.Row():
318
  d = gr.DownloadButton("Download results", visible=True, interactive=False)
319
 
320
+ with gr.Row():
321
+ progress=gr.Progress()
322
+
323
  d.click(download_file, None, d)
324
 
325
  search_button.click(