|
from transformers import Pipeline |
|
|
|
|
|
class LangDetectionPipeline(Pipeline): |
|
|
|
def _sanitize_parameters(self, **kwargs): |
|
preprocess_kwargs = {} |
|
if "text" in kwargs: |
|
preprocess_kwargs["text"] = kwargs["text"] |
|
return preprocess_kwargs, {}, {} |
|
|
|
def preprocess(self, text, **kwargs): |
|
|
|
return text |
|
|
|
def _forward(self, text, **kwargs): |
|
predictions, probabilities = self.model(text) |
|
return predictions, probabilities |
|
|
|
def postprocess(self, outputs, **kwargs): |
|
predictions, probabilities = outputs |
|
label = predictions[0][0].replace("__label__", "") |
|
confidence = float( |
|
probabilities[0][0] |
|
) |
|
|
|
|
|
model_output = {"label": label, "confidence": round(confidence * 100, 2)} |
|
return model_output |
|
|