Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from sentence_transformers import SentenceTransformer | |
import pickle | |
import os | |
from pydantic import BaseModel | |
import numpy as np | |
from typing import List | |
app = FastAPI( | |
title="SBERT Embedding API", | |
description="API for generating sentence embeddings using SBERT", | |
version="1.0" | |
) | |
# Load model (this will be cached after first load) | |
model_name = 'taghyan/model' | |
model = SentenceTransformer(model_name) | |
# Embedding cache setup | |
embedding_file = 'embeddings_sbert.pkl' | |
class TextRequest(BaseModel): | |
text: str | |
class TextsRequest(BaseModel): | |
texts: List[str] | |
class EmbeddingResponse(BaseModel): | |
embedding: List[float] | |
class EmbeddingsResponse(BaseModel): | |
embeddings: List[List[float]] | |
def read_root(): | |
return {"message": "SBERT Embedding Service"} | |
async def embed_text(request: TextRequest): | |
"""Generate embedding for a single text""" | |
embedding = model.encode(request.text, convert_to_numpy=True).tolist() | |
return {"embedding": embedding} | |
async def embed_texts(request: TextsRequest): | |
"""Generate embeddings for multiple texts""" | |
embeddings = model.encode(request.texts, show_progress_bar=True, convert_to_numpy=True).tolist() | |
return {"embeddings": embeddings} | |
async def update_cache(request: TextsRequest): | |
"""Update the embedding cache with new texts""" | |
if os.path.exists(embedding_file): | |
with open(embedding_file, 'rb') as f: | |
existing_embeddings = pickle.load(f) | |
else: | |
existing_embeddings = [] | |
new_embeddings = model.encode(request.texts, show_progress_bar=True) | |
updated_embeddings = existing_embeddings + new_embeddings.tolist() | |
with open(embedding_file, 'wb') as f: | |
pickle.dump(updated_embeddings, f) | |
return {"message": f"Cache updated with {len(request.texts)} new embeddings", "total_embeddings": len(updated_embeddings)} |