import pandas as pd | |
import streamlit as st | |
import torch | |
from src.model_finetuning.config import CONFIG | |
from src.model_finetuning.model import BertLightningModel | |
from src.utils import get_target_columns | |
def load_model() -> BertLightningModel: | |
ckpt_path = "demo/model.ckpt" | |
model = BertLightningModel.load_from_checkpoint(ckpt_path, config=CONFIG, map_location='cpu') | |
return model | |
def process_text(_text: str, _model: BertLightningModel) -> pd.DataFrame: | |
tokens = _model.tokenizer([_text], return_tensors='pt') | |
outputs = _model(tokens)[0].tolist() | |
df = pd.DataFrame({ | |
'Criterion': get_target_columns(), | |
'Grade': outputs | |
}) | |
return df | |