pyresearch commited on
Commit
b0e9ffc
·
verified ·
1 Parent(s): f191d24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +191 -23
app.py CHANGED
@@ -1,28 +1,196 @@
1
  import streamlit as st
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
3
  import torch
 
4
 
5
- # Load the Phi 2 model and tokenizer outside the Streamlit app
6
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
7
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", device_map="auto", trust_remote_code=True)
8
-
9
- # Streamlit UI
10
- st.title("Microsoft Phi 2 Streamlit App")
11
-
12
- # User input prompt
13
- prompt = st.text_area("Enter your prompt:", "Write a story about Nasa")
14
-
15
- # Generate output based on user input
16
- if st.button("Generate Output"):
17
- with torch.no_grad():
18
- token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
19
- output_ids = model.generate(
20
- token_ids.to(model.device),
21
- max_new_tokens=512,
22
- do_sample=True,
23
- temperature=0.3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
 
26
- output = tokenizer.decode(output_ids[0][token_ids.size(1):])
27
- st.text("Generated Output:")
28
- st.write(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
3
+ from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
4
+ from clarifai_grpc.grpc.api.status import status_code_pb2
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
+
9
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
10
+ tokenizer = AutoTokenizer.from_pretraine
11
+
12
+ # GPT-4 credentials
13
+ PAT_GPT4 = "3ca5bd8b0f2244eb8d0e4b2838fc3cf1"
14
+ USER_ID_GPT4 = "openai"
15
+ APP_ID_GPT4 = "chat-completion"
16
+ MODEL_ID_GPT4 = "openai-gpt-4-vision"
17
+ MODEL_VERSION_ID_GPT4 = "266df29bc09843e0aee9b7bf723c03c2"
18
+
19
+ # DALL-E credentials
20
+ PAT_DALLE = "bfdeb4029ef54d23a2e608b0aa4c00e4"
21
+ USER_ID_DALLE = "openai"
22
+ APP_ID_DALLE = "dall-e"
23
+ MODEL_ID_DALLE = "dall-e-3"
24
+ MODEL_VERSION_ID_DALLE = "dc9dcb6ee67543cebc0b9a025861b868"
25
+
26
+ # TTS credentials
27
+ PAT_TTS = "bfdeb4029ef54d23a2e608b0aa4c00e4"
28
+ USER_ID_TTS = "openai"
29
+ APP_ID_TTS = "tts"
30
+ MODEL_ID_TTS = "openai-tts-1"
31
+ MODEL_VERSION_ID_TTS = "fff6ce1fd487457da95b79241ac6f02d"
32
+
33
+ # NewsGuardian model credentials
34
+ PAT_NEWSGUARDIAN = "your_news_guardian_pat"
35
+ USER_ID_NEWSGUARDIAN = "your_user_id"
36
+ APP_ID_NEWSGUARDIAN = "your_app_id"
37
+ MODEL_ID_NEWSGUARDIAN = "your_model_id"
38
+ MODEL_VERSION_ID_NEWSGUARDIAN = "your_model_version_id"
39
+
40
+
41
+ # Set up gRPC channel for NewsGuardian model
42
+ channel_tts = ClarifaiChannel.get_grpc_channel()
43
+ stub_tts = service_pb2_grpc.V2Stub(channel_tts)
44
+ metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
45
+ userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS)
46
+
47
+ # Streamlit app
48
+ st.title("NewsGuardian and phi-2 Text Generation")
49
+
50
+ # Function to generate text using the "microsoft/phi-2" model
51
+ def generate_phi2_text(input_text):
52
+ inputs = tokenizer_phi2(input_text, return_tensors="pt", return_attention_mask=False)
53
+ outputs = model_phi2.generate(**inputs, max_length=200)
54
+ generated_text = tokenizer_phi2.batch_decode(outputs)[0]
55
+ return generated_text
56
+
57
+ # User input for phi-2 model
58
+ raw_text_phi2 = st.text_area("Enter text for phi-2 model")
59
+
60
+ # Button to generate result using "microsoft/phi-2" model
61
+ if st.button("Generate text with phi-2 model"):
62
+ if raw_text_phi2:
63
+ generated_text_phi2 = generate_phi2_text(raw_text_phi2)
64
+ st.text("Generated text with phi-2 model")
65
+ st.text(generated_text_phi2)
66
+ else:
67
+ st.warning("Please enter text for phi-2 model")
68
+
69
+ # User input for selecting the model
70
+ model_type = st.selectbox("Select Model", ["NewsGuardian model", "DALL-E", "phi-2"])
71
+ raw_text_news_guardian = st.text_area("This news is real or fake?")
72
+ image_upload_news_guardian = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
73
+
74
+ # Button to generate result for the selected model
75
+ if st.button("Generate Result"):
76
+ if model_type == "NewsGuardian model":
77
+ # Set up gRPC channel for NewsGuardian model
78
+ channel_news_guardian = ClarifaiChannel.get_grpc_channel()
79
+ stub_news_guardian = service_pb2_grpc.V2Stub(channel_news_guardian)
80
+ metadata_news_guardian = (('authorization', 'Key ' + PAT_NEWSGUARDIAN),)
81
+ userDataObject_news_guardian = resources_pb2.UserAppIDSet(user_id=USER_ID_NEWSGUARDIAN, app_id=APP_ID_NEWSGUARDIAN)
82
+
83
+ # Prepare the request for NewsGuardian model
84
+ input_data_news_guardian = resources_pb2.Data()
85
+
86
+ if raw_text_news_guardian:
87
+ input_data_news_guardian.text.raw = raw_text_news_guardian
88
+
89
+ if image_upload_news_guardian is not None:
90
+ image_bytes_news_guardian = image_upload_news_guardian.read()
91
+ input_data_news_guardian.image.base64 = image_bytes_news_guardian
92
+
93
+ post_model_outputs_response_news_guardian = stub_news_guardian.PostModelOutputs(
94
+ service_pb2.PostModelOutputsRequest(
95
+ user_app_id=userDataObject_news_guardian,
96
+ model_id=MODEL_ID_NEWSGUARDIAN,
97
+ version_id=MODEL_VERSION_ID_NEWSGUARDIAN,
98
+ inputs=[resources_pb2.Input(data=input_data_news_guardian)]
99
+ ),
100
+ metadata=metadata_news_guardian # Use metadata directly in the gRPC request
101
+ )
102
+
103
+ # Check if the request was successful for NewsGuardian model
104
+ if post_model_outputs_response_news_guardian.status.code != status_code_pb2.SUCCESS:
105
+ st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_news_guardian.status.description}")
106
+ else:
107
+ # Get the output for NewsGuardian model
108
+ output_news_guardian = post_model_outputs_response_news_guardian.outputs[0].data
109
+
110
+ # Display the result for NewsGuardian model
111
+ if output_news_guardian.HasField("image"):
112
+ st.image(output_news_guardian.image.base64, caption='Generated Image (NewsGuardian model)', use_column_width=True)
113
+ elif output_news_guardian.HasField("text"):
114
+ # Display the text result
115
+ st.text(output_news_guardian.text.raw)
116
+
117
+ # Convert text to speech and play the audio
118
+ tts_input_data = resources_pb2.Data()
119
+ tts_input_data.text.raw = output_news_guardian.text.raw
120
+
121
+ tts_response = stub_tts.PostModelOutputs(
122
+ service_pb2.PostModelOutputsRequest(
123
+ user_app_id=userDataObject_tts,
124
+ model_id=MODEL_ID_TTS,
125
+ version_id=MODEL_VERSION_ID_TTS,
126
+ inputs=[resources_pb2.Input(data=tts_input_data)]
127
+ ),
128
+ metadata=metadata_tts # Use the same metadata for TTS
129
+ )
130
+
131
+ # Check if the TTS request was successful
132
+ if tts_response.status.code == status_code_pb2.SUCCESS:
133
+ tts_output = tts_response.outputs[0].data
134
+ st.audio(tts_output.audio.base64, format='audio/wav')
135
+ else:
136
+ st.error(f"TTS API request failed: {tts_response.status.description}")
137
+
138
+ elif model_type == "DALL-E":
139
+ # Set up gRPC channel for DALL-E
140
+ channel_dalle = ClarifaiChannel.get_grpc_channel()
141
+ stub_dalle = service_pb2_grpc.V2Stub(channel_dalle)
142
+ metadata_dalle = (('authorization', 'Key ' + PAT_DALLE),)
143
+ userDataObject_dalle = resources_pb2.UserAppIDSet(user_id=USER_ID_DALLE, app_id=APP_ID_DALLE)
144
+
145
+ # Prepare the request for DALL-E
146
+ input_data_dalle = resources_pb2.Data()
147
+
148
+ if raw_text_news_guardian:
149
+ input_data_dalle.text.raw = raw_text_news_guardian
150
+
151
+ post_model_outputs_response_dalle = stub_dalle.PostModelOutputs(
152
+ service_pb2.PostModelOutputsRequest(
153
+ user_app_id=userDataObject_dalle,
154
+ model_id=MODEL_ID_DALLE,
155
+ version_id=MODEL_VERSION_ID_DALLE,
156
+ inputs=[resources_pb2.Input(data=input_data_dalle)]
157
+ ),
158
+ metadata=metadata_dalle
159
  )
160
 
161
+ # Check if the request was successful for DALL-E
162
+ if post_model_outputs_response_dalle.status.code != status_code_pb2.SUCCESS:
163
+ st.error(f"DALL-E API request failed: {post_model_outputs_response_dalle.status.description}")
164
+ else:
165
+ # Get the output for DALL-E
166
+ output_dalle = post_model_outputs_response_dalle.outputs[0].data
167
+
168
+ # Display the result for DALL-E
169
+ if output_dalle.HasField("image"):
170
+ st.image(output_dalle.image.base64, caption='Generated Image (DALL-E)', use_column_width=True)
171
+ elif output_dalle.HasField("text"):
172
+ st.text(output_dalle.text.raw)
173
+
174
+ # Add the beautiful social media icon section
175
+ st.markdown("""
176
+ <div align="center">
177
+ <a href="https://github.com/pyresearch/pyresearch" style="text-decoration:none;">
178
+ <img src="https://user-images.githubusercontent.com/34125851/226594737-c21e2dda-9cc6-42ef-b4e7-a685fea4a21d.png" width="2%" alt="" /></a>
179
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
180
+ <a href="https://www.linkedin.com/company/pyresearch/" style="text-decoration:none;">
181
+ <img src="https://user-images.githubusercontent.com/34125851/226596446-746ffdd0-a47e-4452-84e3-bf11ec2aa26a.png" width="2%" alt="" /></a>
182
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
183
+ <a href="https://twitter.com/Noorkhokhar10" style="text-decoration:none;">
184
+ <img src="https://user-images.githubusercontent.com/34125851/226599162-9b11194e-4998-440a-ba94-c8a5e1cdc676.png" width="2%" alt="" /></a>
185
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
186
+ <a href="https://www.youtube.com/@Pyresearch" style="text-decoration:none;">
187
+ <img src="https://user-images.githubusercontent.com/34125851/226599904-7d5cc5c0-89d2-4d1e-891e-19bee1951744.png" width="2%" alt="" /></a>
188
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
189
+ <a href="https://www.facebook.com/Pyresearch" style="text-decoration:none;">
190
+ <img src="https://user-images.githubusercontent.com/34125851/226600380-a87a9142-e8e0-4ec9-bf2c-dd6e9da2f05a.png" width="2%" alt="" /></a>
191
+ <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
192
+ <a href="https://www.instagram.com/pyresearch/" style="text-decoration:none;">
193
+ <img src="https://user-images.githubusercontent.com/34125851/226601355-ffe0b597-9840-4e10-bbef-43d6c74b5a9e.png" width="2%" alt="" /></a>
194
+ </div>
195
+ <hr>
196
+ """, unsafe_allow_html=True)