import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification # Load model and tokenizer (make sure the model path is correct) model_name = "ipc_refined_approach_model" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForSequenceClassification.from_pretrained(model_name) # Example sections list (ensure it's aligned with your model) sections = ['465', '395', '332', '353', '467'] # Save labels to a file with open("labels.txt", "w") as f: f.write("\n".join(sections)) # 'sections' should be a list of section names like ['465', '395', ...]. # Ensure consistency in device setup device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Move the model to the correct device model = model.to(device) # Function for prediction def predict_text(text): # Tokenize and encode input text inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Move inputs to the same device as the model inputs = {key: value.to(device) for key, value in inputs.items()} # Perform inference model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Ensure logits are accessed correctly # Apply sigmoid to get probabilities probs = torch.sigmoid(logits).detach().cpu().numpy() # Move to CPU for processing # Convert probabilities to binary predictions (threshold 0.5) predictions = {section: int(prob > 0.5) for section, prob in zip(sections, probs[0])} # Print the sections the case belongs to sections_belongs_to = [section for section, pred in predictions.items() if pred == 1] if sections_belongs_to: st.write(f"This case belongs to Section(s): **{', '.join(sections_belongs_to)}**") else: st.write("This case does not belong to any known section.") return predictions with open("style.css") as f: st.markdown(f"", unsafe_allow_html=True) # Streamlit app interface st.title("Legal Section Classification for FIR") st.write("Enter the text for case classification:") # Text input for case description sample_text = st.text_area("Case Description", "") # Button to make prediction if st.button("Classify Case"): predictions = predict_text(sample_text) st.write(predictions)