Spaces:
Running
Running
File size: 1,682 Bytes
acd7cf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import asyncio
from dataclasses import dataclass
from tqdm.asyncio import tqdm as tqdm_async
from graphgen.utils import create_event_loop
from graphgen.models.text.text_pair import TextPair
@dataclass
class BaseEvaluator:
max_concurrent: int = 100
results: list[float] = None
def evaluate(self, pairs: list[TextPair]) -> list[float]:
"""
Evaluate the text and return a score.
"""
return create_event_loop().run_until_complete(self.async_evaluate(pairs))
async def async_evaluate(self, pairs: list[TextPair]) -> list[float]:
semaphore = asyncio.Semaphore(self.max_concurrent)
async def evaluate_with_semaphore(pair):
async with semaphore: # 获取Semaphore
return await self.evaluate_single(pair)
results = []
for result in tqdm_async(
asyncio.as_completed([evaluate_with_semaphore(pair) for pair in pairs]),
total=len(pairs),
):
results.append(await result)
return results
async def evaluate_single(self, pair: TextPair) -> float:
raise NotImplementedError()
def get_average_score(self, pairs: list[TextPair]) -> float:
"""
Get the average score of a batch of texts.
"""
results = self.evaluate(pairs)
self.results = results
return sum(self.results) / len(pairs)
def get_min_max_score(self, pairs: list[TextPair]) -> tuple[float, float]:
"""
Get the min and max score of a batch of texts.
"""
if self.results is None:
self.get_average_score(pairs)
return min(self.results), max(self.results)
|