File size: 3,556 Bytes
c14261a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ca25148
c14261a
9690d2b
 
c14261a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f1bf2e5
 
 
 
 
 
c14261a
f1bf2e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import fitz
import os

model = AutoModelForSequenceClassification.from_pretrained("Reem333/Citaion-Classifier")
tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")

def predict_class(text):
    try:
        max_length = 4096
        truncated_text = text[:max_length]

        inputs = tokenizer(truncated_text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            predicted_class = torch.argmax(logits, dim=1).item()
        return predicted_class
    except Exception as e:
        st.error(f"Error during prediction: {e}")
        return None


class_colors = {
    0: "#2ca02c",  # Level 1
    1: "#ff7f0e",  # Level 2
    2: "#ffff00",  # Level 3
    3: "#d62728"   # Level 4
}

st.set_page_config(page_title="Paper Citation Classifier", page_icon="logo.png")

with st.sidebar:
    st.image("logo.png", width=70)
    st.markdown('<div style="position: absolute; left: 5px;"></div>', unsafe_allow_html=True)
    
    st.markdown("# Paper Citation Classifier")
    st.markdown("---")
    st.markdown("## About")
    st.markdown('''
    This is a tool to classify paper citations into different levels based on their number of citations.
    Powered by Fine-Tuned [Longformer model](https://huggingface.co/REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier) with custom data.
    ''')
    st.markdown("### Class Levels:")
    st.markdown("- Level 1: Highly cited papers")
    st.markdown("- Level 2: Average cited papers")
    st.markdown("- Level 3: More cited papers")
    st.markdown("- Level 4: Low cited papers")
    st.markdown("---")
    st.markdown('Tabuk University')

st.title("Check Your Paper Now!")


title_input = st.text_area("Enter Title:")
abstract_input = st.text_area("Enter Abstract:")
full_text_input = st.text_area("Enter Full Text:")
affiliations_input = st.text_area("Enter Affiliations:")
keywords_input = st.text_area("Enter Keywords:")
options=["Nursing", "Physics", "Maths", "Chemical", "Nuclear", "Engineering" ,"Other"]
    
selected_category = st.selectbox("Select WoS categories:", options, index= None)
if selected_category == "Other":
    custom_category = st.text_input("Enter custom category:")
    selected_category = custom_category if custom_category else "Other"

combined_text = f"{title_input} [SEP] {keywords_input} [SEP] {abstract_input} [SEP] {selected_category} [SEP] {affiliations_input} [SEP] {' [SEP] '.join(full_text_input)}"

if st.button("Predict"):
    if not any([title_input, abstract_input,keywords_input, full_text_input, affiliations_input]):
        st.warning("Please enter paper text.")
    else:
        with st.spinner("Predicting..."):
            predicted_class = predict_class(combined_text)
            if predicted_class is not None:
                class_labels = ["Level 1", "Level 2", "Level 3", "Level 4"]

                st.text("Predicted Class:")
                for i, label in enumerate(class_labels):
                    if i == predicted_class:
                        st.markdown(
                            f'<div style="background-color: {class_colors[predicted_class]}; padding: 10px; border-radius: 5px; color: white; font-weight: bold;">{label}</div>',
                            unsafe_allow_html=True
                        )
                    else:
                        st.text(label)