Update app.py
Browse files
app.py
CHANGED
@@ -2,27 +2,25 @@ import streamlit as st
|
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
|
5 |
-
# Load
|
6 |
-
|
7 |
-
|
8 |
-
# Specify the model path
|
9 |
-
model_name = "ipc_refined_approach_model" # Replace with your actual model path or name
|
10 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
11 |
-
model = model.to(device)
|
12 |
|
13 |
-
|
|
|
14 |
|
15 |
-
#
|
16 |
-
|
|
|
17 |
|
18 |
-
#
|
19 |
-
|
20 |
|
21 |
-
#
|
22 |
-
|
23 |
-
input_text = st.text_area("Input Text", height=250)
|
24 |
|
25 |
-
#
|
26 |
def predict_text(text):
|
27 |
# Tokenize and encode input text
|
28 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
@@ -42,21 +40,23 @@ def predict_text(text):
|
|
42 |
# Convert probabilities to binary predictions (threshold 0.5)
|
43 |
predictions = {section: int(prob > 0.5) for section, prob in zip(sections, probs[0])}
|
44 |
|
45 |
-
#
|
46 |
sections_belongs_to = [section for section, pred in predictions.items() if pred == 1]
|
47 |
-
|
48 |
-
|
49 |
-
# Show results if input text is provided
|
50 |
-
if input_text:
|
51 |
-
st.subheader("Prediction Results")
|
52 |
-
|
53 |
-
# Get predictions for the input text
|
54 |
-
predicted_sections = predict_text(input_text)
|
55 |
-
|
56 |
-
# Show predictions
|
57 |
-
if predicted_sections:
|
58 |
-
st.write(f"This case belongs to Section(s): {', '.join(predicted_sections)}")
|
59 |
else:
|
60 |
st.write("This case does not belong to any known section.")
|
61 |
-
|
62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import torch
|
3 |
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
4 |
|
5 |
+
# Load model and tokenizer (make sure the model path is correct)
|
6 |
+
model_name = "ipc_refined_approach_model"
|
7 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
|
|
|
8 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
|
|
9 |
|
10 |
+
# Example sections list (ensure it's aligned with your model)
|
11 |
+
sections = ['465', '395', '332', '353', '467']
|
12 |
|
13 |
+
# Save labels to a file
|
14 |
+
with open("labels.txt", "w") as f:
|
15 |
+
f.write("\n".join(sections)) # 'sections' should be a list of section names like ['465', '395', ...].
|
16 |
|
17 |
+
# Ensure consistency in device setup
|
18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
19 |
|
20 |
+
# Move the model to the correct device
|
21 |
+
model = model.to(device)
|
|
|
22 |
|
23 |
+
# Function for prediction
|
24 |
def predict_text(text):
|
25 |
# Tokenize and encode input text
|
26 |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
|
|
|
40 |
# Convert probabilities to binary predictions (threshold 0.5)
|
41 |
predictions = {section: int(prob > 0.5) for section, prob in zip(sections, probs[0])}
|
42 |
|
43 |
+
# Print the sections the case belongs to
|
44 |
sections_belongs_to = [section for section, pred in predictions.items() if pred == 1]
|
45 |
+
if sections_belongs_to:
|
46 |
+
st.write(f"This case belongs to Section(s): {', '.join(sections_belongs_to)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
else:
|
48 |
st.write("This case does not belong to any known section.")
|
49 |
+
|
50 |
+
return predictions
|
51 |
+
|
52 |
+
# Streamlit app interface
|
53 |
+
st.title("Legal Section Classification")
|
54 |
+
st.write("Enter the text for case classification:")
|
55 |
+
|
56 |
+
# Text input for case description
|
57 |
+
sample_text = st.text_area("Case Description", "Attack on a police officer to avoid him from doing his duty")
|
58 |
+
|
59 |
+
# Button to make prediction
|
60 |
+
if st.button("Classify Case"):
|
61 |
+
predictions = predict_text(sample_text)
|
62 |
+
st.write(predictions)
|