LegalLLM / app.py
PKaushik's picture
Create app.py
2db95d5 verified
raw
history blame
2.13 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from peft import PeftModel
# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(
"Equall/Saul-7B-Base",
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("Equall/Saul-7B-Base")
tokenizer.pad_token = tokenizer.eos_token
model = PeftModel.from_pretrained(
model,
"auslawbench/Cite-SaulLM-7B",
device_map="auto",
torch_dtype=torch.bfloat16,
)
model.eval()
# Define the prediction function
def predict_case_citation(instruction, input_text):
fine_tuned_prompt = """
### Instruction:
{}
### Input:
{}
### Response:
"""
model_input = fine_tuned_prompt.format(instruction, input_text)
inputs = tokenizer(model_input, return_tensors="pt").to("cuda")
outputs = model.generate(**inputs, max_new_tokens=256, temperature=1.0)
output = tokenizer.decode(outputs[0], skip_special_tokens=True)
response = output.split("### Response:")[1].strip().split('>')[0] + '>'
return response
# Create the Gradio interface
iface = gr.Interface(
fn=predict_case_citation,
inputs=[
gr.Textbox(label="Instruction", value="Predict the name of the case that needs to be cited in the text and explain why it should be cited."),
gr.Textbox(label="Input Text", lines=5)
],
outputs=gr.Textbox(label="Predicted Case Citation"),
title="Case Citation Predictor",
description="This app predicts the name of the case that should be cited in the given legal text.",
examples=[
[
"Predict the name of the case that needs to be cited in the text and explain why it should be cited.",
"Many of ZAR's grounds of appeal related to fact finding. Drawing on principles set down in several other courts and tribunals, the Appeal Panel summarised the circumstances in which leave may be granted for a person to appeal from findings of fact: <CASENAME> at [84]."
]
]
)
# Launch the app
iface.launch()