|
from abc import abstractmethod |
|
|
|
import torch |
|
import numpy as np |
|
|
|
from .utils import recursive_find_device, recursive_find_dtype |
|
|
|
class EncodingSampler: |
|
""" |
|
Class to sample encodings given low dimensional spatial relationships. |
|
""" |
|
def __init__(self, encodes): |
|
self.encodes = encodes |
|
|
|
def apply_coefs(self, coefs): |
|
""" |
|
Linear combination of encodings given coefs |
|
""" |
|
device = recursive_find_device(self.encodes) |
|
dtype = recursive_find_dtype(self.encodes) |
|
|
|
|
|
|
|
coefs = torch.from_numpy(coefs).to(dtype).to(device) |
|
|
|
def single_apply(encodes): |
|
if encodes is None: |
|
return None |
|
elif len(encodes.shape) == 3: |
|
return (coefs[:,None,None] * encodes).sum(0) |
|
elif len(encodes.shape) == 2: |
|
return (coefs[:,None] * encodes).sum(0) |
|
else: |
|
raise ValueError("Encoding Sampler couldn't figure out shape of encodings") |
|
|
|
if isinstance(self.encodes, list) or isinstance(self.encodes, tuple): |
|
return list(map(single_apply, self.encodes)) |
|
else: |
|
return single_apply(self.encodes) |
|
|
|
@abstractmethod |
|
def __call__(self, point, other_points): |
|
""" |
|
:param point: Point in low space representing user input ([2,] array) |
|
:param other_points: Points in low space representing existing prompts ([N,2] array) |
|
""" |
|
pass |
|
|
|
class DistanceSampling(EncodingSampler): |
|
""" |
|
Sample based on distances between points in low dim space |
|
""" |
|
def __call__(self, point, other_points): |
|
coefs = 1. / ((1. + np.linalg.norm(point[None,:] - other_points, axis = 1) ** 2)) |
|
return self.apply_coefs(coefs) |
|
|
|
class CircleSampling(EncodingSampler): |
|
""" |
|
Sampler that views all encodings as points on a unit circle |
|
""" |
|
def __call__(self, point, other_points): |
|
|
|
|
|
cos_sims = point @ other_points.transpose() |
|
|
|
|
|
|
|
|
|
return self.apply_coefs(cos_sims) |
|
|