File size: 5,885 Bytes
13cf51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47dd29e
13cf51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47dd29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3b05ad
47dd29e
9a94502
47dd29e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f3b05ad
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
from transformers import (
    AutoConfig,
    AutoModelForQuestionAnswering,
    AutoTokenizer,
    squad_convert_examples_to_features
)

from transformers.data.processors.squad import SquadResult, SquadV2Processor, SquadExample
from transformers.data.metrics.squad_metrics import compute_predictions_logits
import streamlit as st
import gradio as gr
import json
import torch
import time
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

model_checkpoint = "akdeniz27/roberta-base-cuad"

@st.cache(allow_output_mutation=True)
def run_prediction(question_texts, context_text, model_path):
    max_seq_length = 512
    doc_stride = 256
    n_best_size = 1
    max_query_length = 64
    max_answer_length = 512
    do_lower_case = False
    null_score_diff_threshold = 0.0
    def to_list(tensor):
        return tensor.detach().cpu().tolist()
    config_class, model_class, tokenizer_class = (
        AutoConfig, AutoModelForQuestionAnswering, AutoTokenizer)
    config = config_class.from_pretrained(model_path)
    tokenizer = tokenizer_class.from_pretrained(
        model_path, do_lower_case=True, use_fast=False)
    model = model_class.from_pretrained(model_path, config=config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    processor = SquadV2Processor()
    examples = []
    for i, question_text in enumerate(question_texts):
        example = SquadExample(
            qas_id=str(i),
            question_text=question_text,
            context_text=context_text,
            answer_text=None,
            start_position_character=None,
            title="Predict",
            answers=None,
        )
        examples.append(example)
    features, dataset = squad_convert_examples_to_features(
        examples=examples,
        tokenizer=tokenizer,
        max_seq_length=max_seq_length,
        doc_stride=doc_stride,
        max_query_length=max_query_length,
        is_training=False,
        return_dataset="pt",
        threads=1,
    )
    eval_sampler = SequentialSampler(dataset)
    eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=10)
    all_results = []
    for batch in eval_dataloader:
        model.eval()
        batch = tuple(t.to(device) for t in batch)
        with torch.no_grad():
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
            }
            example_indices = batch[3]
            outputs = model(**inputs)
            for i, example_index in enumerate(example_indices):
                eval_feature = features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                output = [to_list(output[i]) for output in outputs.to_tuple()]
                start_logits, end_logits = output
                result = SquadResult(unique_id, start_logits, end_logits)
                all_results.append(result)
    final_predictions = compute_predictions_logits(
        all_examples=examples,
        all_features=features,
        all_results=all_results,
        n_best_size=n_best_size,
        max_answer_length=max_answer_length,
        do_lower_case=do_lower_case,
        output_prediction_file=None,
        output_nbest_file=None,
        output_null_log_odds_file=None,
        verbose_logging=False,
        version_2_with_negative=True,
        null_score_diff_threshold=null_score_diff_threshold,
        tokenizer=tokenizer
    )
    return final_predictions

@st.cache(allow_output_mutation=True)
def load_model():
    model = AutoModelForQuestionAnswering.from_pretrained(model_checkpoint)
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint , use_fast=False)
    return model, tokenizer
    
@st.cache(allow_output_mutation=True)
def load_questions():
	with open('test.json') as json_file:
		data = json.load(json_file)
	questions = []
	for i, q in enumerate(data['data'][0]['paragraphs'][0]['qas']):
		question = data['data'][0]['paragraphs'][0]['qas'][i]['question']
		questions.append(question)
	return questions
	
@st.cache(allow_output_mutation=True)
def load_contracts():
	with open('test.json') as json_file:
		data = json.load(json_file)
	contracts = []
	for i, q in enumerate(data['data']):
		contract = ' '.join(data['data'][i]['paragraphs'][0]['context'].split())
		contracts.append(contract)
	return contracts
	
model, tokenizer = load_model()
questions = load_questions()
contracts = load_contracts()
contract = contracts[0]

st.header("Question Answering in CUAD (Contract Understanding Atticus Dataset)")

selected_question = st.selectbox('Choose one of the queries from the CUAD dataset or write a legal contract and see if the model can answer correctly. The model only supports English Language:', questions)
question_set = [questions[0], selected_question]
contract_type = st.radio("Select Contract", ("Sample Contract", "New Contract"))

if contract_type == "Sample Contract":
	sample_contract_num = st.slider("Select Sample Contract #")
	contract = contracts[sample_contract_num]
	with st.expander(f"Sample Contract #{sample_contract_num}"):
		st.write(contract)
else:
	contract = st.text_area("Input New Contract", "", height=256)
Run_Button = st.button("Run", key=None)
if Run_Button == True and not len(contract)==0 and not len(question_set)==0:
	predictions = run_prediction(question_set, contract, 'akdeniz27/roberta-base-cuad')
	
	for i, p in enumerate(predictions):
		if i != 0: st.write(f"Question: {question_set[int(p)]}\n\nAnswer: {predictions[p]}\n\n")


st.write("Model: akdeniz27/roberta-base-cuad")
st.write("Project: https://www.atticusprojectai.org/cuad")
st.write("Git Hub: https://github.com/TheAtticusProject/cuad")
st.write("CUAD Dataset: https://huggingface.co/datasets/cuad")
st.write("Based on https://github.com/marshmellow77/cuad-demo")