|
import gradio as gr |
|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
from peft import PeftModel |
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
"Equall/Saul-7B-Base", |
|
quantization_config=BitsAndBytesConfig(load_in_8bit=True), |
|
device_map="auto", |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("Equall/Saul-7B-Base") |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
model = PeftModel.from_pretrained( |
|
model, |
|
"auslawbench/Cite-SaulLM-7B", |
|
device_map="auto", |
|
torch_dtype=torch.bfloat16, |
|
) |
|
model.eval() |
|
|
|
|
|
def predict_case_citation(instruction, input_text): |
|
fine_tuned_prompt = """ |
|
### Instruction: |
|
{} |
|
|
|
### Input: |
|
{} |
|
|
|
### Response: |
|
""" |
|
model_input = fine_tuned_prompt.format(instruction, input_text) |
|
inputs = tokenizer(model_input, return_tensors="pt").to("cuda") |
|
outputs = model.generate(**inputs, max_new_tokens=256, temperature=1.0) |
|
output = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
response = output.split("### Response:")[1].strip().split('>')[0] + '>' |
|
return response |
|
|
|
|
|
iface = gr.Interface( |
|
fn=predict_case_citation, |
|
inputs=[ |
|
gr.Textbox(label="Instruction", value="Predict the name of the case that needs to be cited in the text and explain why it should be cited."), |
|
gr.Textbox(label="Input Text", lines=5) |
|
], |
|
outputs=gr.Textbox(label="Predicted Case Citation"), |
|
title="Case Citation Predictor", |
|
description="This app predicts the name of the case that should be cited in the given legal text.", |
|
examples=[ |
|
[ |
|
"Predict the name of the case that needs to be cited in the text and explain why it should be cited.", |
|
"Many of ZAR's grounds of appeal related to fact finding. Drawing on principles set down in several other courts and tribunals, the Appeal Panel summarised the circumstances in which leave may be granted for a person to appeal from findings of fact: <CASENAME> at [84]." |
|
] |
|
] |
|
) |
|
|
|
|
|
iface.launch() |
|
|