João Pedro commited on
Commit
6c6f2d5
·
1 Parent(s): b90d0b6

try to avoid streamlit re-running everything on ui changes

Browse files
Files changed (1) hide show
  1. app.py +30 -15
app.py CHANGED
@@ -32,8 +32,13 @@ labels = [
32
  id2label = {i: label for i, label in enumerate(labels)}
33
  label2id = {v: k for k, v in id2label.items()}
34
 
35
- processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
36
- model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
 
 
 
 
 
37
 
38
  st.title("Document Classification with LayoutLMv3")
39
 
@@ -41,15 +46,32 @@ uploaded_file = st.file_uploader(
41
  "Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
42
  )
43
 
44
-
45
  feedback_table = wandb.Table(columns=[
46
  'image', 'filetype', 'predicted_label', 'predicted_label_id',
47
  'correct_label', 'correct_label_id'
48
  ])
49
 
50
- if uploaded_file:
51
- run = wandb.init(project='hydra-classifier', name='feedback-loop')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
 
53
  if uploaded_file.type == "application/pdf":
54
  images = convert_from_bytes(uploaded_file.getvalue())
55
  else:
@@ -58,16 +80,7 @@ if uploaded_file:
58
  for i, image in enumerate(images):
59
  st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
60
 
61
- print(f'Encoding image with index {i}')
62
- encoding = processor(
63
- image,
64
- return_tensors="pt",
65
- truncation=True,
66
- max_length=512,
67
- )
68
- print(f'Predicting image with index {i}')
69
- outputs = model(**encoding)
70
- prediction = outputs.logits.argmax(-1)[0].item()
71
 
72
  st.write(f"Prediction: {id2label[prediction]}")
73
 
@@ -96,5 +109,7 @@ if uploaded_file:
96
  st.success(f"Feedback for Image {i} submitted!")
97
 
98
  print(feedback_table)
 
99
  run.log({'feedback_table': feedback_table})
100
  run.finish()
 
 
32
  id2label = {i: label for i, label in enumerate(labels)}
33
  label2id = {v: k for k, v in id2label.items()}
34
 
35
+ if 'model' not in st.session_state:
36
+ st.session_state.model = LayoutLMv3ForSequenceClassification.from_pretrained("model/layoutlmv3/")
37
+ if 'processor' not in st.session_state:
38
+ st.session_state.processor = LayoutLMv3Processor.from_pretrained("model/layoutlmv3/")
39
+
40
+ model = st.session_state.model
41
+ processor = st.session_state.processor
42
 
43
  st.title("Document Classification with LayoutLMv3")
44
 
 
46
  "Upload Document", type=["pdf", "jpg", "png"], accept_multiple_files=False
47
  )
48
 
 
49
  feedback_table = wandb.Table(columns=[
50
  'image', 'filetype', 'predicted_label', 'predicted_label_id',
51
  'correct_label', 'correct_label_id'
52
  ])
53
 
54
+ if 'wandb_run' not in st.session_data:
55
+ st.session_data.wandb_run = wandb.init(project='hydra-classifier', name='feedback-loop')
56
+
57
+
58
+ @st.cache_data
59
+ def classify_image(image):
60
+ print(f'Encoding image with index {i}')
61
+ encoding = processor(
62
+ image,
63
+ return_tensors="pt",
64
+ truncation=True,
65
+ max_length=512,
66
+ )
67
+
68
+ print(f'Predicting image with index {i}')
69
+ outputs = model(**encoding)
70
+ prediction = outputs.logits.argmax(-1)[0].item()
71
+ return prediction
72
+
73
 
74
+ if uploaded_file:
75
  if uploaded_file.type == "application/pdf":
76
  images = convert_from_bytes(uploaded_file.getvalue())
77
  else:
 
80
  for i, image in enumerate(images):
81
  st.image(image, caption=f'Uploaded Image {i}', use_container_width=True)
82
 
83
+ prediction = classify_image(image)
 
 
 
 
 
 
 
 
 
84
 
85
  st.write(f"Prediction: {id2label[prediction]}")
86
 
 
109
  st.success(f"Feedback for Image {i} submitted!")
110
 
111
  print(feedback_table)
112
+ run = st.session_data.wandb_run
113
  run.log({'feedback_table': feedback_table})
114
  run.finish()
115
+ st.session_data.wandb_run = None