Spaces:
Running
Running
File size: 1,017 Bytes
b82a487 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from enum import Enum
import wandb
class SimilarityMetric(Enum):
COSINE = "cosine"
EUCLIDEAN = "euclidean"
def mean_pooling(token_embeddings, mask):
token_embeddings = token_embeddings.masked_fill(~mask[..., None].bool(), 0.0)
sentence_embeddings = token_embeddings.sum(dim=1) / mask.sum(dim=1)[..., None]
return sentence_embeddings
def get_wandb_artifact(artifact_address: str, artifact_type: str):
if wandb.run:
artifact = wandb.run.use_artifact(artifact_address, type=artifact_type)
artifact_dir = artifact.download()
else:
api = wandb.Api()
artifact = api.artifact(artifact_address)
artifact_dir = artifact.download()
metadata = artifact.metadata
return artifact_dir, metadata
def argsort_scores(scores: list[float], descending: bool = False):
return [
{"item": item, "original_index": idx}
for idx, item in sorted(
list(enumerate(scores)), key=lambda x: x[1], reverse=descending
)
]
|