from transformers import Pipeline | |
class LangDetectionPipeline(Pipeline): | |
def _sanitize_parameters(self, **kwargs): | |
preprocess_kwargs = {} | |
if "text" in kwargs: | |
preprocess_kwargs["text"] = kwargs["text"] | |
return preprocess_kwargs, {}, {} | |
def preprocess(self, text, **kwargs): | |
# Nothing to preprocess | |
return text | |
def _forward(self, text, **kwargs): | |
predictions, probabilities = self.model(text.replace("\n", " ")) | |
return predictions, probabilities | |
def postprocess(self, outputs, **kwargs): | |
predictions, probabilities = outputs | |
label = predictions[0][0].replace("__label__", "") # Remove __label__ prefix | |
confidence = float( | |
probabilities[0][0] | |
) # Convert to float for JSON serialization | |
# Format as JSON-compatible dictionary | |
model_output = {"language": label, "score": round(confidence, 2)} | |
return model_output | |