abdullahmubeen10 commited on
Commit
2b86da8
·
verified ·
1 Parent(s): 68044b8

Update Demo.py

Browse files
Files changed (1) hide show
  1. Demo.py +140 -140
Demo.py CHANGED
@@ -1,141 +1,141 @@
1
- import streamlit as st
2
- import sparknlp
3
- import os
4
- import pandas as pd
5
-
6
- from sparknlp.base import *
7
- from sparknlp.annotator import *
8
- from pyspark.ml import Pipeline
9
- from sparknlp.pretrained import PretrainedPipeline
10
- from streamlit_tags import st_tags
11
-
12
- # Page configuration
13
- st.set_page_config(
14
- layout="wide",
15
- initial_sidebar_state="auto"
16
- )
17
-
18
- # CSS for styling
19
- st.markdown("""
20
- <style>
21
- .main-title {
22
- font-size: 36px;
23
- color: #4A90E2;
24
- font-weight: bold;
25
- text-align: center;
26
- }
27
- .section {
28
- background-color: #f9f9f9;
29
- padding: 10px;
30
- border-radius: 10px;
31
- margin-top: 10px;
32
- }
33
- .section p, .section ul {
34
- color: #666666;
35
- }
36
- </style>
37
- """, unsafe_allow_html=True)
38
-
39
- @st.cache_resource
40
- def init_spark():
41
- return sparknlp.start()
42
-
43
- @st.cache_resource
44
- def create_pipeline(model, labels):
45
- image_assembler = ImageAssembler() \
46
- .setInputCol("image") \
47
- .setOutputCol("image_assembler")
48
-
49
- imageClassifier = CLIPForZeroShotClassification \
50
- .pretrained() \
51
- .setInputCols(["image_assembler"]) \
52
- .setOutputCol("label") \
53
- .setCandidateLabels(labels)
54
-
55
- pipeline = Pipeline(stages=[
56
- image_assembler,
57
- imageClassifier,
58
- ])
59
- return pipeline
60
-
61
- def fit_data(pipeline, data):
62
- model = pipeline.fit(data)
63
- light_pipeline = LightPipeline(model)
64
- annotations_result = light_pipeline.fullAnnotateImage(data)
65
- return annotations_result[0]['label'][0].result
66
-
67
- def save_uploadedfile(uploadedfile):
68
- filepath = os.path.join(IMAGE_FILE_PATH, uploadedfile.name)
69
- with open(filepath, "wb") as f:
70
- if hasattr(uploadedfile, 'getbuffer'):
71
- f.write(uploadedfile.getbuffer())
72
- else:
73
- f.write(uploadedfile.read())
74
-
75
- # Sidebar content
76
- model = st.sidebar.selectbox(
77
- "Choose the pretrained model",
78
- ["CLIPForZeroShotClassification"],
79
- help="For more info about the models visit: https://sparknlp.org/models"
80
- )
81
-
82
- # Set up the page layout
83
- st.markdown(f'<div class="main-title">CLIPForZeroShotClassification</div>', unsafe_allow_html=True)
84
- # st.markdown(f'<div class="section"><p>{sub_title}</p></div>', unsafe_allow_html=True)
85
-
86
- # Reference notebook link in sidebar
87
- link = """
88
- <a href="https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/image/CLIPForZeroShotClassification.ipynb">
89
- <img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
90
- </a>
91
- """
92
- st.sidebar.markdown('Reference notebook:')
93
- st.sidebar.markdown(link, unsafe_allow_html=True)
94
-
95
- # Load examples
96
- IMAGE_FILE_PATH = "/content/sparknlp CLIPForZeroShotClassification/input"
97
- image_files = sorted([file for file in os.listdir(IMAGE_FILE_PATH) if file.split('.')[-1]=='png' or file.split('.')[-1]=='jpg' or file.split('.')[-1]=='JPEG' or file.split('.')[-1]=='jpeg'])
98
-
99
- img_options = st.selectbox("Select an image", image_files)
100
- uploadedfile = st.file_uploader("Try it for yourself!")
101
-
102
- if uploadedfile:
103
- file_details = {"FileName":uploadedfile.name,"FileType":uploadedfile.type}
104
- save_uploadedfile(uploadedfile)
105
- selected_image = f"{IMAGE_FILE_PATH}/{uploadedfile.name}"
106
- elif img_options:
107
- selected_image = f"{IMAGE_FILE_PATH}/{img_options}"
108
-
109
- candidateLabels = [
110
- "a photo of a bird",
111
- "a photo of a cat",
112
- "a photo of a dog",
113
- "a photo of a hen",
114
- "a photo of a hippo",
115
- "a photo of a room",
116
- "a photo of a tractor",
117
- "a photo of an ostrich",
118
- "a photo of an ox"]
119
-
120
- lables = st_tags(
121
- label='Select labels',
122
- text='Press enter to add more',
123
- value=candidateLabels,
124
- maxtags = -1)
125
-
126
- st.subheader('Classified Image')
127
-
128
- image_size = st.slider('Image Size', 400, 1000, value=400, step = 100)
129
-
130
- try:
131
- st.image(f"{IMAGE_FILE_PATH}/{selected_image}", width=image_size)
132
- except:
133
- st.image(selected_image, width=image_size)
134
-
135
- st.subheader('Classification')
136
-
137
- init_spark()
138
- Pipeline = create_pipeline(model, lables)
139
- output = fit_data(Pipeline, selected_image)
140
-
141
  st.markdown(f'This document has been classified as : **{output}**')
 
1
+ import streamlit as st
2
+ import sparknlp
3
+ import os
4
+ import pandas as pd
5
+
6
+ from sparknlp.base import *
7
+ from sparknlp.annotator import *
8
+ from pyspark.ml import Pipeline
9
+ from sparknlp.pretrained import PretrainedPipeline
10
+ from streamlit_tags import st_tags
11
+
12
+ # Page configuration
13
+ st.set_page_config(
14
+ layout="wide",
15
+ initial_sidebar_state="auto"
16
+ )
17
+
18
+ # CSS for styling
19
+ st.markdown("""
20
+ <style>
21
+ .main-title {
22
+ font-size: 36px;
23
+ color: #4A90E2;
24
+ font-weight: bold;
25
+ text-align: center;
26
+ }
27
+ .section {
28
+ background-color: #f9f9f9;
29
+ padding: 10px;
30
+ border-radius: 10px;
31
+ margin-top: 10px;
32
+ }
33
+ .section p, .section ul {
34
+ color: #666666;
35
+ }
36
+ </style>
37
+ """, unsafe_allow_html=True)
38
+
39
+ @st.cache_resource
40
+ def init_spark():
41
+ return sparknlp.start()
42
+
43
+ @st.cache_resource
44
+ def create_pipeline(model, labels):
45
+ image_assembler = ImageAssembler() \
46
+ .setInputCol("image") \
47
+ .setOutputCol("image_assembler")
48
+
49
+ imageClassifier = CLIPForZeroShotClassification \
50
+ .pretrained() \
51
+ .setInputCols(["image_assembler"]) \
52
+ .setOutputCol("label") \
53
+ .setCandidateLabels(labels)
54
+
55
+ pipeline = Pipeline(stages=[
56
+ image_assembler,
57
+ imageClassifier,
58
+ ])
59
+ return pipeline
60
+
61
+ def fit_data(pipeline, data):
62
+ model = pipeline.fit(data)
63
+ light_pipeline = LightPipeline(model)
64
+ annotations_result = light_pipeline.fullAnnotateImage(data)
65
+ return annotations_result[0]['label'][0].result
66
+
67
+ def save_uploadedfile(uploadedfile):
68
+ filepath = os.path.join(IMAGE_FILE_PATH, uploadedfile.name)
69
+ with open(filepath, "wb") as f:
70
+ if hasattr(uploadedfile, 'getbuffer'):
71
+ f.write(uploadedfile.getbuffer())
72
+ else:
73
+ f.write(uploadedfile.read())
74
+
75
+ # Sidebar content
76
+ model = st.sidebar.selectbox(
77
+ "Choose the pretrained model",
78
+ ["CLIPForZeroShotClassification"],
79
+ help="For more info about the models visit: https://sparknlp.org/models"
80
+ )
81
+
82
+ # Set up the page layout
83
+ st.markdown(f'<div class="main-title">CLIPForZeroShotClassification</div>', unsafe_allow_html=True)
84
+ # st.markdown(f'<div class="section"><p>{sub_title}</p></div>', unsafe_allow_html=True)
85
+
86
+ # Reference notebook link in sidebar
87
+ link = """
88
+ <a href="https://github.com/JohnSnowLabs/spark-nlp/blob/master/examples/python/annotation/image/CLIPForZeroShotClassification.ipynb">
89
+ <img src="https://colab.research.google.com/assets/colab-badge.svg" style="zoom: 1.3" alt="Open In Colab"/>
90
+ </a>
91
+ """
92
+ st.sidebar.markdown('Reference notebook:')
93
+ st.sidebar.markdown(link, unsafe_allow_html=True)
94
+
95
+ # Load examples
96
+ IMAGE_FILE_PATH = "input"
97
+ image_files = sorted([file for file in os.listdir(IMAGE_FILE_PATH) if file.split('.')[-1]=='png' or file.split('.')[-1]=='jpg' or file.split('.')[-1]=='JPEG' or file.split('.')[-1]=='jpeg'])
98
+
99
+ img_options = st.selectbox("Select an image", image_files)
100
+ uploadedfile = st.file_uploader("Try it for yourself!")
101
+
102
+ if uploadedfile:
103
+ file_details = {"FileName":uploadedfile.name,"FileType":uploadedfile.type}
104
+ save_uploadedfile(uploadedfile)
105
+ selected_image = f"{IMAGE_FILE_PATH}/{uploadedfile.name}"
106
+ elif img_options:
107
+ selected_image = f"{IMAGE_FILE_PATH}/{img_options}"
108
+
109
+ candidateLabels = [
110
+ "a photo of a bird",
111
+ "a photo of a cat",
112
+ "a photo of a dog",
113
+ "a photo of a hen",
114
+ "a photo of a hippo",
115
+ "a photo of a room",
116
+ "a photo of a tractor",
117
+ "a photo of an ostrich",
118
+ "a photo of an ox"]
119
+
120
+ lables = st_tags(
121
+ label='Select labels',
122
+ text='Press enter to add more',
123
+ value=candidateLabels,
124
+ maxtags = -1)
125
+
126
+ st.subheader('Classified Image')
127
+
128
+ image_size = st.slider('Image Size', 400, 1000, value=400, step = 100)
129
+
130
+ try:
131
+ st.image(f"{IMAGE_FILE_PATH}/{selected_image}", width=image_size)
132
+ except:
133
+ st.image(selected_image, width=image_size)
134
+
135
+ st.subheader('Classification')
136
+
137
+ init_spark()
138
+ Pipeline = create_pipeline(model, lables)
139
+ output = fit_data(Pipeline, selected_image)
140
+
141
  st.markdown(f'This document has been classified as : **{output}**')