import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import gradio as gr def run_inference(review_text: str) -> str: """ Perform inference on the given wine review text and return the predicted wine variety. Args: review_text (str): Wine review text in the format "country [SEP] description". Returns: str: The predicted wine variety using the model's id2label mapping if available. """ # Define model and tokenizer identifiers model_id = "spawn99/modernbert-wine-classification" tokenizer_id = "answerdotai/ModernBERT-base" # Load tokenizer and model tokenizer = AutoTokenizer.from_pretrained(tokenizer_id) model = AutoModelForSequenceClassification.from_pretrained(model_id) # Tokenize the input text inputs = tokenizer( review_text, return_tensors="pt", padding="max_length", truncation=True, max_length=256 ) model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Determine prediction and map to label if available pred = torch.argmax(logits, dim=-1).item() variety = ( model.config.id2label.get(pred, str(pred)) if hasattr(model.config, "id2label") and model.config.id2label else str(pred) ) return variety def predict_wine_variety(country: str, description: str) -> dict: """ Combine the provided country and description, then perform inference. Enforces a maximum character limit of 750 on the description. Args: country (str): The country of wine origin. description (str): The wine review description. Returns: dict: Dictionary containing the predicted wine variety or an error message if the limit is exceeded. """ # Validate description length if len(description) > 750: return {"error": "Description exceeds 750 character limit. Please shorten your input."} # Capitalize input values and format the review text accordingly. review_text = f"{country.capitalize()} [SEP] {description.capitalize()}" predicted_variety = run_inference(review_text) return {"Variety": predicted_variety} if __name__ == "__main__": iface = gr.Interface( fn=predict_wine_variety, inputs=[ gr.Textbox(label="Country", placeholder="Enter country of origin..."), gr.Textbox(label="Description", placeholder="Enter wine review description...") ], outputs=gr.JSON(label="Prediction"), title="Wine Variety Predictor", description="Predict the wine variety based on country and description.", flagging="never" ) iface.launch()