pyresearch commited on
Commit
abf471d
·
verified ·
1 Parent(s): c925507

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -23
app.py CHANGED
@@ -1,33 +1,237 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
3
  import torch
 
4
 
5
- # Use GPU if available
6
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
7
 
8
- st.title("Text Generation with Hugging Face Transformers")
 
 
 
 
 
9
 
10
- # Input prompt from user
11
- prompt = st.text_area("Enter a prompt:", "this news is real pyresearch given right computer vision videos?")
 
 
 
 
12
 
13
- # Load model and tokenizer
14
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
15
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
 
 
 
16
 
17
- # Move the model to the desired device
18
- model.to(device)
 
 
 
 
19
 
20
- # Generate text on button click
21
- if st.button("Generate"):
22
- with torch.no_grad():
23
- token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt").to(device)
24
- output_ids = model.generate(
25
- token_ids,
26
- max_new_tokens=512,
27
- do_sample=True,
28
- temperature=0.1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  )
30
 
31
- generated_text = tokenizer.decode(output_ids[0][token_ids.size(1):])
32
- st.text("Generated Text:")
33
- st.write(generated_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
3
+ from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
4
+ from clarifai_grpc.grpc.api.status import status_code_pb2
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
 
 
8
 
9
+ # GPT-4 credentials
10
+ PAT_GPT4 = "3ca5bd8b0f2244eb8d0e4b2838fc3cf1"
11
+ USER_ID_GPT4 = "openai"
12
+ APP_ID_GPT4 = "chat-completion"
13
+ MODEL_ID_GPT4 = "openai-gpt-4-vision"
14
+ MODEL_VERSION_ID_GPT4 = "266df29bc09843e0aee9b7bf723c03c2"
15
 
16
+ # DALL-E credentials
17
+ PAT_DALLE = "bfdeb4029ef54d23a2e608b0aa4c00e4"
18
+ USER_ID_DALLE = "openai"
19
+ APP_ID_DALLE = "dall-e"
20
+ MODEL_ID_DALLE = "dall-e-3"
21
+ MODEL_VERSION_ID_DALLE = "dc9dcb6ee67543cebc0b9a025861b868"
22
 
23
+ # TTS credentials
24
+ PAT_TTS = "bfdeb4029ef54d23a2e608b0aa4c00e4"
25
+ USER_ID_TTS = "openai"
26
+ APP_ID_TTS = "tts"
27
+ MODEL_ID_TTS = "openai-tts-1"
28
+ MODEL_VERSION_ID_TTS = "fff6ce1fd487457da95b79241ac6f02d"
29
 
30
+ # NewsGuardian model credentials
31
+ PAT_NEWSGUARDIAN = "your_news_guardian_pat"
32
+ USER_ID_NEWSGUARDIAN = "your_user_id"
33
+ APP_ID_NEWSGUARDIAN = "your_app_id"
34
+ MODEL_ID_NEWSGUARDIAN = "your_model_id"
35
+ MODEL_VERSION_ID_NEWSGUARDIAN = "your_model_version_id"
36
 
37
+ #
38
+ import streamlit as st
39
+ from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
40
+ from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
41
+ from clarifai_grpc.grpc.api.status import status_code_pb2
42
+
43
+
44
+
45
+
46
+ # Set up gRPC channel for NewsGuardian model
47
+ channel_tts = ClarifaiChannel.get_grpc_channel()
48
+ stub_tts = service_pb2_grpc.V2Stub(channel_tts)
49
+ metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
50
+ userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS,)
51
+
52
+ # Streamlit app
53
+ st.title("NewsGuardian")
54
+
55
+
56
+ # Inserting logo
57
+ st.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTdA-MJ_SUCRgLs1prqudpMdaX4x-x10Zqlwp7cpzXWCMM9xjBAJYWdJsDlLoHBqNpj8qs&usqp=CAU")
58
+ # Function to get gRPC channel for NewsGuardian model
59
+ def get_tts_channel():
60
+ channel_tts = ClarifaiChannel.get_grpc_channel()
61
+ return channel_tts, channel_tts.metadata
62
+
63
+
64
+
65
+ # User input
66
+ model_type = st.selectbox("Select Model", ["NewsGuardian model","NewsGuardian model"])
67
+ raw_text = st.text_area("This news is real or fake?")
68
+ image_upload = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
69
+
70
+ # Button to generate result
71
+ if st.button("NewsGuardian News Result"):
72
+ if model_type == "NewsGuardian model":
73
+ # Set up gRPC channel for NewsGuardian model
74
+ channel_gpt4 = ClarifaiChannel.get_grpc_channel()
75
+ stub_gpt4 = service_pb2_grpc.V2Stub(channel_gpt4)
76
+ metadata_gpt4 = (('authorization', 'Key ' + PAT_GPT4),)
77
+ userDataObject_gpt4 = resources_pb2.UserAppIDSet(user_id=USER_ID_GPT4, app_id=APP_ID_GPT4)
78
+
79
+ # Prepare the request for NewsGuardian model
80
+ input_data_gpt4 = resources_pb2.Data()
81
+
82
+ if raw_text:
83
+ input_data_gpt4.text.raw = raw_text
84
+
85
+ if image_upload is not None:
86
+ image_bytes_gpt4 = image_upload.read()
87
+ input_data_gpt4.image.base64 = image_bytes_gpt4
88
+
89
+ post_model_outputs_response_gpt4 = stub_gpt4.PostModelOutputs(
90
+ service_pb2.PostModelOutputsRequest(
91
+ user_app_id=userDataObject_gpt4,
92
+ model_id=MODEL_ID_GPT4,
93
+ version_id=MODEL_VERSION_ID_GPT4,
94
+ inputs=[resources_pb2.Input(data=input_data_gpt4)]
95
+ ),
96
+ metadata=metadata_gpt4 # Use metadata directly in the gRPC request
97
  )
98
 
99
+ # Check if the request was successful for NewsGuardian model
100
+ if post_model_outputs_response_gpt4.status.code != status_code_pb2.SUCCESS:
101
+ st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_gpt4.status.description}")
102
+ else:
103
+ # Get the output for NewsGuardian model
104
+ output_gpt4 = post_model_outputs_response_gpt4.outputs[0].data
105
+
106
+ # Display the result for NewsGuardian model
107
+ if output_gpt4.HasField("image"):
108
+ st.image(output_gpt4.image.base64, caption='Generated Image (NewsGuardian model)', use_column_width=True)
109
+ elif output_gpt4.HasField("text"):
110
+ # Display the text result
111
+ st.text(output_gpt4.text.raw)
112
+
113
+ # Convert text to speech and play the audio
114
+ stub_tts = service_pb2_grpc.V2Stub(channel_gpt4) # Use the same channel for TTS
115
+
116
+ tts_input_data = resources_pb2.Data()
117
+ tts_input_data.text.raw = output_gpt4.text.raw
118
+
119
+ tts_response = stub_tts.PostModelOutputs(
120
+ service_pb2.PostModelOutputsRequest(
121
+ user_app_id=userDataObject_tts,
122
+ model_id=MODEL_ID_TTS,
123
+ version_id=MODEL_VERSION_ID_TTS,
124
+ inputs=[resources_pb2.Input(data=tts_input_data)]
125
+ ),
126
+ metadata=metadata_gpt4 # Use the same metadata for TTS
127
+ )
128
+
129
+ # Check if the TTS request was successful
130
+ if tts_response.status.code == status_code_pb2.SUCCESS:
131
+ tts_output = tts_response.outputs[0].data
132
+ st.audio(tts_output.audio.base64, format='audio/wav')
133
+ else:
134
+ st.error(f"NewsGuardian model API request failed: {tts_response.status.description}")
135
+
136
+ elif model_type == "DALL-E":
137
+ # Set up gRPC channel for DALL-E
138
+ channel_dalle = ClarifaiChannel.get_grpc_channel()
139
+ stub_dalle = service_pb2_grpc.V2Stub(channel_dalle)
140
+ metadata_dalle = (('authorization', 'Key ' + PAT_DALLE),)
141
+ userDataObject_dalle = resources_pb2.UserAppIDSet(user_id=USER_ID_DALLE, app_id=APP_ID_DALLE)
142
+
143
+ # Prepare the request for DALL-E
144
+ input_data_dalle = resources_pb2.Data()
145
+
146
+ if raw_text:
147
+ input_data_dalle.text.raw = raw_text
148
+
149
+ post_model_outputs_response_dalle = stub_dalle.PostModelOutputs(
150
+ service_pb2.PostModelOutputsRequest(
151
+ user_app_id=userDataObject_dalle,
152
+ model_id=MODEL_ID_DALLE,
153
+ version_id=MODEL_VERSION_ID_DALLE,
154
+ inputs=[resources_pb2.Input(data=input_data_dalle)]
155
+ ),
156
+ metadata=metadata_dalle
157
+ )
158
+
159
+ # Check if the request was successful for DALL-E
160
+ if post_model_outputs_response_dalle.status.code != status_code_pb2.SUCCESS:
161
+ st.error(f"DALL-E API request failed: {post_model_outputs_response_dalle.status.description}")
162
+ else:
163
+ # Get the output for DALL-E
164
+ output_dalle = post_model_outputs_response_dalle.outputs[0].data
165
+
166
+ # Display the result for DALL-E
167
+ if output_dalle.HasField("image"):
168
+ st.image(output_dalle.image.base64, caption='Generated Image (DALL-E)', use_column_width=True)
169
+ elif output_dalle.HasField("text"):
170
+ st.text(output_dalle.text.raw)
171
+
172
+ elif model_type == "NewsGuardian model":
173
+ # Set up gRPC channel for NewsGuardian model
174
+ channel_tts = ClarifaiChannel.get_grpc_channel()
175
+ stub_tts = service_pb2_grpc.V2Stub(channel_tts)
176
+ metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
177
+ userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS)
178
+
179
+ # Prepare the request for NewsGuardian model
180
+ input_data_tts = resources_pb2.Data()
181
+
182
+ if raw_text:
183
+ input_data_tts.text.raw = raw_text
184
+
185
+ post_model_outputs_response_tts = stub_tts.PostModelOutputs(
186
+ service_pb2.PostModelOutputsRequest(
187
+ user_app_id=userDataObject_tts,
188
+ model_id=MODEL_ID_TTS,
189
+ version_id=MODEL_VERSION_ID_TTS,
190
+ inputs=[resources_pb2.Input(data=input_data_tts)]
191
+ ),
192
+ metadata=metadata_tts
193
+ )
194
+
195
+ # Check if the request was successful for NewsGuardian model
196
+ if post_model_outputs_response_tts.status.code != status_code_pb2.SUCCESS:
197
+ st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_tts.status.description}")
198
+ else:
199
+ # Get the output for NewsGuardian model
200
+ output_tts = post_model_outputs_response_tts.outputs[0].data
201
+
202
+ # Display the result for NewsGuardian model
203
+ if output_tts.HasField("text"):
204
+ st.text(output_tts.text.raw)
205
+
206
+ if output_tts.HasField("audio"):
207
+ st.audio(output_tts.audio.base64, format='audio/wav')
208
+
209
+
210
+ # Add the beautiful social media icon section
211
+ st.markdown("""
212
+ <div align="center">
213
+ <a href="https://github.com/pyresearch/pyresearch" style="text-decoration:none;">
214
+ <img src="https://user-images.githubusercontent.com/34125851/226594737-c21e2dda-9cc6-42ef-b4e7-a685fea4a21d.png" width="2%" alt="" /></a>
215
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
216
+ <a href="https://www.linkedin.com/company/pyresearch/" style="text-decoration:none;">
217
+ <img src="https://user-images.githubusercontent.com/34125851/226596446-746ffdd0-a47e-4452-84e3-bf11ec2aa26a.png" width="2%" alt="" /></a>
218
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
219
+ <a href="https://twitter.com/Noorkhokhar10" style="text-decoration:none;">
220
+ <img src="https://user-images.githubusercontent.com/34125851/226599162-9b11194e-4998-440a-ba94-c8a5e1cdc676.png" width="2%" alt="" /></a>
221
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
222
+ <a href="https://www.youtube.com/@Pyresearch" style="text-decoration:none;">
223
+ <img src="https://user-images.githubusercontent.com/34125851/226599904-7d5cc5c0-89d2-4d1e-891e-19bee1951744.png" width="2%" alt="" /></a>
224
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
225
+ <a href="https://www.facebook.com/Pyresearch" style="text-decoration:none;">
226
+ <img src="https://user-images.githubusercontent.com/34125851/226600380-a87a9142-e8e0-4ec9-bf2c-dd6e9da2f05a.png" width="2%" alt="" /></a>
227
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
228
+ <a href="https://www.instagram.com/pyresearch/" style="text-decoration:none;">
229
+ <img src="https://user-images.githubusercontent.com/34125851/226601355-ffe0b597-9840-4e10-bbef-43d6c74b5a9e.png" width="2%" alt="" /></a>
230
+ </div>
231
+ <hr>
232
+ """, unsafe_allow_html=True)
233
+
234
+
235
+
236
+
237
+