File size: 918 Bytes
ab631a4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from sentence_transformers import SentenceTransformer
import torch
class Model:
def __init__(self):
# Load the pre-trained model
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
def __call__(self, payload):
# Extract inputs from the payload
inputs = payload.get("inputs", {})
source_sentence = inputs.get("source_sentence", "")
sentences = inputs.get("sentences", [])
# Combine source_sentence with sentences
chunks = [source_sentence] + sentences
# Generate embeddings
embeddings = self.embedding_model.encode(chunks, convert_to_tensor=True)
# Prepare response
response = {
"embeddings": embeddings.tolist(), # Convert tensor to list for JSON serialization
"shape": list(embeddings.shape) # Return the shape of the embeddings tensor
}
return response
|