File size: 833 Bytes
f51abca
 
 
5bf9e80
f51abca
5a5b916
5bf9e80
f51abca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import torch
import gradio as gr
from transformers import AutoConfig
from models import CustomClassifier, CustomClassificationConfig

MODEL_ID = "yhamidullah/custom-classifier-demo"
config = CustomClassificationConfig.from_pretrained(MODEL_ID)
model = CustomClassifier.from_pretrained(MODEL_ID)
model.eval()

def predict(input_csv: str):
    vec = [float(x) for x in input_csv.split(",")]
    if len(vec) != config.input_dim:
        return f"Error: Need {config.input_dim} floats"
    x = torch.tensor([vec])
    with torch.no_grad():
        logits = model(input_ids=x)["logits"]
    pred = logits.argmax(dim=-1).item()
    return f"Predicted class: {pred}"

demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(label="Input Vector (comma-separated)"),
    outputs="text",
    title="Custom Classifier Demo",
)
demo.launch()