import streamlit as st import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification # Load the model and tokenizer (make sure your model is correctly loaded here) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Specify the model path model_name = "ipc_refined_approach_model" # Replace with your actual model path or name model = AutoModelForSequenceClassification.from_pretrained(model_name) model = model.to(device) tokenizer = AutoTokenizer.from_pretrained(model_name) # Define your legal sections sections = ['465', '467', '395', '332','353'] # Example sections, modify as per your actual list # Streamlit UI setup st.title("Legal Case Section Prediction") # Get input text from user st.subheader("Enter the legal text to predict the sections it belongs to:") input_text = st.text_area("Input Text", height=250) # Prediction function def predict_text(text): # Tokenize and encode input text inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512) # Move inputs to the same device as the model inputs = {key: value.to(device) for key, value in inputs.items()} # Perform inference model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # Ensure logits are accessed correctly # Apply sigmoid to get probabilities probs = torch.sigmoid(logits).detach().cpu().numpy() # Move to CPU for processing # Convert probabilities to binary predictions (threshold 0.5) predictions = {section: int(prob > 0.5) for section, prob in zip(sections, probs[0])} # Return the sections the case belongs to sections_belongs_to = [section for section, pred in predictions.items() if pred == 1] return sections_belongs_to # Show results if input text is provided if input_text: st.subheader("Prediction Results") # Get predictions for the input text predicted_sections = predict_text(input_text) # Show predictions if predicted_sections: st.write(f"This case belongs to Section(s): {', '.join(predicted_sections)}") else: st.write("This case does not belong to any known section.") else: st.write("Please enter some text to predict the sections.")