File size: 918 Bytes
ab631a4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from sentence_transformers import SentenceTransformer
import torch

class Model:
    def __init__(self):
        # Load the pre-trained model
        self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')

    def __call__(self, payload):
        # Extract inputs from the payload
        inputs = payload.get("inputs", {})
        source_sentence = inputs.get("source_sentence", "")
        sentences = inputs.get("sentences", [])

        # Combine source_sentence with sentences
        chunks = [source_sentence] + sentences
        # Generate embeddings
        embeddings = self.embedding_model.encode(chunks, convert_to_tensor=True)

        # Prepare response
        response = {
            "embeddings": embeddings.tolist(),  # Convert tensor to list for JSON serialization
            "shape": list(embeddings.shape)    # Return the shape of the embeddings tensor
        }
        return response