Commit
·
ac02643
1
Parent(s):
f432512
fix: label cache error
Browse files- ai/classify_paper.py +30 -14
- fetch_paper.py +6 -5
ai/classify_paper.py
CHANGED
@@ -1,15 +1,19 @@
|
|
|
|
1 |
import time
|
2 |
from typing import List, Dict, Optional
|
3 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
4 |
from threading import Lock
|
5 |
from typing import List, Dict, Any
|
6 |
import json
|
|
|
7 |
|
8 |
try:
|
9 |
from .ai import complete
|
10 |
except ImportError:
|
11 |
from ai import complete
|
12 |
|
|
|
|
|
13 |
|
14 |
paper_types: Dict[str, str] = {
|
15 |
"CV": "computer vision, any paper that deals with image, video, point cloud or 3D model data",
|
@@ -41,19 +45,21 @@ You should output in the following format with a code block:
|
|
41 |
```json
|
42 |
[
|
43 |
{
|
44 |
-
"
|
45 |
"category": ["RO"]
|
46 |
},
|
47 |
{
|
48 |
-
"
|
49 |
"category": ["ML"]
|
50 |
},
|
51 |
{
|
52 |
-
"
|
53 |
"category": ["LLM", "NLP"]
|
54 |
}
|
55 |
]
|
56 |
```
|
|
|
|
|
57 |
""".strip(),
|
58 |
"""
|
59 |
The followings are the papers you need to classify:
|
@@ -61,24 +67,25 @@ The followings are the papers you need to classify:
|
|
61 |
])
|
62 |
|
63 |
|
64 |
-
def build_paper(
|
65 |
|
66 |
if abstract is None:
|
67 |
-
return f"{
|
68 |
|
69 |
-
return f"{
|
70 |
|
71 |
|
72 |
def get_classify_prompt(papers: List[Dict[str, str]]) -> str:
|
73 |
prompt = []
|
74 |
|
75 |
for index, paper in enumerate(papers, start=1):
|
76 |
-
prompt.append(build_paper(paper["
|
77 |
|
78 |
return user_prompt + "\n\n" + "\n\n".join(prompt)
|
79 |
|
80 |
|
81 |
def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
|
|
|
82 |
# 匹配code block
|
83 |
response = response.strip()
|
84 |
if not response.startswith("```") or not response.endswith("```"):
|
@@ -90,13 +97,14 @@ def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
|
|
90 |
try:
|
91 |
data = json.loads(response)
|
92 |
except json.JSONDecodeError:
|
|
|
93 |
return None
|
94 |
|
95 |
for paper in data:
|
96 |
-
if "
|
97 |
return None
|
98 |
|
99 |
-
if not isinstance(paper["
|
100 |
return None
|
101 |
|
102 |
for category in paper["category"]:
|
@@ -160,9 +168,15 @@ class PaperCache:
|
|
160 |
def get(self, paper):
|
161 |
key = paper["id"]
|
162 |
with self.lock:
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
164 |
|
165 |
def set(self, paper, result):
|
|
|
166 |
key = paper["id"]
|
167 |
with self.lock:
|
168 |
self.cache[key] = result
|
@@ -212,12 +226,14 @@ def classify_papers(papers: List[Dict[str, str]]) -> Optional[List[Dict[str, Lis
|
|
212 |
for f in futures:
|
213 |
f.cancel()
|
214 |
return None
|
215 |
-
for paper, result in zip(uncached_papers, batch_result):
|
216 |
-
paper_cache.set(paper, result)
|
217 |
results.extend(batch_result)
|
218 |
-
results
|
|
|
|
|
|
|
219 |
return cached_results + results
|
220 |
-
except Exception:
|
|
|
221 |
return None
|
222 |
|
223 |
|
|
|
1 |
+
import traceback
|
2 |
import time
|
3 |
from typing import List, Dict, Optional
|
4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
5 |
from threading import Lock
|
6 |
from typing import List, Dict, Any
|
7 |
import json
|
8 |
+
from rich.console import Console
|
9 |
|
10 |
try:
|
11 |
from .ai import complete
|
12 |
except ImportError:
|
13 |
from ai import complete
|
14 |
|
15 |
+
print = Console().log
|
16 |
+
|
17 |
|
18 |
paper_types: Dict[str, str] = {
|
19 |
"CV": "computer vision, any paper that deals with image, video, point cloud or 3D model data",
|
|
|
45 |
```json
|
46 |
[
|
47 |
{
|
48 |
+
"id": "2402.01032",
|
49 |
"category": ["RO"]
|
50 |
},
|
51 |
{
|
52 |
+
"id": "2402.03254",
|
53 |
"category": ["ML"]
|
54 |
},
|
55 |
{
|
56 |
+
"id": "2403.00043",
|
57 |
"category": ["LLM", "NLP"]
|
58 |
}
|
59 |
]
|
60 |
```
|
61 |
+
|
62 |
+
Do not add any additional information in the output. The order of the papers in the output should match the order of the papers in the input.
|
63 |
""".strip(),
|
64 |
"""
|
65 |
The followings are the papers you need to classify:
|
|
|
67 |
])
|
68 |
|
69 |
|
70 |
+
def build_paper(id: str, title: str, abstract: str = None) -> str:
|
71 |
|
72 |
if abstract is None:
|
73 |
+
return f"{id}: {title}"
|
74 |
|
75 |
+
return f"{id}: {title}\n\n{abstract}"
|
76 |
|
77 |
|
78 |
def get_classify_prompt(papers: List[Dict[str, str]]) -> str:
|
79 |
prompt = []
|
80 |
|
81 |
for index, paper in enumerate(papers, start=1):
|
82 |
+
prompt.append(build_paper(paper["id"], paper["title"], paper["abstract"] if "abstract" in paper else None))
|
83 |
|
84 |
return user_prompt + "\n\n" + "\n\n".join(prompt)
|
85 |
|
86 |
|
87 |
def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
|
88 |
+
print(response)
|
89 |
# 匹配code block
|
90 |
response = response.strip()
|
91 |
if not response.startswith("```") or not response.endswith("```"):
|
|
|
97 |
try:
|
98 |
data = json.loads(response)
|
99 |
except json.JSONDecodeError:
|
100 |
+
print(response)
|
101 |
return None
|
102 |
|
103 |
for paper in data:
|
104 |
+
if "id" not in paper or "category" not in paper:
|
105 |
return None
|
106 |
|
107 |
+
if not isinstance(paper["id"], str) or not isinstance(paper["category"], list):
|
108 |
return None
|
109 |
|
110 |
for category in paper["category"]:
|
|
|
168 |
def get(self, paper):
|
169 |
key = paper["id"]
|
170 |
with self.lock:
|
171 |
+
data = self.cache.get(key)
|
172 |
+
if data is not None:
|
173 |
+
print(f"Cache hit for {paper['id']}")
|
174 |
+
return data
|
175 |
+
print(f"Cache miss for {paper['id']}")
|
176 |
+
return None
|
177 |
|
178 |
def set(self, paper, result):
|
179 |
+
print(f"Setting cache for {paper['id']}")
|
180 |
key = paper["id"]
|
181 |
with self.lock:
|
182 |
self.cache[key] = result
|
|
|
226 |
for f in futures:
|
227 |
f.cancel()
|
228 |
return None
|
|
|
|
|
229 |
results.extend(batch_result)
|
230 |
+
print(results)
|
231 |
+
results.sort(key=lambda x: x['id'])
|
232 |
+
for result in results:
|
233 |
+
paper_cache.set(result, result)
|
234 |
return cached_results + results
|
235 |
+
except Exception as e:
|
236 |
+
print(traceback.format_exc())
|
237 |
return None
|
238 |
|
239 |
|
fetch_paper.py
CHANGED
@@ -101,20 +101,21 @@ def fetch_papers_with_daterange(start_date: Date, end_date: Date):
|
|
101 |
|
102 |
print(f"Unique articles: {len(unique_articles)}")
|
103 |
|
104 |
-
|
105 |
|
106 |
preprocessed_articles = list(map(lambda article: {
|
107 |
"title": article.title,
|
108 |
"abstract": article.paper.summary,
|
109 |
"id": article.paper.id
|
110 |
-
},
|
111 |
|
112 |
classified_articles = classify_papers(preprocessed_articles)
|
113 |
|
114 |
-
|
115 |
-
|
|
|
116 |
|
117 |
-
return unique_articles
|
118 |
|
119 |
|
120 |
if __name__ == "__main__":
|
|
|
101 |
|
102 |
print(f"Unique articles: {len(unique_articles)}")
|
103 |
|
104 |
+
preprocessed_articles: List[Article] = list(unique_articles.values())
|
105 |
|
106 |
preprocessed_articles = list(map(lambda article: {
|
107 |
"title": article.title,
|
108 |
"abstract": article.paper.summary,
|
109 |
"id": article.paper.id
|
110 |
+
}, preprocessed_articles))
|
111 |
|
112 |
classified_articles = classify_papers(preprocessed_articles)
|
113 |
|
114 |
+
# 遍历 classified_articles,将分类结果写入到 unique_articles 中
|
115 |
+
for article in classified_articles:
|
116 |
+
unique_articles[article["id"]].paper.label = article["category"]
|
117 |
|
118 |
+
return list(unique_articles.values())
|
119 |
|
120 |
|
121 |
if __name__ == "__main__":
|