pyresearch commited on
Commit
4e838e8
·
verified ·
1 Parent(s): 7d9f3d5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -46
app.py CHANGED
@@ -1,25 +1,43 @@
 
 
1
  import streamlit as st
2
  from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
3
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
4
  from clarifai_grpc.grpc.api.status import status_code_pb2
5
- from transformers import AutoModelForCausalLM, AutoTokenizer
6
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
7
-
8
- import torch
9
-
10
 
11
-
12
- torch.set_default_device("cpu")
13
-
14
- # Load the 'microsoft/phi-2' model and tokenizer
15
- model_phi2 = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", torch_dtype="auto", trust_remote_code=True)
16
- tokenizer_phi2 = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  # Set up gRPC channel for NewsGuardian model
19
  channel_tts = ClarifaiChannel.get_grpc_channel()
20
  stub_tts = service_pb2_grpc.V2Stub(channel_tts)
21
  metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
22
- userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS,)
23
 
24
  # Streamlit app
25
  st.title("NewsGuardian")
@@ -27,62 +45,76 @@ st.title("NewsGuardian")
27
  # Inserting logo
28
  st.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTdA-MJ_SUCRgLs1prqudpMdaX4x-x10Zqlwp7cpzXWCMM9xjBAJYWdJsDlLoHBqNpj8qs&usqp=CAU")
29
 
30
- # Function to get gRPC channel for NewsGuardian model
31
- def get_tts_channel():
32
- channel_tts = ClarifaiChannel.get_grpc_channel()
33
- return channel_tts, channel_tts.metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # User input
36
  model_type = st.selectbox("Select Model", ["NewsGuardian model", "DALL-E"])
37
- raw_text = st.text_area("This news is real or fake?")
38
- image_upload = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
39
 
40
- # Button to generate result
41
  if st.button("NewsGuardian News Result"):
42
  if model_type == "NewsGuardian model":
43
  # Set up gRPC channel for NewsGuardian model
44
- channel_gpt4 = ClarifaiChannel.get_grpc_channel()
45
- stub_gpt4 = service_pb2_grpc.V2Stub(channel_gpt4)
46
- metadata_gpt4 = (('authorization', 'Key ' + PAT_GPT4),)
47
- userDataObject_gpt4 = resources_pb2.UserAppIDSet(user_id=USER_ID_GPT4, app_id=APP_ID_GPT4)
48
 
49
  # Prepare the request for NewsGuardian model
50
- input_data_gpt4 = resources_pb2.Data()
51
 
52
- if raw_text:
53
- input_data_gpt4.text.raw = raw_text
54
 
55
- if image_upload is not None:
56
- image_bytes_gpt4 = image_upload.read()
57
- input_data_gpt4.image.base64 = image_bytes_gpt4
58
 
59
- post_model_outputs_response_gpt4 = stub_gpt4.PostModelOutputs(
60
  service_pb2.PostModelOutputsRequest(
61
- user_app_id=userDataObject_gpt4,
62
- model_id=MODEL_ID_GPT4,
63
- version_id=MODEL_VERSION_ID_GPT4,
64
- inputs=[resources_pb2.Input(data=input_data_gpt4)]
65
  ),
66
- metadata=metadata_gpt4 # Use metadata directly in the gRPC request
67
  )
68
 
69
  # Check if the request was successful for NewsGuardian model
70
- if post_model_outputs_response_gpt4.status.code != status_code_pb2.SUCCESS:
71
- st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_gpt4.status.description}")
72
  else:
73
  # Get the output for NewsGuardian model
74
- output_gpt4 = post_model_outputs_response_gpt4.outputs[0].data
75
 
76
  # Display the result for NewsGuardian model
77
- if output_gpt4.HasField("image"):
78
- st.image(output_gpt4.image.base64, caption='Generated Image (NewsGuardian model)', use_column_width=True)
79
- elif output_gpt4.HasField("text"):
80
  # Display the text result
81
- st.text(output_gpt4.text.raw)
82
 
83
  # Convert text to speech and play the audio
84
  tts_input_data = resources_pb2.Data()
85
- tts_input_data.text.raw = output_gpt4.text.raw
86
 
87
  tts_response = stub_tts.PostModelOutputs(
88
  service_pb2.PostModelOutputsRequest(
@@ -99,7 +131,7 @@ if st.button("NewsGuardian News Result"):
99
  tts_output = tts_response.outputs[0].data
100
  st.audio(tts_output.audio.base64, format='audio/wav')
101
  else:
102
- st.error(f"NewsGuardian model API request failed: {tts_response.status.description}")
103
 
104
  elif model_type == "DALL-E":
105
  # Set up gRPC channel for DALL-E
@@ -111,8 +143,8 @@ if st.button("NewsGuardian News Result"):
111
  # Prepare the request for DALL-E
112
  input_data_dalle = resources_pb2.Data()
113
 
114
- if raw_text:
115
- input_data_dalle.text.raw = raw_text
116
 
117
  post_model_outputs_response_dalle = stub_dalle.PostModelOutputs(
118
  service_pb2.PostModelOutputsRequest(
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import streamlit as st
4
  from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
5
  from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
6
  from clarifai_grpc.grpc.api.status import status_code_pb2
 
 
 
 
 
7
 
8
+ # GPT-4 credentials
9
+ PAT_GPT4 = "3ca5bd8b0f2244eb8d0e4b2838fc3cf1"
10
+ USER_ID_GPT4 = "openai"
11
+ APP_ID_GPT4 = "chat-completion"
12
+ MODEL_ID_GPT4 = "openai-gpt-4-vision"
13
+ MODEL_VERSION_ID_GPT4 = "266df29bc09843e0aee9b7bf723c03c2"
14
+
15
+ # DALL-E credentials
16
+ PAT_DALLE = "bfdeb4029ef54d23a2e608b0aa4c00e4"
17
+ USER_ID_DALLE = "openai"
18
+ APP_ID_DALLE = "dall-e"
19
+ MODEL_ID_DALLE = "dall-e-3"
20
+ MODEL_VERSION_ID_DALLE = "dc9dcb6ee67543cebc0b9a025861b868"
21
+
22
+ # TTS credentials
23
+ PAT_TTS = "bfdeb4029ef54d23a2e608b0aa4c00e4"
24
+ USER_ID_TTS = "openai"
25
+ APP_ID_TTS = "tts"
26
+ MODEL_ID_TTS = "openai-tts-1"
27
+ MODEL_VERSION_ID_TTS = "fff6ce1fd487457da95b79241ac6f02d"
28
+
29
+ # NewsGuardian model credentials
30
+ PAT_NEWSGUARDIAN = "your_news_guardian_pat"
31
+ USER_ID_NEWSGUARDIAN = "your_user_id"
32
+ APP_ID_NEWSGUARDIAN = "your_app_id"
33
+ MODEL_ID_NEWSGUARDIAN = "your_model_id"
34
+ MODEL_VERSION_ID_NEWSGUARDIAN = "your_model_version_id"
35
 
36
  # Set up gRPC channel for NewsGuardian model
37
  channel_tts = ClarifaiChannel.get_grpc_channel()
38
  stub_tts = service_pb2_grpc.V2Stub(channel_tts)
39
  metadata_tts = (('authorization', 'Key ' + PAT_TTS),)
40
+ userDataObject_tts = resources_pb2.UserAppIDSet(user_id=USER_ID_TTS, app_id=APP_ID_TTS)
41
 
42
  # Streamlit app
43
  st.title("NewsGuardian")
 
45
  # Inserting logo
46
  st.image("https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTdA-MJ_SUCRgLs1prqudpMdaX4x-x10Zqlwp7cpzXWCMM9xjBAJYWdJsDlLoHBqNpj8qs&usqp=CAU")
47
 
48
+ # Function to generate text using the "microsoft/phi-2" model
49
+ def generate_phi2_text(input_text):
50
+ inputs = tokenizer(input_text, return_tensors="pt", return_attention_mask=False)
51
+ outputs = model.generate(**inputs, max_length=200)
52
+ generated_text = tokenizer.batch_decode(outputs)[0]
53
+ return generated_text
54
+
55
+ # User input
56
+ raw_text_phi2 = st.text_area("Enter text for phi-2 model")
57
+
58
+ # Button to generate result using "microsoft/phi-2" model
59
+ if st.button("NewsGuardian model Generated fake news with phi-2"):
60
+ if raw_text_phi2:
61
+ generated_text_phi2 = generate_phi2_text(raw_text_phi2)
62
+ st.text("NewsGuardian model Generated fake news with phi-2")
63
+ st.text(generated_text_phi2)
64
+ else:
65
+ st.warning("Please enter news phi-2 model")
66
 
67
  # User input
68
  model_type = st.selectbox("Select Model", ["NewsGuardian model", "DALL-E"])
69
+ raw_text_news_guardian = st.text_area("This news is real or fake?")
70
+ image_upload_news_guardian = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
71
 
72
+ # Button to generate result for NewsGuardian model
73
  if st.button("NewsGuardian News Result"):
74
  if model_type == "NewsGuardian model":
75
  # Set up gRPC channel for NewsGuardian model
76
+ channel_news_guardian = ClarifaiChannel.get_grpc_channel()
77
+ stub_news_guardian = service_pb2_grpc.V2Stub(channel_news_guardian)
78
+ metadata_news_guardian = (('authorization', 'Key ' + PAT_NEWSGUARDIAN),)
79
+ userDataObject_news_guardian = resources_pb2.UserAppIDSet(user_id=USER_ID_NEWSGUARDIAN, app_id=APP_ID_NEWSGUARDIAN)
80
 
81
  # Prepare the request for NewsGuardian model
82
+ input_data_news_guardian = resources_pb2.Data()
83
 
84
+ if raw_text_news_guardian:
85
+ input_data_news_guardian.text.raw = raw_text_news_guardian
86
 
87
+ if image_upload_news_guardian is not None:
88
+ image_bytes_news_guardian = image_upload_news_guardian.read()
89
+ input_data_news_guardian.image.base64 = image_bytes_news_guardian
90
 
91
+ post_model_outputs_response_news_guardian = stub_news_guardian.PostModelOutputs(
92
  service_pb2.PostModelOutputsRequest(
93
+ user_app_id=userDataObject_news_guardian,
94
+ model_id=MODEL_ID_NEWSGUARDIAN,
95
+ version_id=MODEL_VERSION_ID_NEWSGUARDIAN,
96
+ inputs=[resources_pb2.Input(data=input_data_news_guardian)]
97
  ),
98
+ metadata=metadata_news_guardian # Use metadata directly in the gRPC request
99
  )
100
 
101
  # Check if the request was successful for NewsGuardian model
102
+ if post_model_outputs_response_news_guardian.status.code != status_code_pb2.SUCCESS:
103
+ st.error(f"NewsGuardian model API request failed: {post_model_outputs_response_news_guardian.status.description}")
104
  else:
105
  # Get the output for NewsGuardian model
106
+ output_news_guardian = post_model_outputs_response_news_guardian.outputs[0].data
107
 
108
  # Display the result for NewsGuardian model
109
+ if output_news_guardian.HasField("image"):
110
+ st.image(output_news_guardian.image.base64, caption='Generated Image (NewsGuardian model)', use_column_width=True)
111
+ elif output_news_guardian.HasField("text"):
112
  # Display the text result
113
+ st.text(output_news_guardian.text.raw)
114
 
115
  # Convert text to speech and play the audio
116
  tts_input_data = resources_pb2.Data()
117
+ tts_input_data.text.raw = output_news_guardian.text.raw
118
 
119
  tts_response = stub_tts.PostModelOutputs(
120
  service_pb2.PostModelOutputsRequest(
 
131
  tts_output = tts_response.outputs[0].data
132
  st.audio(tts_output.audio.base64, format='audio/wav')
133
  else:
134
+ st.error(f"TTS API request failed: {tts_response.status.description}")
135
 
136
  elif model_type == "DALL-E":
137
  # Set up gRPC channel for DALL-E
 
143
  # Prepare the request for DALL-E
144
  input_data_dalle = resources_pb2.Data()
145
 
146
+ if raw_text_news_guardian:
147
+ input_data_dalle.text.raw = raw_text_news_guardian
148
 
149
  post_model_outputs_response_dalle = stub_dalle.PostModelOutputs(
150
  service_pb2.PostModelOutputsRequest(