File size: 959 Bytes
ae8276c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from transformers import Pipeline
class LangDetectionPipeline(Pipeline):
def _sanitize_parameters(self, **kwargs):
preprocess_kwargs = {}
if "text" in kwargs:
preprocess_kwargs["text"] = kwargs["text"]
return preprocess_kwargs, {}, {}
def preprocess(self, text, **kwargs):
# Nothing to preprocess
return text
def _forward(self, text, **kwargs):
predictions, probabilities = self.model(text)
return predictions, probabilities
def postprocess(self, outputs, **kwargs):
predictions, probabilities = outputs
label = predictions[0][0].replace("__label__", "") # Remove __label__ prefix
confidence = float(
probabilities[0][0]
) # Convert to float for JSON serialization
# Format as JSON-compatible dictionary
model_output = {"label": label, "confidence": round(confidence * 100, 2)}
return model_output
|