import gradio as gr from transformers import pipeline import torch # Initialize the zero-shot classification pipeline try: classifier = pipeline( "zero-shot-classification", model="models/tasksource/ModernBERT-nli", device=0 if torch.cuda.is_available() else -1 ) except Exception as e: print(f"Error loading model: {e}") classifier = None def classify_text(text, candidate_labels): """ Perform zero-shot classification on input text. Args: text (str): Input text to classify candidate_labels (str): Comma-separated string of possible labels Returns: list: List of (label, score) tuples """ if classifier is None: # Return a default response when model fails to load return [("error", 1.0)] try: # Convert comma-separated string to list labels = [label.strip() for label in candidate_labels.split(",")] # Perform classification result = classifier(text, labels) # Convert results to list of (label, score) tuples return list(zip(result["labels"], result["scores"])) except Exception as e: print(f"Classification error: {e}") return [("error", 1.0)] # Create Gradio interface iface = gr.Interface( fn=classify_text, inputs=[ gr.Textbox( label="Text to classify", placeholder="Enter text here...", value="all cats are blue" ), gr.Textbox( label="Possible labels (comma-separated)", placeholder="Enter labels...", value="true,false" ) ], outputs=gr.Label(label="Classification Results"), title="Zero-Shot Text Classification", description="Classify text into given categories without any training examples.", examples=[ ["all cats are blue", "true,false"], ["the sky is above us", "true,false"], ["birds can fly", "true,false,unknown"] ] ) # Launch the app if __name__ == "__main__": iface.launch(share=True)