pyresearch commited on
Commit
de2fcea
·
verified ·
1 Parent(s): cc28781

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -190
app.py CHANGED
@@ -1,199 +1,26 @@
1
- import streamlit as st
2
- from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
3
- from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
4
- from clarifai_grpc.grpc.api.status import status_code_pb2
5
  import torch
 
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
 
8
- torch.set_default_tensor_type(torch.FloatTensor) # Set the default tensor type to float32
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
- model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True).to(device)
11
- tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
12
-
13
- # GPT-4 credentials
14
- PAT_GPT4 = "3ca5bd8b0f2244eb8d0e4b2838fc3cf1"
15
- USER_ID_GPT4 = "openai"
16
- APP_ID_GPT4 = "chat-completion"
17
- MODEL_ID_GPT4 = "openai-gpt-4-vision"
18
- MODEL_VERSION_ID_GPT4 = "266df29bc09843e0aee9b7bf723c03c2"
19
-
20
- # DALL-E credentials
21
- PAT_DALLE = "bfdeb4029ef54d23a2e608b0aa4c00e4"
22
- USER_ID_DALLE = "openai"
23
- APP_ID_DALLE = "dall-e"
24
- MODEL_ID_DALLE = "dall-e-3"
25
- MODEL_VERSION_ID_DALLE = "dc9dcb6ee67543cebc0b9a025861b868"
26
-
27
- # TTS credentials
28
- PAT_TTS = "bfdeb4029ef54d23a2e608b0aa4c00e4"
29
- USER_ID_TTS = "openai"
30
- APP_ID_TTS = "tts"
31
- MODEL_ID_TTS = "openai-tts-1"
32
- MODEL_VERSION_ID_TTS = "fff6ce1fd487457da95b79241ac6f02d"
33
-
34
- # NewsGuardian model credentials
35
- PAT_NEWSGUARDIAN = "your_news_guardian_pat"
36
- USER_ID_NEWSGUARDIAN = "your_user_id"
37
- APP_ID_NEWSGUARDIAN = "your_app_id"
38
- MODEL_ID_NEWSGUARDIAN = "your_model_id"
39
- MODEL_VERSION_ID_NEWSGUARDIAN = "your_model_version_id"
40
 
41
- # Set up gRPC channel for NewsGuardian model
42
- channel_tts = ClarifaiChannel.get_grpc_channel()
43
- stub_tts = service_pb2_grpc.V2Stub(channel_tts)
44
- metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
45
- userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS)
46
 
47
- # Streamlit app
48
- st.title("NewsGuardian")
49
 
50
- # Inserting logo
51
- st.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTdA-MJ_SUCRgLs1prqudpMdaX4x-x10Zqlwp7cpzXWCMM9xjBAJYWdJsDlLoHBqNpj8qs&usqp=CAU")
 
 
 
52
 
53
- # Function to generate text using the "microsoft/phi-2" model
54
- def generate_phi2_text(input_text):
55
- inputs = tokenizer(input_text, return_tensors="pt", return_attention_mask=False)
56
  outputs = model.generate(**inputs, max_length=200)
57
  generated_text = tokenizer.batch_decode(outputs)[0]
58
- return generated_text
59
-
60
- # User input
61
- raw_text_phi2 = st.text_area("Enter text for phi-2 model")
62
-
63
- # Button to generate result using "microsoft/phi-2" model
64
- if st.button("NewsGuardian model Generated fake news with phi-2"):
65
- if raw_text_phi2:
66
- generated_text_phi2 = generate_phi2_text(raw_text_phi2)
67
- st.text("NewsGuardian model Generated fake news with phi-2")
68
- st.text(generated_text_phi2)
69
- else:
70
- st.warning("Please enter news phi-2 model")
71
-
72
- # User input
73
- model_type = st.selectbox("Select Model", ["NewsGuardian model", "DALL-E"])
74
- raw_text_news_guardian = st.text_area("This news is real or fake?")
75
- image_upload_news_guardian = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
76
-
77
- # Button to generate result for NewsGuardian model
78
- if st.button("NewsGuardian News Result"):
79
- if model_type == "NewsGuardian model":
80
- # Set up gRPC channel for NewsGuardian model
81
- channel_news_guardian = ClarifaiChannel.get_grpc_channel()
82
- stub_news_guardian = service_pb2_grpc.V2Stub(channel_news_guardian)
83
- metadata_news_guardian = (('authorization', 'Key ' + PAT_NEWSGUARDIAN),)
84
- userDataObject_news_guardian = resources_pb2.UserAppIDSet(user_id=USER_ID_NEWSGUARDIAN, app_id=APP_ID_NEWSGUARDIAN)
85
-
86
- # Prepare the request for NewsGuardian model
87
- input_data_news_guardian = resources_pb2.Data()
88
-
89
- if raw_text_news_guardian:
90
- input_data_news_guardian.text.raw = raw_text_news_guardian
91
-
92
- if image_upload_news_guardian is not None:
93
- image_bytes_news_guardian = image_upload_news_guardian.read()
94
- input_data_news_guardian.image.base64 = image_bytes_news_guardian
95
-
96
- post_model_outputs_response_news_guardian = stub_news_guardian.PostModelOutputs(
97
- service_pb2.PostModelOutputsRequest(
98
- user_app_id=userDataObject_news_guardian,
99
- model_id=MODEL_ID_NEWSGUARDIAN,
100
- version_id=MODEL_VERSION_ID_NEWSGUARDIAN,
101
- inputs=[resources_pb2.Input(data=input_data_news_guardian)]
102
- ),
103
- metadata=metadata_news_guardian # Use metadata directly in the gRPC request
104
- )
105
-
106
- # Check if the request was successful for NewsGuardian model
107
- if post_model_outputs_response_news_guardian.status.code != status_code_pb2.SUCCESS:
108
- st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_news_guardian.status.description}")
109
- else:
110
- # Get the output for NewsGuardian model
111
- output_news_guardian = post_model_outputs_response_news_guardian.outputs[0].data
112
-
113
- # Display the result for NewsGuardian model
114
- if output_news_guardian.HasField("image"):
115
- st.image(output_news_guardian.image.base64, caption='Generated Image (NewsGuardian model)', use_column_width=True)
116
- elif output_news_guardian.HasField("text"):
117
- # Display the text result
118
- st.text(output_news_guardian.text.raw)
119
-
120
- # Convert text to speech and play the audio
121
- tts_input_data = resources_pb2.Data()
122
- tts_input_data.text.raw = output_news_guardian.text.raw
123
-
124
- tts_response = stub_tts.PostModelOutputs(
125
- service_pb2.PostModelOutputsRequest(
126
- user_app_id=userDataObject_tts,
127
- model_id=MODEL_ID_TTS,
128
- version_id=MODEL_VERSION_ID_TTS,
129
- inputs=[resources_pb2.Input(data=tts_input_data)]
130
- ),
131
- metadata=metadata_tts # Use the same metadata for TTS
132
- )
133
-
134
- # Check if the TTS request was successful
135
- if tts_response.status.code == status_code_pb2.SUCCESS:
136
- tts_output = tts_response.outputs[0].data
137
- st.audio(tts_output.audio.base64, format='audio/wav')
138
- else:
139
- st.error(f"TTS API request failed: {tts_response.status.description}")
140
-
141
- elif model_type == "DALL-E":
142
- # Set up gRPC channel for DALL-E
143
- channel_dalle = ClarifaiChannel.get_grpc_channel()
144
- stub_dalle = service_pb2_grpc.V2Stub(channel_dalle)
145
- metadata_dalle = (('authorization', 'Key ' + PAT_DALLE),)
146
- userDataObject_dalle = resources_pb2.UserAppIDSet(user_id=USER_ID_DALLE, app_id=APP_ID_DALLE)
147
-
148
- # Prepare the request for DALL-E
149
- input_data_dalle = resources_pb2.Data()
150
-
151
- if raw_text_news_guardian:
152
- input_data_dalle.text.raw = raw_text_news_guardian
153
-
154
- post_model_outputs_response_dalle = stub_dalle.PostModelOutputs(
155
- service_pb2.PostModelOutputsRequest(
156
- user_app_id=userDataObject_dalle,
157
- model_id=MODEL_ID_DALLE,
158
- version_id=MODEL_VERSION_ID_DALLE,
159
- inputs=[resources_pb2.Input(data=input_data_dalle)]
160
- ),
161
- metadata=metadata_dalle
162
- )
163
-
164
- # Check if the request was successful for DALL-E
165
- if post_model_outputs_response_dalle.status.code != status_code_pb2.SUCCESS:
166
- st.error(f"DALL-E API request failed: {post_model_outputs_response_dalle.status.description}")
167
- else:
168
- # Get the output for DALL-E
169
- output_dalle = post_model_outputs_response_dalle.outputs[0].data
170
-
171
- # Display the result for DALL-E
172
- if output_dalle.HasField("image"):
173
- st.image(output_dalle.image.base64, caption='Generated Image (DALL-E)', use_column_width=True)
174
- elif output_dalle.HasField("text"):
175
- st.text(output_dalle.text.raw)
176
-
177
- # Add the beautiful social media icon section
178
- st.markdown("""
179
- <div align="center">
180
- <a href="https://github.com/pyresearch/pyresearch" style="text-decoration:none;">
181
- <img src="https://user-images.githubusercontent.com/34125851/226594737-c21e2dda-9cc6-42ef-b4e7-a685fea4a21d.png" width="2%" alt="" /></a>
182
- <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
183
- <a href="https://www.linkedin.com/company/pyresearch/" style="text-decoration:none;">
184
- <img src="https://user-images.githubusercontent.com/34125851/226596446-746ffdd0-a47e-4452-84e3-bf11ec2aa26a.png" width="2%" alt="" /></a>
185
- <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
186
- <a href="https://twitter.com/Noorkhokhar10" style="text-decoration:none;">
187
- <img src="https://user-images.githubusercontent.com/34125851/226599162-9b11194e-4998-440a-ba94-c8a5e1cdc676.png" width="2%" alt="" /></a>
188
- <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
189
- <a href="https://www.youtube.com/@Pyresearch" style="text-decoration:none;">
190
- <img src="https://user-images.githubusercontent.com/34125851/226599904-7d5cc5c0-89d2-4d1e-891e-19bee1951744.png" width="2%" alt="" /></a>
191
- <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
192
- <a href="https://www.facebook.com/Pyresearch" style="text-decoration:none;">
193
- <img src="https://user-images.githubusercontent.com/34125851/226600380-a87a9142-e8e0-4ec9-bf2c-dd6e9da2f05a.png" width="2%" alt="" /></a>
194
- <img src="https://user-images.githubusercontent.com/34125851/226595799-160b0da3-c9e0-4562-8544-5f20460f7cc9.png" width="2%" alt="" />
195
- <a href="https://www.instagram.com/pyresearch/" style="text-decoration:none;">
196
- <img src="https://user-images.githubusercontent.com/34125851/226601355-ffe0b597-9840-4e10-bbef-43d6c74b5a9e.png" width="2%" alt="" /></a>
197
- </div>
198
- <hr>
199
- """, unsafe_allow_html=True)
 
 
 
 
 
1
  import torch
2
+ import streamlit as st
3
  from transformers import AutoModelForCausalLM, AutoTokenizer
4
 
5
+ torch.set_default_device("cpu")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
8
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
 
 
 
9
 
10
+ st.title("Text Generation with phi-2 Model")
 
11
 
12
+ # Input text area for the user
13
+ user_input = st.text_area("Enter some text:", '''def print_prime(n):
14
+ """
15
+ Print all primes between 1 and n
16
+ """''')
17
 
18
+ # Generate text based on user input
19
+ if st.button("Generate Text"):
20
+ inputs = tokenizer(user_input, return_tensors="pt", return_attention_mask=False)
21
  outputs = model.generate(**inputs, max_length=200)
22
  generated_text = tokenizer.batch_decode(outputs)[0]
23
+
24
+ # Display the generated text
25
+ st.subheader("Generated Text:")
26
+ st.write(generated_text)