Spaces:
Runtime error
Runtime error
Commit
·
ac02643
1
Parent(s):
f432512
fix: label cache error
Browse files- ai/classify_paper.py +30 -14
- fetch_paper.py +6 -5
ai/classify_paper.py
CHANGED
|
@@ -1,15 +1,19 @@
|
|
|
|
|
| 1 |
import time
|
| 2 |
from typing import List, Dict, Optional
|
| 3 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 4 |
from threading import Lock
|
| 5 |
from typing import List, Dict, Any
|
| 6 |
import json
|
|
|
|
| 7 |
|
| 8 |
try:
|
| 9 |
from .ai import complete
|
| 10 |
except ImportError:
|
| 11 |
from ai import complete
|
| 12 |
|
|
|
|
|
|
|
| 13 |
|
| 14 |
paper_types: Dict[str, str] = {
|
| 15 |
"CV": "computer vision, any paper that deals with image, video, point cloud or 3D model data",
|
|
@@ -41,19 +45,21 @@ You should output in the following format with a code block:
|
|
| 41 |
```json
|
| 42 |
[
|
| 43 |
{
|
| 44 |
-
"
|
| 45 |
"category": ["RO"]
|
| 46 |
},
|
| 47 |
{
|
| 48 |
-
"
|
| 49 |
"category": ["ML"]
|
| 50 |
},
|
| 51 |
{
|
| 52 |
-
"
|
| 53 |
"category": ["LLM", "NLP"]
|
| 54 |
}
|
| 55 |
]
|
| 56 |
```
|
|
|
|
|
|
|
| 57 |
""".strip(),
|
| 58 |
"""
|
| 59 |
The followings are the papers you need to classify:
|
|
@@ -61,24 +67,25 @@ The followings are the papers you need to classify:
|
|
| 61 |
])
|
| 62 |
|
| 63 |
|
| 64 |
-
def build_paper(
|
| 65 |
|
| 66 |
if abstract is None:
|
| 67 |
-
return f"{
|
| 68 |
|
| 69 |
-
return f"{
|
| 70 |
|
| 71 |
|
| 72 |
def get_classify_prompt(papers: List[Dict[str, str]]) -> str:
|
| 73 |
prompt = []
|
| 74 |
|
| 75 |
for index, paper in enumerate(papers, start=1):
|
| 76 |
-
prompt.append(build_paper(paper["
|
| 77 |
|
| 78 |
return user_prompt + "\n\n" + "\n\n".join(prompt)
|
| 79 |
|
| 80 |
|
| 81 |
def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
|
|
|
|
| 82 |
# 匹配code block
|
| 83 |
response = response.strip()
|
| 84 |
if not response.startswith("```") or not response.endswith("```"):
|
|
@@ -90,13 +97,14 @@ def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
|
|
| 90 |
try:
|
| 91 |
data = json.loads(response)
|
| 92 |
except json.JSONDecodeError:
|
|
|
|
| 93 |
return None
|
| 94 |
|
| 95 |
for paper in data:
|
| 96 |
-
if "
|
| 97 |
return None
|
| 98 |
|
| 99 |
-
if not isinstance(paper["
|
| 100 |
return None
|
| 101 |
|
| 102 |
for category in paper["category"]:
|
|
@@ -160,9 +168,15 @@ class PaperCache:
|
|
| 160 |
def get(self, paper):
|
| 161 |
key = paper["id"]
|
| 162 |
with self.lock:
|
| 163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
def set(self, paper, result):
|
|
|
|
| 166 |
key = paper["id"]
|
| 167 |
with self.lock:
|
| 168 |
self.cache[key] = result
|
|
@@ -212,12 +226,14 @@ def classify_papers(papers: List[Dict[str, str]]) -> Optional[List[Dict[str, Lis
|
|
| 212 |
for f in futures:
|
| 213 |
f.cancel()
|
| 214 |
return None
|
| 215 |
-
for paper, result in zip(uncached_papers, batch_result):
|
| 216 |
-
paper_cache.set(paper, result)
|
| 217 |
results.extend(batch_result)
|
| 218 |
-
results
|
|
|
|
|
|
|
|
|
|
| 219 |
return cached_results + results
|
| 220 |
-
except Exception:
|
|
|
|
| 221 |
return None
|
| 222 |
|
| 223 |
|
|
|
|
| 1 |
+
import traceback
|
| 2 |
import time
|
| 3 |
from typing import List, Dict, Optional
|
| 4 |
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
from threading import Lock
|
| 6 |
from typing import List, Dict, Any
|
| 7 |
import json
|
| 8 |
+
from rich.console import Console
|
| 9 |
|
| 10 |
try:
|
| 11 |
from .ai import complete
|
| 12 |
except ImportError:
|
| 13 |
from ai import complete
|
| 14 |
|
| 15 |
+
print = Console().log
|
| 16 |
+
|
| 17 |
|
| 18 |
paper_types: Dict[str, str] = {
|
| 19 |
"CV": "computer vision, any paper that deals with image, video, point cloud or 3D model data",
|
|
|
|
| 45 |
```json
|
| 46 |
[
|
| 47 |
{
|
| 48 |
+
"id": "2402.01032",
|
| 49 |
"category": ["RO"]
|
| 50 |
},
|
| 51 |
{
|
| 52 |
+
"id": "2402.03254",
|
| 53 |
"category": ["ML"]
|
| 54 |
},
|
| 55 |
{
|
| 56 |
+
"id": "2403.00043",
|
| 57 |
"category": ["LLM", "NLP"]
|
| 58 |
}
|
| 59 |
]
|
| 60 |
```
|
| 61 |
+
|
| 62 |
+
Do not add any additional information in the output. The order of the papers in the output should match the order of the papers in the input.
|
| 63 |
""".strip(),
|
| 64 |
"""
|
| 65 |
The followings are the papers you need to classify:
|
|
|
|
| 67 |
])
|
| 68 |
|
| 69 |
|
| 70 |
+
def build_paper(id: str, title: str, abstract: str = None) -> str:
|
| 71 |
|
| 72 |
if abstract is None:
|
| 73 |
+
return f"{id}: {title}"
|
| 74 |
|
| 75 |
+
return f"{id}: {title}\n\n{abstract}"
|
| 76 |
|
| 77 |
|
| 78 |
def get_classify_prompt(papers: List[Dict[str, str]]) -> str:
|
| 79 |
prompt = []
|
| 80 |
|
| 81 |
for index, paper in enumerate(papers, start=1):
|
| 82 |
+
prompt.append(build_paper(paper["id"], paper["title"], paper["abstract"] if "abstract" in paper else None))
|
| 83 |
|
| 84 |
return user_prompt + "\n\n" + "\n\n".join(prompt)
|
| 85 |
|
| 86 |
|
| 87 |
def parse_response(response: str) -> List[Dict[str, List[str]]] | None:
|
| 88 |
+
print(response)
|
| 89 |
# 匹配code block
|
| 90 |
response = response.strip()
|
| 91 |
if not response.startswith("```") or not response.endswith("```"):
|
|
|
|
| 97 |
try:
|
| 98 |
data = json.loads(response)
|
| 99 |
except json.JSONDecodeError:
|
| 100 |
+
print(response)
|
| 101 |
return None
|
| 102 |
|
| 103 |
for paper in data:
|
| 104 |
+
if "id" not in paper or "category" not in paper:
|
| 105 |
return None
|
| 106 |
|
| 107 |
+
if not isinstance(paper["id"], str) or not isinstance(paper["category"], list):
|
| 108 |
return None
|
| 109 |
|
| 110 |
for category in paper["category"]:
|
|
|
|
| 168 |
def get(self, paper):
|
| 169 |
key = paper["id"]
|
| 170 |
with self.lock:
|
| 171 |
+
data = self.cache.get(key)
|
| 172 |
+
if data is not None:
|
| 173 |
+
print(f"Cache hit for {paper['id']}")
|
| 174 |
+
return data
|
| 175 |
+
print(f"Cache miss for {paper['id']}")
|
| 176 |
+
return None
|
| 177 |
|
| 178 |
def set(self, paper, result):
|
| 179 |
+
print(f"Setting cache for {paper['id']}")
|
| 180 |
key = paper["id"]
|
| 181 |
with self.lock:
|
| 182 |
self.cache[key] = result
|
|
|
|
| 226 |
for f in futures:
|
| 227 |
f.cancel()
|
| 228 |
return None
|
|
|
|
|
|
|
| 229 |
results.extend(batch_result)
|
| 230 |
+
print(results)
|
| 231 |
+
results.sort(key=lambda x: x['id'])
|
| 232 |
+
for result in results:
|
| 233 |
+
paper_cache.set(result, result)
|
| 234 |
return cached_results + results
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(traceback.format_exc())
|
| 237 |
return None
|
| 238 |
|
| 239 |
|
fetch_paper.py
CHANGED
|
@@ -101,20 +101,21 @@ def fetch_papers_with_daterange(start_date: Date, end_date: Date):
|
|
| 101 |
|
| 102 |
print(f"Unique articles: {len(unique_articles)}")
|
| 103 |
|
| 104 |
-
|
| 105 |
|
| 106 |
preprocessed_articles = list(map(lambda article: {
|
| 107 |
"title": article.title,
|
| 108 |
"abstract": article.paper.summary,
|
| 109 |
"id": article.paper.id
|
| 110 |
-
},
|
| 111 |
|
| 112 |
classified_articles = classify_papers(preprocessed_articles)
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
|
|
|
| 116 |
|
| 117 |
-
return unique_articles
|
| 118 |
|
| 119 |
|
| 120 |
if __name__ == "__main__":
|
|
|
|
| 101 |
|
| 102 |
print(f"Unique articles: {len(unique_articles)}")
|
| 103 |
|
| 104 |
+
preprocessed_articles: List[Article] = list(unique_articles.values())
|
| 105 |
|
| 106 |
preprocessed_articles = list(map(lambda article: {
|
| 107 |
"title": article.title,
|
| 108 |
"abstract": article.paper.summary,
|
| 109 |
"id": article.paper.id
|
| 110 |
+
}, preprocessed_articles))
|
| 111 |
|
| 112 |
classified_articles = classify_papers(preprocessed_articles)
|
| 113 |
|
| 114 |
+
# 遍历 classified_articles,将分类结果写入到 unique_articles 中
|
| 115 |
+
for article in classified_articles:
|
| 116 |
+
unique_articles[article["id"]].paper.label = article["category"]
|
| 117 |
|
| 118 |
+
return list(unique_articles.values())
|
| 119 |
|
| 120 |
|
| 121 |
if __name__ == "__main__":
|