File size: 3,830 Bytes
bd3f85d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0205253
bd3f85d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bc93f3
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import fitz
import os

model = AutoModelForSequenceClassification.from_pretrained("Reem333/Citaion-Classifier")
tokenizer = AutoTokenizer.from_pretrained("allenai/longformer-base-4096")

def predict_class(text):
    try:
        max_length = 4096
        truncated_text = text[:max_length]

        inputs = tokenizer(truncated_text, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
        with torch.no_grad():
            outputs = model(**inputs)
            logits = outputs.logits
            predicted_class = torch.argmax(logits, dim=1).item()
        return predicted_class
    except Exception as e:
        st.error(f"Error during prediction: {e}")
        return None

uploaded_files_dir = "uploaded_files"
os.makedirs(uploaded_files_dir, exist_ok=True)

class_colors = {
    0: "#d62728",  # Level 1
    1: "#ff7f0e",  # Level 2
    2: "#2ca02c",  # Level 3
    3: "#1f77b4"   # Level 4
}

st.set_page_config(page_title="Paper Citation Classifier", page_icon="logo.png")

with st.sidebar:
    st.image("logo.png", width=70)
    st.markdown('<div style="position: absolute; left: 5px;"></div>', unsafe_allow_html=True)
    
    st.markdown("# Paper Citation Classifier")
    st.markdown("---")
    st.markdown("## About")
    st.markdown('''
    This is a tool to classify paper citations into different levels based on their number of citations.
    Powered by Fine-Tuned [Longformer model](https://huggingface.co/REEM-ALRASHIDI/LongFormer-Paper-Citaion-Classifier) with custom data.
    ''')
    st.markdown("### Class Levels:")
    st.markdown("- Level 1: Highly cited papers")
    st.markdown("- Level 2: Average cited papers")
    st.markdown("- Level 3: More cited papers")
    st.markdown("- Level 4: Low cited papers")
    st.markdown("---")
    st.markdown('Tabuk University')

st.title("Check Your Paper Now!")

option = st.radio("Select input type:", ("Text", "PDF"))

if option == "Text":
    title_input = st.text_area("Enter Title:")
    abstract_input = st.text_area("Enter Abstract:")
    full_text_input = st.text_area("Enter Full Text:")
    affiliations_input = st.text_area("Enter Affiliations:")
    keywords_input = st.text_area("Enter Keywords:")
    options=["Nursing", "Physics", "Maths", "Chemical", "Nuclear", "Engineering" ,"Other"]
    
    selected_category = st.selectbox("Select WoS categories:", options, index= None)
    if selected_category == "Other":
        custom_category = st.text_input("Enter custom category:")
        selected_category = custom_category if custom_category else "Other"

    combined_text = f"{title_input} [SEP] {keywords_input} [SEP] {abstract_input} [SEP] {selected_category} [SEP] {affiliations_input} [SEP] {' [SEP] '.join(full_text_input)}"

    if st.button("Predict"):
        if not any([title_input, abstract_input,keywords_input, full_text_input, affiliations_input]):
            st.warning("Please enter paper text.")
        else:
            with st.spinner("Predicting..."):
                predicted_class = predict_class(combined_text)
                if predicted_class is not None:
                    class_labels = ["Level 1", "Level 2", "Level 3", "Level 4"]

                    st.text("Predicted Class:")
                    for i, label in enumerate(class_labels):
                        if i == predicted_class:
                            st.markdown(
                                f'<div style="background-color: {class_colors[predicted_class]}; padding: 10px; border-radius: 5px; color: white; font-weight: bold;">{label}</div>',
                                unsafe_allow_html=True
                            )
                        else:
                            st.text(label)