|
import pandas as pd |
|
import streamlit as st |
|
|
|
from category_classification.models import models as class_models |
|
from common import Input |
|
from languages import * |
|
from results import process_results |
|
|
|
page_title = {en: "Papers classification", ru: "Классификация статей"} |
|
model_label = {en: "Select model", ru: "Выберете модель"} |
|
title_label = {en: "Title", ru: "Название статьи"} |
|
authors_label = {en: "Author(s)", ru: "Автор(ы)"} |
|
abstract_label = {en: "Abstract", ru: "Аннотация"} |
|
|
|
|
|
def text_area_height(line_height: int): |
|
return 34 * line_height |
|
|
|
|
|
@st.cache_data |
|
def load_class_model(name): |
|
model = class_models.get_model(name) |
|
return model |
|
|
|
|
|
lang = st.pills(label=langs_str, options=langs, default=en) |
|
st.title(page_title[lang]) |
|
model_name = st.selectbox( |
|
model_label[lang], options=class_models.get_model_names_by_lang(lang) |
|
) |
|
title = st.text_area(title_label[lang], height=text_area_height(2)) |
|
authors = st.text_area(authors_label[lang], height=text_area_height(2)) |
|
abstract = st.text_area(abstract_label[lang], height=text_area_height(5)) |
|
|
|
if title: |
|
input = Input(title=title, abstract=abstract, authors=authors) |
|
model = load_class_model(model_name) |
|
results = model(input) |
|
results = process_results(results, lang) |
|
st.dataframe(results, hide_index=True) |
|
|