MassageMateNLP / app.py
BiEchi
final push
24fe5ed
raw
history blame
981 Bytes
import gradio as gr
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import BertModel
# ignore warnings
import warnings
warnings.filterwarnings("ignore")
def infer(text):
output_str = ''
for col in ['position_x', 'position_y', 'force', 'velocity_xy', 'velocity_z']:
model_path = f'models/bert/{col}'
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
scores = output[0].detach().cpu().numpy()[0]
answer = ['-1', '0', '1'][scores.argmax()]
output_str += f'{col}: {answer}\n'
return output_str
iface = gr.Interface(fn=infer, inputs="text", outputs="text")
iface.launch(share=True)