import streamlit as st import torch from transformers import AutoTokenizer from semviqa.ser.qatc_model import QATCForQuestionAnswering from semviqa.tvc.model import ClaimModelForClassification from semviqa.ser.ser_eval import extract_evidence_tfidf_qatc from semviqa.tvc.tvc_eval import classify_claim import io # Load models with caching @st.cache_resource() def load_model(model_name, model_class, is_bc=False): tokenizer = AutoTokenizer.from_pretrained(model_name) model = model_class.from_pretrained(model_name, num_labels=3 if not is_bc else 2) model.eval() return tokenizer, model # Set up page configuration st.set_page_config(page_title="SemViQA Demo", layout="wide") # Custom CSS: fixed navigation, header, and adjusted height st.markdown(""" """, unsafe_allow_html=True) # Main container with fixed height with st.container(): st.markdown("
SemViQA: Semantic QA Information Verification System for Vietnamese
", unsafe_allow_html=True) st.markdown("Enter information to verify and context to check its accuracy
", unsafe_allow_html=True) # Sidebar: Global Settings with st.sidebar.expander("⚙️ Settings", expanded=True): tfidf_threshold = st.slider("TF-IDF Threshold", 0.0, 1.0, 0.5, 0.01) length_ratio_threshold = st.slider("Length Ratio Threshold", 0.1, 1.0, 0.5, 0.01) qatc_model_name = st.selectbox("QATC Model", [ "SemViQA/qatc-infoxlm-viwikifc", "SemViQA/qatc-infoxlm-isedsc01", "SemViQA/qatc-vimrc-viwikifc", "SemViQA/qatc-vimrc-isedsc01" ]) bc_model_name = st.selectbox("Binary Classification Model", [ "SemViQA/bc-xlmr-viwikifc", "SemViQA/bc-xlmr-isedsc01", "SemViQA/bc-infoxlm-viwikifc", "SemViQA/bc-infoxlm-isedsc01", "SemViQA/bc-erniem-viwikifc", "SemViQA/bc-erniem-isedsc01" ]) tc_model_name = st.selectbox("3-Class Classification Model", [ "SemViQA/tc-xlmr-viwikifc", "SemViQA/tc-xlmr-isedsc01", "SemViQA/tc-infoxlm-viwikifc", "SemViQA/tc-infoxlm-isedsc01", "SemViQA/tc-erniem-viwikifc", "SemViQA/tc-erniem-isedsc01" ]) show_details = st.checkbox("Show probability details", value=False) # Store verification history if 'history' not in st.session_state: st.session_state.history = [] if 'latest_result' not in st.session_state: st.session_state.latest_result = None if 'is_verifying' not in st.session_state: st.session_state.is_verifying = False # Load selected models tokenizer_qatc, model_qatc = load_model(qatc_model_name, QATCForQuestionAnswering) tokenizer_bc, model_bc = load_model(bc_model_name, ClaimModelForClassification, is_bc=True) tokenizer_tc, model_tc = load_model(tc_model_name, ClaimModelForClassification) # Icons for results verdict_icons = { "SUPPORTED": "✅", "REFUTED": "❌", "NEI": "⚠️" } # Create tabs: Verify, History, About tabs = st.tabs(["Verify", "History", "About"]) # --- Verify Tab --- with tabs[0]: st.subheader("Verify Information") # Use 2-column layout: inputs on left, results on right col_input, col_result = st.columns([2, 1]) with col_input: claim = st.text_area("Enter Claim", "Vietnam is a country in Southeast Asia.") context = st.text_area("Enter Context", "Vietnam is a country located in Southeast Asia, covering an area of over 331,000 km² with a population of more than 98 million people.") def start_verification(): st.session_state.is_verifying = True st.experimental_rerun() if st.button("Verify", key="verify_button", on_click=start_verification): pass # Display results in right column with col_result: st.markdown("Processing verification...
1. Extracting evidence...
2. Running binary classification...
3. Running 3-class classification...
4. Determining final verdict...
3-Class Probability: {prob3class.item():.2f} - 2-Class Probability: {prob2class.item():.2f}
" # Save verification history and latest result st.session_state.history.append({ "claim": claim, "evidence": evidence, "verdict": verdict }) st.session_state.latest_result = { "claim": claim, "evidence": evidence, "verdict": verdict, "details": details } if torch.cuda.is_available(): torch.cuda.empty_cache() # Turn off verification flag st.session_state.is_verifying = False st.experimental_rerun() elif st.session_state.latest_result is not None: res = st.session_state.latest_result st.markdown(f"""Claim: {res['claim']}
Evidence: {res['evidence']}
{res['details']} {res['verdict']}