sklearn-transformers / pipeline.py
merve's picture
merve HF staff
Update pipeline.py
9694378
raw
history blame
598 Bytes
import json
from typing import Any, Dict, List
import sklearn
import os
import joblib
import numpy as np
class PreTrainedPipeline():
def __init__(self, path: str):
# load the model
self.model = joblib.load((os.path.join(path, "pipeline.pkl"))
def __call__(self, inputs: str) -> List[Dict[str, float]]:
predictions = self.model.predict_proba([inputs])
labels = []
for cls in predictions[0]:
labels.append({
"label": f"LABEL_{cls}",
"score": predictions[0][cls],
})
return labels