File size: 743 Bytes
8b414b0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import pandas as pd
import streamlit as st
import torch
from src.model_finetuning.config import CONFIG
from src.model_finetuning.model import BertLightningModel
from src.utils import get_target_columns
@st.cache(allow_output_mutation=True)
def load_model() -> BertLightningModel:
ckpt_path = "demo/model.ckpt"
model = BertLightningModel.load_from_checkpoint(ckpt_path, config=CONFIG, map_location='cpu')
return model
@torch.no_grad()
def process_text(_text: str, _model: BertLightningModel) -> pd.DataFrame:
tokens = _model.tokenizer([_text], return_tensors='pt')
outputs = _model(tokens)[0].tolist()
df = pd.DataFrame({
'Criterion': get_target_columns(),
'Grade': outputs
})
return df
|