dbleek commited on
Commit
d2e0837
·
1 Parent(s): 930a750

implemented patent classifier

Browse files
Files changed (1) hide show
  1. app_pt.py +59 -0
app_pt.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from datasets import load_dataset
4
+ from transformers import AutoTokenizer
5
+ from transformers import AutoModelForSequenceClassification
6
+ from transformers import pipeline
7
+
8
+ dataset_dict = load_dataset('HUPD/hupd',
9
+ name='sample',
10
+ data_files="https://huggingface.co/datasets/HUPD/hupd/blob/main/hupd_metadata_2022-02-22.feather",
11
+ icpr_label=None,
12
+ train_filing_start_date='2016-01-01',
13
+ train_filing_end_date='2016-01-21',
14
+ val_filing_start_date='2016-01-22',
15
+ val_filing_end_date='2016-01-31',
16
+ )
17
+ model = torch.load("/workspaces/cs-gy-6613-project/patent_classification(1).pt", map_location=torch.device('cpu'))
18
+ tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
19
+ classifier = pipeline("text-classification", model=model, tokenizer=tokenizer)
20
+ filtered_dataset = dataset_dict['validation'].filter(lambda e: e['decision'] == 'ACCEPTED' or e['decision'] == 'REJECTED')
21
+ dataset = filtered_dataset.shuffle(seed=42).select(range(20))
22
+ dataset = dataset.sort("patent_number")
23
+ applications = {}
24
+
25
+ for ds_index, example in enumerate(dataset):
26
+ applications.update({example['patent_number']: ds_index })
27
+
28
+ def load_patent():
29
+ selected_application = dataset.select([applications[st.session_state.id]])
30
+ st.session_state.abstract = selected_application['abstract'][0]
31
+ st.session_state.claims = selected_application['claims'][0]
32
+ st.session_state.title = selected_application['title'][0]
33
+
34
+ st.title("CS-GY-6613 Project Milestone 3")
35
+
36
+ patent_number = st.selectbox("Select a patent application:", applications, on_change=load_patent, key="id")
37
+ title = st.text_area("Title", key="title", value=dataset[0]['title'], height=50)
38
+
39
+ with st.form('Details'):
40
+ abstract = st.text_area("Abstract", key="abstract", value=dataset[0]['abstract'], height=200)
41
+ claims = st.text_area("Claims", key="claims", value=dataset[0]['abstract'], height=200)
42
+ submitted = st.form_submit_button("Get Patentability Score")
43
+
44
+ if submitted:
45
+ selected_application = dataset.select([applications[st.session_state.id]])
46
+ res = classifier(abstract, claims)
47
+ if res[0]["label"] == 'LABEL_0':
48
+ pred = "ACCEPTED"
49
+ elif res[0]["label"] == 'LABEL_1':
50
+ pred = "REJECTED"
51
+ score = res[0]["score"]
52
+ label = selected_application['decision'][0]
53
+ result = st.markdown("This text was classified as **{}** with a confidence score of **{}**.".format(pred, score))
54
+ check = st.markdown("Actual Label: **{}**.".format(label))
55
+
56
+
57
+
58
+
59
+