HowardZhangdqs commited on
Commit
ac02643
·
1 Parent(s): f432512

fix: label cache error

Browse files
Files changed (2) hide show
  1. ai/classify_paper.py +30 -14
  2. fetch_paper.py +6 -5
ai/classify_paper.py CHANGED
@@ -1,15 +1,19 @@
 
1
  import time
2
  from typing import List, Dict, Optional
3
  from concurrent.futures import ThreadPoolExecutor, as_completed
4
  from threading import Lock
5
  from typing import List, Dict, Any
6
  import json
 
7
 
8
  try:
9
  from .ai import complete
10
  except ImportError:
11
  from ai import complete
12
 
 
 
13
 
14
  paper_types: Dict[str, str] = {
15
  "CV": "computer vision, any paper that deals with image, video, point cloud or 3D model data",
@@ -41,19 +45,21 @@ You should output in the following format with a code block:
41
  ```json
42
  [
43
  {
44
- "index": 1,
45
  "category": ["RO"]
46
  },
47
  {
48
- "index": 2,
49
  "category": ["ML"]
50
  },
51
  {
52
- "index": 3,
53
  "category": ["LLM", "NLP"]
54
  }
55
  ]
56
  ```
 
 
57
  """.strip(),
58
  """
59
  The followings are the papers you need to classify:
@@ -61,24 +67,25 @@ The followings are the papers you need to classify:
61
  ])
62
 
63
 
64
- def build_paper(index: int, title: str, abstract: str = None) -> str:
65
 
66
  if abstract is None:
67
- return f"{index}. {title}"
68
 
69
- return f"{index}. {title}\n\n{abstract}"
70
 
71
 
72
  def get_classify_prompt(papers: List[Dict[str, str]]) -> str:
73
  prompt = []
74
 
75
  for index, paper in enumerate(papers, start=1):
76
- prompt.append(build_paper(paper["index"] if "index" in paper else index, paper["title"], paper["abstract"] if "abstract" in paper else None))
77
 
78
  return user_prompt + "\n\n" + "\n\n".join(prompt)
79
 
80
 
81
  def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
 
82
  # 匹配code block
83
  response = response.strip()
84
  if not response.startswith("```") or not response.endswith("```"):
@@ -90,13 +97,14 @@ def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
90
  try:
91
  data = json.loads(response)
92
  except json.JSONDecodeError:
 
93
  return None
94
 
95
  for paper in data:
96
- if "index" not in paper or "category" not in paper:
97
  return None
98
 
99
- if not isinstance(paper["index"], int) or not isinstance(paper["category"], list):
100
  return None
101
 
102
  for category in paper["category"]:
@@ -160,9 +168,15 @@ class PaperCache:
160
  def get(self, paper):
161
  key = paper["id"]
162
  with self.lock:
163
- return self.cache.get(key)
 
 
 
 
 
164
 
165
  def set(self, paper, result):
 
166
  key = paper["id"]
167
  with self.lock:
168
  self.cache[key] = result
@@ -212,12 +226,14 @@ def classify_papers(papers: List[Dict[str, str]]) -> Optional[List[Dict[str, Lis
212
  for f in futures:
213
  f.cancel()
214
  return None
215
- for paper, result in zip(uncached_papers, batch_result):
216
- paper_cache.set(paper, result)
217
  results.extend(batch_result)
218
- results.sort(key=lambda x: x['index'])
 
 
 
219
  return cached_results + results
220
- except Exception:
 
221
  return None
222
 
223
 
 
1
+ import traceback
2
  import time
3
  from typing import List, Dict, Optional
4
  from concurrent.futures import ThreadPoolExecutor, as_completed
5
  from threading import Lock
6
  from typing import List, Dict, Any
7
  import json
8
+ from rich.console import Console
9
 
10
  try:
11
  from .ai import complete
12
  except ImportError:
13
  from ai import complete
14
 
15
+ print = Console().log
16
+
17
 
18
  paper_types: Dict[str, str] = {
19
  "CV": "computer vision, any paper that deals with image, video, point cloud or 3D model data",
 
45
  ```json
46
  [
47
  {
48
+ "id": "2402.01032",
49
  "category": ["RO"]
50
  },
51
  {
52
+ "id": "2402.03254",
53
  "category": ["ML"]
54
  },
55
  {
56
+ "id": "2403.00043",
57
  "category": ["LLM", "NLP"]
58
  }
59
  ]
60
  ```
61
+
62
+ Do not add any additional information in the output. The order of the papers in the output should match the order of the papers in the input.
63
  """.strip(),
64
  """
65
  The followings are the papers you need to classify:
 
67
  ])
68
 
69
 
70
+ def build_paper(id: str, title: str, abstract: str = None) -> str:
71
 
72
  if abstract is None:
73
+ return f"{id}: {title}"
74
 
75
+ return f"{id}: {title}\n\n{abstract}"
76
 
77
 
78
  def get_classify_prompt(papers: List[Dict[str, str]]) -> str:
79
  prompt = []
80
 
81
  for index, paper in enumerate(papers, start=1):
82
+ prompt.append(build_paper(paper["id"], paper["title"], paper["abstract"] if "abstract" in paper else None))
83
 
84
  return user_prompt + "\n\n" + "\n\n".join(prompt)
85
 
86
 
87
  def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
88
+ print(response)
89
  # 匹配code block
90
  response = response.strip()
91
  if not response.startswith("```") or not response.endswith("```"):
 
97
  try:
98
  data = json.loads(response)
99
  except json.JSONDecodeError:
100
+ print(response)
101
  return None
102
 
103
  for paper in data:
104
+ if "id" not in paper or "category" not in paper:
105
  return None
106
 
107
+ if not isinstance(paper["id"], str) or not isinstance(paper["category"], list):
108
  return None
109
 
110
  for category in paper["category"]:
 
168
  def get(self, paper):
169
  key = paper["id"]
170
  with self.lock:
171
+ data = self.cache.get(key)
172
+ if data is not None:
173
+ print(f"Cache hit for {paper['id']}")
174
+ return data
175
+ print(f"Cache miss for {paper['id']}")
176
+ return None
177
 
178
  def set(self, paper, result):
179
+ print(f"Setting cache for {paper['id']}")
180
  key = paper["id"]
181
  with self.lock:
182
  self.cache[key] = result
 
226
  for f in futures:
227
  f.cancel()
228
  return None
 
 
229
  results.extend(batch_result)
230
+ print(results)
231
+ results.sort(key=lambda x: x['id'])
232
+ for result in results:
233
+ paper_cache.set(result, result)
234
  return cached_results + results
235
+ except Exception as e:
236
+ print(traceback.format_exc())
237
  return None
238
 
239
 
fetch_paper.py CHANGED
@@ -101,20 +101,21 @@ def fetch_papers_with_daterange(start_date: Date, end_date: Date):
101
 
102
  print(f"Unique articles: {len(unique_articles)}")
103
 
104
- unique_articles: List[Article] = list(unique_articles.values())
105
 
106
  preprocessed_articles = list(map(lambda article: {
107
  "title": article.title,
108
  "abstract": article.paper.summary,
109
  "id": article.paper.id
110
- }, unique_articles))
111
 
112
  classified_articles = classify_papers(preprocessed_articles)
113
 
114
- for i, article in enumerate(unique_articles):
115
- article.paper.label = classified_articles[i]["category"]
 
116
 
117
- return unique_articles
118
 
119
 
120
  if __name__ == "__main__":
 
101
 
102
  print(f"Unique articles: {len(unique_articles)}")
103
 
104
+ preprocessed_articles: List[Article] = list(unique_articles.values())
105
 
106
  preprocessed_articles = list(map(lambda article: {
107
  "title": article.title,
108
  "abstract": article.paper.summary,
109
  "id": article.paper.id
110
+ }, preprocessed_articles))
111
 
112
  classified_articles = classify_papers(preprocessed_articles)
113
 
114
+ # 遍历 classified_articles,将分类结果写入到 unique_articles
115
+ for article in classified_articles:
116
+ unique_articles[article["id"]].paper.label = article["category"]
117
 
118
+ return list(unique_articles.values())
119
 
120
 
121
  if __name__ == "__main__":