File size: 1,787 Bytes
3c2e6b2
 
 
ce0c8db
eaba143
 
7aae4ed
9d5b1d9
ce0c8db
9d5b1d9
90a7e15
9d5b1d9
90a7e15
fe3279c
9d5b1d9
3c2e6b2
 
 
 
 
 
9d5b1d9
eaba143
fe3279c
 
 
 
eaba143
 
 
 
 
3c2e6b2
87e1cd4
 
 
90a7e15
eaba143
9d5b1d9
eaba143
fe3279c
 
 
eaba143
 
82ec9f7
eaba143
ce0c8db
 
53bf7b2
3c2e6b2
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import json
from pathlib import Path

import pandas as pd
import streamlit as st

from category_classification.models import models as class_models
from languages import *
from results import process_results

page_title = {en: "Papers classification", ru: "Классификация статей"}
model_label = {en: "Select model", ru: "Выберете модель"}
title_label = {en: "Title", ru: "Название статьи"}
authors_label = {en: "Author(s)", ru: "Автор(ы)"}
abstract_label = {en: "Abstract", ru: "Аннотация"}
metrics_label = {en: "Test metrics", ru: "Метрики на тренировочном датасете"}

with open(
    Path(__file__).parent / "category_classification" / "test_results.json", "r"
) as metric_f:
    metrics = json.load(metric_f)


def text_area_height(line_height: int):
    return 34 * line_height


@st.cache_data
def load_class_model(name):
    model = class_models.get_model(name)
    return model


lang = st.pills(label=langs_str, options=langs)
if lang is None:
    lang = en
st.title(page_title[lang])
model_name = st.selectbox(
    model_label[lang], options=class_models.get_model_names_by_lang(lang)
)
title = st.text_area(title_label[lang], height=text_area_height(2))
authors = st.text_area(authors_label[lang], height=text_area_height(2))
abstract = st.text_area(abstract_label[lang], height=text_area_height(5))

if title:
    input = {"title": title, "abstract": abstract, "authors": authors}
    model = load_class_model(model_name)
    results = model(input)
    results = process_results(results, lang)
    st.dataframe(results, hide_index=True)

lang_metrics = pd.DataFrame(metrics[lang])
if not lang_metrics.empty:
    with st.expander(metrics_label[lang]):
        st.dataframe(lang_metrics)