Draken007's picture
Upload 7228 files
2a0bc63 verified
from typing import Callable, Dict, List, Tuple
import numpy as np
VectorType = List[float]
# distance definitions. These all work batched in the first argument.
def distance_dot_product(
embedding_vectors: List[VectorType], reference_embedding_vector: VectorType
) -> List[float]:
"""
Given a list [emb_i] and a reference rEmb vector,
return a list [distance_i] where each distance is
distance_i = distance(emb_i, rEmb)
At the moment only the dot product is supported
(which for unitary vectors is the cosine difference).
Not particularly optimized.
"""
v1s = np.array(embedding_vectors, dtype=float)
v2 = np.array(reference_embedding_vector, dtype=float)
return list(
np.dot(
v1s,
v2.T,
)
)
def distance_cos_difference(
embedding_vectors: List[VectorType], reference_embedding_vector: VectorType
) -> List[float]:
v1s = np.array(embedding_vectors, dtype=float)
v2 = np.array(reference_embedding_vector, dtype=float)
return list(
np.dot(
v1s,
v2.T,
)
/ (np.linalg.norm(v1s, axis=1) * np.linalg.norm(v2))
)
def distance_l1(
embedding_vectors: List[VectorType], reference_embedding_vector: VectorType
) -> List[float]:
v1s = np.array(embedding_vectors, dtype=float)
v2 = np.array(reference_embedding_vector, dtype=float)
return list(np.linalg.norm(v1s - v2, axis=1, ord=1))
def distance_l2(
embedding_vectors: List[VectorType], reference_embedding_vector: VectorType
) -> List[float]:
v1s = np.array(embedding_vectors, dtype=float)
v2 = np.array(reference_embedding_vector, dtype=float)
return list(np.linalg.norm(v1s - v2, axis=1, ord=2))
def distance_max(
embedding_vectors: List[VectorType], reference_embedding_vector: VectorType
) -> List[float]:
v1s = np.array(embedding_vectors, dtype=float)
v2 = np.array(reference_embedding_vector, dtype=float)
return list(np.linalg.norm(v1s - v2, axis=1, ord=np.inf))
# The tuple is:
# (
# function,
# sorting 'reverse' argument, nearest-to-farthest
# )
# (i.e. True means that:
# - in that metric higher is closer and that
# - cutoff should be metric > threshold)
distance_metrics: Dict[
str, Tuple[Callable[[List[VectorType], VectorType], List[float]], bool]
] = {
"cos": (
distance_cos_difference,
True,
),
"dot": (
distance_dot_product,
True,
),
"l1": (
distance_l1,
False,
),
"l2": (
distance_l2,
False,
),
"max": (
distance_max,
False,
),
}