annapurnapadmaprema-ji commited on
Commit
fbca1c4
·
verified ·
1 Parent(s): dc20adb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -30
app.py CHANGED
@@ -2,27 +2,25 @@ import streamlit as st
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
- # Load the model and tokenizer (make sure your model is correctly loaded here)
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
-
8
- # Specify the model path
9
- model_name = "ipc_refined_approach_model" # Replace with your actual model path or name
10
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
11
- model = model.to(device)
12
 
13
- tokenizer = AutoTokenizer.from_pretrained(model_name)
 
14
 
15
- # Define your legal sections
16
- sections = ['465', '467', '395', '332','353'] # Example sections, modify as per your actual list
 
17
 
18
- # Streamlit UI setup
19
- st.title("Legal Case Section Prediction")
20
 
21
- # Get input text from user
22
- st.subheader("Enter the legal text to predict the sections it belongs to:")
23
- input_text = st.text_area("Input Text", height=250)
24
 
25
- # Prediction function
26
  def predict_text(text):
27
  # Tokenize and encode input text
28
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
@@ -42,21 +40,23 @@ def predict_text(text):
42
  # Convert probabilities to binary predictions (threshold 0.5)
43
  predictions = {section: int(prob > 0.5) for section, prob in zip(sections, probs[0])}
44
 
45
- # Return the sections the case belongs to
46
  sections_belongs_to = [section for section, pred in predictions.items() if pred == 1]
47
- return sections_belongs_to
48
-
49
- # Show results if input text is provided
50
- if input_text:
51
- st.subheader("Prediction Results")
52
-
53
- # Get predictions for the input text
54
- predicted_sections = predict_text(input_text)
55
-
56
- # Show predictions
57
- if predicted_sections:
58
- st.write(f"This case belongs to Section(s): {', '.join(predicted_sections)}")
59
  else:
60
  st.write("This case does not belong to any known section.")
61
- else:
62
- st.write("Please enter some text to predict the sections.")
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # Load model and tokenizer (make sure the model path is correct)
6
+ model_name = "ipc_refined_approach_model"
7
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
 
9
 
10
+ # Example sections list (ensure it's aligned with your model)
11
+ sections = ['465', '395', '332', '353', '467']
12
 
13
+ # Save labels to a file
14
+ with open("labels.txt", "w") as f:
15
+ f.write("\n".join(sections)) # 'sections' should be a list of section names like ['465', '395', ...].
16
 
17
+ # Ensure consistency in device setup
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
 
20
+ # Move the model to the correct device
21
+ model = model.to(device)
 
22
 
23
+ # Function for prediction
24
  def predict_text(text):
25
  # Tokenize and encode input text
26
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
 
40
  # Convert probabilities to binary predictions (threshold 0.5)
41
  predictions = {section: int(prob > 0.5) for section, prob in zip(sections, probs[0])}
42
 
43
+ # Print the sections the case belongs to
44
  sections_belongs_to = [section for section, pred in predictions.items() if pred == 1]
45
+ if sections_belongs_to:
46
+ st.write(f"This case belongs to Section(s): {', '.join(sections_belongs_to)}")
 
 
 
 
 
 
 
 
 
 
47
  else:
48
  st.write("This case does not belong to any known section.")
49
+
50
+ return predictions
51
+
52
+ # Streamlit app interface
53
+ st.title("Legal Section Classification")
54
+ st.write("Enter the text for case classification:")
55
+
56
+ # Text input for case description
57
+ sample_text = st.text_area("Case Description", "Attack on a police officer to avoid him from doing his duty")
58
+
59
+ # Button to make prediction
60
+ if st.button("Classify Case"):
61
+ predictions = predict_text(sample_text)
62
+ st.write(predictions)