barunsaha commited on
Commit
24afa64
·
unverified ·
2 Parent(s): 0707284 44d6df8

Merge pull request #58 from barun-saha/byok

Browse files
Files changed (7) hide show
  1. app.py +137 -111
  2. global_config.py +28 -5
  3. helpers/llm_helper.py +126 -94
  4. helpers/pptx_helper.py +27 -22
  5. helpers/text_helper.py +17 -23
  6. requirements.txt +2 -1
  7. strings.json +2 -1
app.py CHANGED
@@ -5,7 +5,6 @@ import datetime
5
  import logging
6
  import pathlib
7
  import random
8
- import sys
9
  import tempfile
10
  from typing import List, Union
11
 
@@ -17,9 +16,6 @@ from langchain_community.chat_message_histories import StreamlitChatMessageHisto
17
  from langchain_core.messages import HumanMessage
18
  from langchain_core.prompts import ChatPromptTemplate
19
 
20
- sys.path.append('..')
21
- sys.path.append('../..')
22
-
23
  from global_config import GlobalConfig
24
  from helpers import llm_helper, pptx_helper, text_helper
25
 
@@ -54,17 +50,58 @@ def _get_prompt_template(is_refinement: bool) -> str:
54
  return template
55
 
56
 
57
- @st.cache_resource
58
- def _get_llm(repo_id: str, max_new_tokens: int):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  """
60
- Get an LLM instance.
61
 
62
- :param repo_id: The model name.
63
- :param max_new_tokens: The max new tokens to generate.
64
- :return: The LLM.
65
  """
66
 
67
- return llm_helper.get_hf_endpoint(repo_id, max_new_tokens)
 
 
 
68
 
69
 
70
  APP_TEXT = _load_strings()
@@ -81,18 +118,32 @@ texts = list(GlobalConfig.PPTX_TEMPLATE_FILES.keys())
81
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
82
 
83
  with st.sidebar:
 
84
  pptx_template = st.sidebar.radio(
85
- 'Select a presentation template:',
86
  texts,
87
  captions=captions,
88
  horizontal=True
89
  )
90
- st.divider()
91
- llm_to_use = st.sidebar.selectbox(
92
- 'Select an LLM to use:',
93
- [f'{k} ({v["description"]})' for k, v in GlobalConfig.HF_MODELS.items()]
 
 
 
94
  ).split(' ')[0]
95
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def build_ui():
98
  """
@@ -119,119 +170,108 @@ def set_up_chat_ui():
119
  with st.expander('Usage Instructions'):
120
  st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS)
121
 
122
- st.info(
123
- 'If you like SlideDeck AI, please consider leaving a heart ❤️ on the'
124
- ' [Hugging Face Space](https://huggingface.co/spaces/barunsaha/slide-deck-ai/) or'
125
- ' a star ⭐ on [GitHub](https://github.com/barun-saha/slide-deck-ai).'
126
- ' Your [feedback](https://forms.gle/JECFBGhjvSj7moBx9) is appreciated.'
127
- )
128
-
129
- # view_messages = st.expander('View the messages in the session state')
130
-
131
- st.chat_message('ai').write(
132
- random.choice(APP_TEXT['ai_greetings'])
133
- )
134
 
135
  history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
136
-
137
- if _is_it_refinement():
138
- template = _get_prompt_template(is_refinement=True)
139
- else:
140
- template = _get_prompt_template(is_refinement=False)
141
-
142
- prompt_template = ChatPromptTemplate.from_template(template)
143
 
144
  # Since Streamlit app reloads at every interaction, display the chat history
145
  # from the save session state
146
  for msg in history.messages:
147
- msg_type = msg.type
148
- if msg_type == 'user':
149
- st.chat_message(msg_type).write(msg.content)
150
- else:
151
- st.chat_message(msg_type).code(msg.content, language='json')
152
 
153
  if prompt := st.chat_input(
154
  placeholder=APP_TEXT['chat_placeholder'],
155
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
156
  ):
157
- if not text_helper.is_valid_prompt(prompt):
158
- st.error(
159
- 'Not enough information provided!'
160
- ' Please be a little more descriptive and type a few words'
161
- ' with a few characters :)'
162
- )
163
  return
164
 
165
  logger.info(
166
  'User input: %s | #characters: %d | LLM: %s',
167
- prompt, len(prompt), llm_to_use
168
  )
169
  st.chat_message('user').write(prompt)
170
 
171
- user_messages = _get_user_messages()
172
- user_messages.append(prompt)
173
- list_of_msgs = [
174
- f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages)
175
- ]
176
- list_of_msgs = '\n'.join(list_of_msgs)
177
-
178
  if _is_it_refinement():
 
 
 
 
 
179
  formatted_template = prompt_template.format(
180
  **{
181
- 'instructions': list_of_msgs,
182
  'previous_content': _get_last_response(),
183
  }
184
  )
185
  else:
186
- formatted_template = prompt_template.format(
187
- **{
188
- 'question': prompt,
189
- }
190
- )
191
 
192
  progress_bar = st.progress(0, 'Preparing to call LLM...')
193
  response = ''
194
 
195
  try:
196
- for chunk in _get_llm(
197
- repo_id=llm_to_use,
198
- max_new_tokens=GlobalConfig.HF_MODELS[llm_to_use]['max_new_tokens']
199
- ).stream(formatted_template):
200
- response += chunk
201
-
202
- # Update the progress bar
203
- progress_percentage = min(
204
- len(response) / GlobalConfig.HF_MODELS[llm_to_use]['max_new_tokens'], 0.95
 
 
 
 
205
  )
 
 
 
 
 
 
206
  progress_bar.progress(
207
- progress_percentage,
 
 
 
 
 
208
  text='Streaming content...this might take a while...'
209
  )
210
  except requests.exceptions.ConnectionError:
211
- msg = (
212
  'A connection error occurred while streaming content from the LLM endpoint.'
213
  ' Unfortunately, the slide deck cannot be generated. Please try again later.'
214
- ' Alternatively, try selecting a different LLM from the dropdown list.'
 
215
  )
216
- logger.error(msg)
217
- st.error(msg)
218
  return
219
  except huggingface_hub.errors.ValidationError as ve:
220
- msg = (
221
  f'An error occurred while trying to generate the content: {ve}'
222
- '\nPlease try again with a significantly shorter input text.'
 
223
  )
224
- logger.error(msg)
225
- st.error(msg)
226
  return
227
  except Exception as ex:
228
- msg = (
229
  f'An unexpected error occurred while generating the content: {ex}'
230
  '\nPlease try again later, possibly with different inputs.'
231
  ' Alternatively, try selecting a different LLM from the dropdown list.'
 
 
 
232
  )
233
- logger.error(msg)
234
- st.error(msg)
235
  return
236
 
237
  history.add_user_message(prompt)
@@ -240,25 +280,20 @@ def set_up_chat_ui():
240
  # The content has been generated as JSON
241
  # There maybe trailing ``` at the end of the response -- remove them
242
  # To be careful: ``` may be part of the content as well when code is generated
243
- response_cleaned = text_helper.get_clean_json(response)
244
-
245
  logger.info(
246
- 'Cleaned JSON response:: original length: %d | cleaned length: %d',
247
- len(response), len(response_cleaned)
248
  )
249
- # logger.debug('Cleaned JSON: %s', response_cleaned)
250
 
251
  # Now create the PPT file
252
  progress_bar.progress(
253
  GlobalConfig.LLM_PROGRESS_MAX,
254
  text='Finding photos online and generating the slide deck...'
255
  )
256
- path = generate_slide_deck(response_cleaned)
257
  progress_bar.progress(1.0, text='Done!')
258
-
259
  st.chat_message('ai').code(response, language='json')
260
 
261
- if path:
262
  _display_download_button(path)
263
 
264
  logger.info(
@@ -279,44 +314,35 @@ def generate_slide_deck(json_str: str) -> Union[pathlib.Path, None]:
279
  try:
280
  parsed_data = json5.loads(json_str)
281
  except ValueError:
282
- st.error(
283
- 'Encountered error while parsing JSON...will fix it and retry'
284
- )
285
- logger.error(
286
- 'Caught ValueError: trying again after repairing JSON...'
287
  )
288
  try:
289
  parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
290
  except ValueError:
291
- st.error(
292
  'Encountered an error again while fixing JSON...'
293
  'the slide deck cannot be created, unfortunately ☹'
294
- '\nPlease try again later.'
295
- )
296
- logger.error(
297
- 'Caught ValueError: failed to repair JSON!'
298
  )
299
-
300
  return None
301
  except RecursionError:
302
- st.error(
303
- 'Encountered an error while parsing JSON...'
304
  'the slide deck cannot be created, unfortunately ☹'
305
- '\nPlease try again later.'
 
306
  )
307
- logger.error('Caught RecursionError while parsing JSON. Cannot generate the slide deck!')
308
-
309
  return None
310
  except Exception:
311
- st.error(
312
  'Encountered an error while parsing JSON...'
313
  'the slide deck cannot be created, unfortunately ☹'
314
- '\nPlease try again later.'
 
315
  )
316
- logger.error(
317
- 'Caught ValueError: failed to parse JSON!'
318
- )
319
-
320
  return None
321
 
322
  if DOWNLOAD_FILE_KEY in st.session_state:
 
5
  import logging
6
  import pathlib
7
  import random
 
8
  import tempfile
9
  from typing import List, Union
10
 
 
16
  from langchain_core.messages import HumanMessage
17
  from langchain_core.prompts import ChatPromptTemplate
18
 
 
 
 
19
  from global_config import GlobalConfig
20
  from helpers import llm_helper, pptx_helper, text_helper
21
 
 
50
  return template
51
 
52
 
53
+ def are_all_inputs_valid(
54
+ user_prompt: str,
55
+ selected_provider: str,
56
+ selected_model: str,
57
+ user_key: str,
58
+ ) -> bool:
59
+ """
60
+ Validate user input and LLM selection.
61
+
62
+ :param user_prompt: The prompt.
63
+ :param selected_provider: The LLM provider.
64
+ :param selected_model: Name of the model.
65
+ :param user_key: User-provided API key.
66
+ :return: `True` if all inputs "look" OK; `False` otherwise.
67
+ """
68
+
69
+ if not text_helper.is_valid_prompt(user_prompt):
70
+ handle_error(
71
+ 'Not enough information provided!'
72
+ ' Please be a little more descriptive and type a few words'
73
+ ' with a few characters :)',
74
+ False
75
+ )
76
+ return False
77
+
78
+ if not selected_provider or not selected_model:
79
+ handle_error('No valid LLM provider and/or model name found!', False)
80
+ return False
81
+
82
+ if not llm_helper.is_valid_llm_provider_model(selected_provider, selected_model, user_key):
83
+ handle_error(
84
+ 'The LLM settings do not look correct. Make sure that an API key/access token'
85
+ ' is provided if the selected LLM requires it.',
86
+ False
87
+ )
88
+ return False
89
+
90
+ return True
91
+
92
+
93
+ def handle_error(error_msg: str, should_log: bool):
94
  """
95
+ Display an error message in the app.
96
 
97
+ :param error_msg: The error message to be displayed.
98
+ :param should_log: If `True`, log the message.
 
99
  """
100
 
101
+ if should_log:
102
+ logger.error(error_msg)
103
+
104
+ st.error(error_msg)
105
 
106
 
107
  APP_TEXT = _load_strings()
 
118
  captions = [GlobalConfig.PPTX_TEMPLATE_FILES[x]['caption'] for x in texts]
119
 
120
  with st.sidebar:
121
+ # The PPT templates
122
  pptx_template = st.sidebar.radio(
123
+ '1: Select a presentation template:',
124
  texts,
125
  captions=captions,
126
  horizontal=True
127
  )
128
+
129
+ # The LLMs
130
+ llm_provider_to_use = st.sidebar.selectbox(
131
+ label='2: Select an LLM to use:',
132
+ options=[f'{k} ({v["description"]})' for k, v in GlobalConfig.VALID_MODELS.items()],
133
+ index=GlobalConfig.DEFAULT_MODEL_INDEX,
134
+ help=GlobalConfig.LLM_PROVIDER_HELP,
135
  ).split(' ')[0]
136
 
137
+ # The API key/access token
138
+ api_key_token = st.text_input(
139
+ label=(
140
+ '3: Paste your API key/access token:\n\n'
141
+ '*Mandatory* for Cohere and Gemini LLMs.'
142
+ ' *Optional* for HF Mistral LLMs but still encouraged.\n\n'
143
+ ),
144
+ type='password',
145
+ )
146
+
147
 
148
  def build_ui():
149
  """
 
170
  with st.expander('Usage Instructions'):
171
  st.markdown(GlobalConfig.CHAT_USAGE_INSTRUCTIONS)
172
 
173
+ st.info(APP_TEXT['like_feedback'])
174
+ st.chat_message('ai').write(random.choice(APP_TEXT['ai_greetings']))
 
 
 
 
 
 
 
 
 
 
175
 
176
  history = StreamlitChatMessageHistory(key=CHAT_MESSAGES)
177
+ prompt_template = ChatPromptTemplate.from_template(
178
+ _get_prompt_template(
179
+ is_refinement=_is_it_refinement()
180
+ )
181
+ )
 
 
182
 
183
  # Since Streamlit app reloads at every interaction, display the chat history
184
  # from the save session state
185
  for msg in history.messages:
186
+ st.chat_message(msg.type).code(msg.content, language='json')
 
 
 
 
187
 
188
  if prompt := st.chat_input(
189
  placeholder=APP_TEXT['chat_placeholder'],
190
  max_chars=GlobalConfig.LLM_MODEL_MAX_INPUT_LENGTH
191
  ):
192
+ provider, llm_name = llm_helper.get_provider_model(llm_provider_to_use)
193
+
194
+ if not are_all_inputs_valid(prompt, provider, llm_name, api_key_token):
 
 
 
195
  return
196
 
197
  logger.info(
198
  'User input: %s | #characters: %d | LLM: %s',
199
+ prompt, len(prompt), llm_name
200
  )
201
  st.chat_message('user').write(prompt)
202
 
 
 
 
 
 
 
 
203
  if _is_it_refinement():
204
+ user_messages = _get_user_messages()
205
+ user_messages.append(prompt)
206
+ list_of_msgs = [
207
+ f'{idx + 1}. {msg}' for idx, msg in enumerate(user_messages)
208
+ ]
209
  formatted_template = prompt_template.format(
210
  **{
211
+ 'instructions': '\n'.join(list_of_msgs),
212
  'previous_content': _get_last_response(),
213
  }
214
  )
215
  else:
216
+ formatted_template = prompt_template.format(**{'question': prompt})
 
 
 
 
217
 
218
  progress_bar = st.progress(0, 'Preparing to call LLM...')
219
  response = ''
220
 
221
  try:
222
+ llm = llm_helper.get_langchain_llm(
223
+ provider=provider,
224
+ model=llm_name,
225
+ max_new_tokens=GlobalConfig.VALID_MODELS[llm_provider_to_use]['max_new_tokens'],
226
+ api_key=api_key_token.strip(),
227
+ )
228
+
229
+ if not llm:
230
+ handle_error(
231
+ 'Failed to create an LLM instance! Make sure that you have selected the'
232
+ ' correct model from the dropdown list and have provided correct API key'
233
+ ' or access token.',
234
+ False
235
  )
236
+ return
237
+
238
+ for _ in llm.stream(formatted_template):
239
+ response += _
240
+
241
+ # Update the progress bar with an approx progress percentage
242
  progress_bar.progress(
243
+ min(
244
+ len(response) / GlobalConfig.VALID_MODELS[
245
+ llm_provider_to_use
246
+ ]['max_new_tokens'],
247
+ 0.95
248
+ ),
249
  text='Streaming content...this might take a while...'
250
  )
251
  except requests.exceptions.ConnectionError:
252
+ handle_error(
253
  'A connection error occurred while streaming content from the LLM endpoint.'
254
  ' Unfortunately, the slide deck cannot be generated. Please try again later.'
255
+ ' Alternatively, try selecting a different LLM from the dropdown list.',
256
+ True
257
  )
 
 
258
  return
259
  except huggingface_hub.errors.ValidationError as ve:
260
+ handle_error(
261
  f'An error occurred while trying to generate the content: {ve}'
262
+ '\nPlease try again with a significantly shorter input text.',
263
+ True
264
  )
 
 
265
  return
266
  except Exception as ex:
267
+ handle_error(
268
  f'An unexpected error occurred while generating the content: {ex}'
269
  '\nPlease try again later, possibly with different inputs.'
270
  ' Alternatively, try selecting a different LLM from the dropdown list.'
271
+ ' If you are using Cohere or Gemini models, make sure that you have provided'
272
+ ' a correct API key.',
273
+ True
274
  )
 
 
275
  return
276
 
277
  history.add_user_message(prompt)
 
280
  # The content has been generated as JSON
281
  # There maybe trailing ``` at the end of the response -- remove them
282
  # To be careful: ``` may be part of the content as well when code is generated
283
+ response = text_helper.get_clean_json(response)
 
284
  logger.info(
285
+ 'Cleaned JSON length: %d', len(response)
 
286
  )
 
287
 
288
  # Now create the PPT file
289
  progress_bar.progress(
290
  GlobalConfig.LLM_PROGRESS_MAX,
291
  text='Finding photos online and generating the slide deck...'
292
  )
 
293
  progress_bar.progress(1.0, text='Done!')
 
294
  st.chat_message('ai').code(response, language='json')
295
 
296
+ if path := generate_slide_deck(response):
297
  _display_download_button(path)
298
 
299
  logger.info(
 
314
  try:
315
  parsed_data = json5.loads(json_str)
316
  except ValueError:
317
+ handle_error(
318
+ 'Encountered error while parsing JSON...will fix it and retry',
319
+ True
 
 
320
  )
321
  try:
322
  parsed_data = json5.loads(text_helper.fix_malformed_json(json_str))
323
  except ValueError:
324
+ handle_error(
325
  'Encountered an error again while fixing JSON...'
326
  'the slide deck cannot be created, unfortunately ☹'
327
+ '\nPlease try again later.',
328
+ True
 
 
329
  )
 
330
  return None
331
  except RecursionError:
332
+ handle_error(
333
+ 'Encountered a recursion error while parsing JSON...'
334
  'the slide deck cannot be created, unfortunately ☹'
335
+ '\nPlease try again later.',
336
+ True
337
  )
 
 
338
  return None
339
  except Exception:
340
+ handle_error(
341
  'Encountered an error while parsing JSON...'
342
  'the slide deck cannot be created, unfortunately ☹'
343
+ '\nPlease try again later.',
344
+ True
345
  )
 
 
 
 
346
  return None
347
 
348
  if DOWNLOAD_FILE_KEY in st.session_state:
global_config.py CHANGED
@@ -17,16 +17,39 @@ class GlobalConfig:
17
  A data class holding the configurations.
18
  """
19
 
20
- HF_MODELS = {
21
- 'mistralai/Mistral-7B-Instruct-v0.2': {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  'description': 'faster, shorter',
23
- 'max_new_tokens': 8192
 
24
  },
25
- 'mistralai/Mistral-Nemo-Instruct-2407': {
26
  'description': 'longer response',
27
- 'max_new_tokens': 12228
 
28
  },
29
  }
 
 
 
 
 
 
 
30
  LLM_MODEL_TEMPERATURE = 0.2
31
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
32
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
 
17
  A data class holding the configurations.
18
  """
19
 
20
+ PROVIDER_COHERE = 'co'
21
+ PROVIDER_GOOGLE_GEMINI = 'gg'
22
+ PROVIDER_HUGGING_FACE = 'hf'
23
+ VALID_PROVIDERS = {PROVIDER_COHERE, PROVIDER_GOOGLE_GEMINI, PROVIDER_HUGGING_FACE}
24
+ VALID_MODELS = {
25
+ '[co]command-r-08-2024': {
26
+ 'description': 'simpler, slower',
27
+ 'max_new_tokens': 4096,
28
+ 'paid': True,
29
+ },
30
+ '[gg]gemini-1.5-flash-002': {
31
+ 'description': 'faster response',
32
+ 'max_new_tokens': 8192,
33
+ 'paid': True,
34
+ },
35
+ '[hf]mistralai/Mistral-7B-Instruct-v0.2': {
36
  'description': 'faster, shorter',
37
+ 'max_new_tokens': 8192,
38
+ 'paid': False,
39
  },
40
+ '[hf]mistralai/Mistral-Nemo-Instruct-2407': {
41
  'description': 'longer response',
42
+ 'max_new_tokens': 10240,
43
+ 'paid': False,
44
  },
45
  }
46
+ LLM_PROVIDER_HELP = (
47
+ 'LLM provider codes:\n\n'
48
+ '- **[co]**: Cohere\n'
49
+ '- **[gg]**: Google Gemini API\n'
50
+ '- **[hf]**: Hugging Face Inference Endpoint\n'
51
+ )
52
+ DEFAULT_MODEL_INDEX = 2
53
  LLM_MODEL_TEMPERATURE = 0.2
54
  LLM_MODEL_MIN_OUTPUT_LENGTH = 100
55
  LLM_MODEL_MAX_INPUT_LENGTH = 400 # characters
helpers/llm_helper.py CHANGED
@@ -1,18 +1,28 @@
 
 
 
1
  import logging
 
 
 
 
2
  import requests
3
  from requests.adapters import HTTPAdapter
4
  from urllib3.util import Retry
 
5
 
6
- from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
7
- from langchain_core.language_models import LLM
8
 
9
  from global_config import GlobalConfig
10
 
11
 
12
- HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
 
13
  REQUEST_TIMEOUT = 35
14
 
15
  logger = logging.getLogger(__name__)
 
 
16
 
17
  retries = Retry(
18
  total=5,
@@ -27,101 +37,123 @@ http_session.mount('https://', adapter)
27
  http_session.mount('http://', adapter)
28
 
29
 
30
- def get_hf_endpoint(repo_id: str, max_new_tokens: int) -> LLM:
31
  """
32
- Get an LLM via the HuggingFaceEndpoint of LangChain.
33
 
34
- :param repo_id: The model name.
35
- :param max_new_tokens: The max new tokens to generate.
36
- :return: The HF LLM inference endpoint.
37
  """
38
 
39
- logger.debug('Getting LLM via HF endpoint: %s', repo_id)
40
-
41
- return HuggingFaceEndpoint(
42
- repo_id=repo_id,
43
- max_new_tokens=max_new_tokens,
44
- top_k=40,
45
- top_p=0.95,
46
- temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
47
- repetition_penalty=1.03,
48
- streaming=True,
49
- huggingfacehub_api_token=GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
50
- return_full_text=False,
51
- stop_sequences=['</s>'],
52
- )
53
-
54
-
55
- # def hf_api_query(payload: dict) -> dict:
56
- # """
57
- # Invoke HF inference end-point API.
58
- #
59
- # :param payload: The prompt for the LLM and related parameters.
60
- # :return: The output from the LLM.
61
- # """
62
- #
63
- # try:
64
- # response = http_session.post(
65
- # HF_API_URL,
66
- # headers=HF_API_HEADERS,
67
- # json=payload,
68
- # timeout=REQUEST_TIMEOUT
69
- # )
70
- # result = response.json()
71
- # except requests.exceptions.Timeout as te:
72
- # logger.error('*** Error: hf_api_query timeout! %s', str(te))
73
- # result = []
74
- #
75
- # return result
76
-
77
-
78
- # def generate_slides_content(topic: str) -> str:
79
- # """
80
- # Generate the outline/contents of slides for a presentation on a given topic.
81
- #
82
- # :param topic: Topic on which slides are to be generated.
83
- # :return: The content in JSON format.
84
- # """
85
- #
86
- # with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
87
- # template_txt = in_file.read().strip()
88
- # template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)
89
- #
90
- # output = hf_api_query({
91
- # 'inputs': template_txt,
92
- # 'parameters': {
93
- # 'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
94
- # 'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
95
- # 'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
96
- # 'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
97
- # 'num_return_sequences': 1,
98
- # 'return_full_text': False,
99
- # # "repetition_penalty": 0.0001
100
- # },
101
- # 'options': {
102
- # 'wait_for_model': True,
103
- # 'use_cache': True
104
- # }
105
- # })
106
- #
107
- # output = output[0]['generated_text'].strip()
108
- # # output = output[len(template_txt):]
109
- #
110
- # json_end_idx = output.rfind('```')
111
- # if json_end_idx != -1:
112
- # # logging.debug(f'{json_end_idx=}')
113
- # output = output[:json_end_idx]
114
- #
115
- # logger.debug('generate_slides_content: output: %s', output)
116
- #
117
- # return output
118
 
 
 
 
 
119
 
120
- if __name__ == '__main__':
121
- # results = get_related_websites('5G AI WiFi 6')
122
- #
123
- # for a_result in results.results:
124
- # print(a_result.title, a_result.url, a_result.extract)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
- # get_ai_image('A talk on AI, covering pros and cons')
127
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Helper functions to access LLMs.
3
+ """
4
  import logging
5
+ import re
6
+ import sys
7
+ from typing import Tuple, Union
8
+
9
  import requests
10
  from requests.adapters import HTTPAdapter
11
  from urllib3.util import Retry
12
+ from langchain_core.language_models import BaseLLM
13
 
14
+ sys.path.append('..')
 
15
 
16
  from global_config import GlobalConfig
17
 
18
 
19
+ LLM_PROVIDER_MODEL_REGEX = re.compile(r'\[(.*?)\](.*)')
20
+ HF_API_HEADERS = {'Authorization': f'Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}'}
21
  REQUEST_TIMEOUT = 35
22
 
23
  logger = logging.getLogger(__name__)
24
+ logging.getLogger('httpx').setLevel(logging.WARNING)
25
+ logging.getLogger('httpcore').setLevel(logging.WARNING)
26
 
27
  retries = Retry(
28
  total=5,
 
37
  http_session.mount('http://', adapter)
38
 
39
 
40
+ def get_provider_model(provider_model: str) -> Tuple[str, str]:
41
  """
42
+ Parse and get LLM provider and model name from strings like `[provider]model/name-version`.
43
 
44
+ :param provider_model: The provider, model name string from `GlobalConfig`.
45
+ :return: The provider and the model name.
 
46
  """
47
 
48
+ match = LLM_PROVIDER_MODEL_REGEX.match(provider_model)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ if match:
51
+ inside_brackets = match.group(1)
52
+ outside_brackets = match.group(2)
53
+ return inside_brackets, outside_brackets
54
 
55
+ return '', ''
56
+
57
+
58
+ def is_valid_llm_provider_model(provider: str, model: str, api_key: str) -> bool:
59
+ """
60
+ Verify whether LLM settings are proper.
61
+ This function does not verify whether `api_key` is correct. It only confirms that the key has
62
+ at least five characters. Key verification is done when the LLM is created.
63
+
64
+ :param provider: Name of the LLM provider.
65
+ :param model: Name of the model.
66
+ :param api_key: The API key or access token.
67
+ :return: `True` if the settings "look" OK; `False` otherwise.
68
+ """
69
+
70
+ if not provider or not model or provider not in GlobalConfig.VALID_PROVIDERS:
71
+ return False
72
+
73
+ if provider in [GlobalConfig.PROVIDER_GOOGLE_GEMINI, ]:
74
+ if not api_key or len(api_key) < 5:
75
+ return False
76
+
77
+ return True
78
 
79
+
80
+ def get_langchain_llm(
81
+ provider: str,
82
+ model: str,
83
+ max_new_tokens: int,
84
+ api_key: str = ''
85
+ ) -> Union[BaseLLM, None]:
86
+ """
87
+ Get an LLM based on the provider and model specified.
88
+
89
+ :param provider: The LLM provider. Valid values are `hf` for Hugging Face.
90
+ :param model: The name of the LLM.
91
+ :param max_new_tokens: The maximum number of tokens to generate.
92
+ :param api_key: API key or access token to use.
93
+ :return: An instance of the LLM or `None` in case of any error.
94
+ """
95
+
96
+ if provider == GlobalConfig.PROVIDER_HUGGING_FACE:
97
+ from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
98
+
99
+ logger.debug('Getting LLM via HF endpoint: %s', model)
100
+ return HuggingFaceEndpoint(
101
+ repo_id=model,
102
+ max_new_tokens=max_new_tokens,
103
+ top_k=40,
104
+ top_p=0.95,
105
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
106
+ repetition_penalty=1.03,
107
+ streaming=True,
108
+ huggingfacehub_api_token=api_key or GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
109
+ return_full_text=False,
110
+ stop_sequences=['</s>'],
111
+ )
112
+
113
+ if provider == GlobalConfig.PROVIDER_GOOGLE_GEMINI:
114
+ from google.generativeai.types.safety_types import HarmBlockThreshold, HarmCategory
115
+ from langchain_google_genai import GoogleGenerativeAI
116
+
117
+ logger.debug('Getting LLM via Google Gemini: %s', model)
118
+ return GoogleGenerativeAI(
119
+ model=model,
120
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
121
+ max_tokens=max_new_tokens,
122
+ timeout=None,
123
+ max_retries=2,
124
+ google_api_key=api_key,
125
+ safety_settings={
126
+ HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT:
127
+ HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
128
+ HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
129
+ HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
130
+ HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT:
131
+ HarmBlockThreshold.BLOCK_LOW_AND_ABOVE
132
+ }
133
+ )
134
+
135
+ if provider == GlobalConfig.PROVIDER_COHERE:
136
+ from langchain_cohere.llms import Cohere
137
+
138
+ logger.debug('Getting LLM via Cohere: %s', model)
139
+ return Cohere(
140
+ temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
141
+ max_tokens=max_new_tokens,
142
+ timeout_seconds=None,
143
+ max_retries=2,
144
+ cohere_api_key=api_key,
145
+ streaming=True,
146
+ )
147
+
148
+ return None
149
+
150
+
151
+ if __name__ == '__main__':
152
+ inputs = [
153
+ '[co]Cohere',
154
+ '[hf]mistralai/Mistral-7B-Instruct-v0.2',
155
+ '[gg]gemini-1.5-flash-002'
156
+ ]
157
+
158
+ for text in inputs:
159
+ print(get_provider_model(text))
helpers/pptx_helper.py CHANGED
@@ -115,37 +115,42 @@ def generate_powerpoint_presentation(
115
 
116
  # Add content in a loop
117
  for a_slide in parsed_data['slides']:
118
- is_processing_done = _handle_icons_ideas(
119
- presentation=presentation,
120
- slide_json=a_slide,
121
- slide_width_inch=slide_width_inch,
122
- slide_height_inch=slide_height_inch
123
- )
124
-
125
- if not is_processing_done:
126
- is_processing_done = _handle_double_col_layout(
127
  presentation=presentation,
128
  slide_json=a_slide,
129
  slide_width_inch=slide_width_inch,
130
  slide_height_inch=slide_height_inch
131
  )
132
 
133
- if not is_processing_done:
134
- is_processing_done = _handle_step_by_step_process(
135
- presentation=presentation,
136
- slide_json=a_slide,
137
- slide_width_inch=slide_width_inch,
138
- slide_height_inch=slide_height_inch
139
- )
140
 
141
- if not is_processing_done:
142
- _handle_default_display(
143
- presentation=presentation,
144
- slide_json=a_slide,
145
- slide_width_inch=slide_width_inch,
146
- slide_height_inch=slide_height_inch
 
 
 
 
 
 
 
 
147
  )
148
 
 
 
 
 
149
  # The thank-you slide
150
  last_slide_layout = presentation.slide_layouts[0]
151
  slide = presentation.slides.add_slide(last_slide_layout)
 
115
 
116
  # Add content in a loop
117
  for a_slide in parsed_data['slides']:
118
+ try:
119
+ is_processing_done = _handle_icons_ideas(
 
 
 
 
 
 
 
120
  presentation=presentation,
121
  slide_json=a_slide,
122
  slide_width_inch=slide_width_inch,
123
  slide_height_inch=slide_height_inch
124
  )
125
 
126
+ if not is_processing_done:
127
+ is_processing_done = _handle_double_col_layout(
128
+ presentation=presentation,
129
+ slide_json=a_slide,
130
+ slide_width_inch=slide_width_inch,
131
+ slide_height_inch=slide_height_inch
132
+ )
133
 
134
+ if not is_processing_done:
135
+ is_processing_done = _handle_step_by_step_process(
136
+ presentation=presentation,
137
+ slide_json=a_slide,
138
+ slide_width_inch=slide_width_inch,
139
+ slide_height_inch=slide_height_inch
140
+ )
141
+
142
+ if not is_processing_done:
143
+ _handle_default_display(
144
+ presentation=presentation,
145
+ slide_json=a_slide,
146
+ slide_width_inch=slide_width_inch,
147
+ slide_height_inch=slide_height_inch
148
  )
149
 
150
+ except Exception:
151
+ # In case of any unforeseen error, try to salvage what is available
152
+ continue
153
+
154
  # The thank-you slide
155
  last_slide_layout = presentation.slide_layouts[0]
156
  slide = presentation.slides.add_slide(last_slide_layout)
helpers/text_helper.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import json_repair as jr
2
 
3
 
@@ -17,28 +20,19 @@ def is_valid_prompt(prompt: str) -> bool:
17
 
18
  def get_clean_json(json_str: str) -> str:
19
  """
20
- Attempt to clean a JSON response string from the LLM by removing the trailing ```
21
- and any text beyond that.
22
  CAUTION: May not be always accurate.
23
 
24
  :param json_str: The input string in JSON format.
25
  :return: The "cleaned" JSON string.
26
  """
27
 
28
- # An example of response containing JSON and other text:
29
- # {
30
- # "title": "AI and the Future: A Transformative Journey",
31
- # "slides": [
32
- # ...
33
- # ]
34
- # } <<---- This is end of valid JSON content
35
- # ```
36
- #
37
- # ```vbnet
38
- # Please note that the JSON output is in valid format but the content of the "Role of GPUs in AI" slide is just an example and may not be factually accurate. For accurate information, you should consult relevant resources and update the content accordingly.
39
- # ```
40
  response_cleaned = json_str
41
 
 
 
 
42
  while True:
43
  idx = json_str.rfind('```') # -1 on failure
44
 
@@ -46,7 +40,7 @@ def get_clean_json(json_str: str) -> str:
46
  break
47
 
48
  # In the ideal scenario, the character before the last ``` should be
49
- # a new line or a closing bracket }
50
  prev_char = json_str[idx - 1]
51
 
52
  if (prev_char == '}') or (prev_char == '\n' and json_str[idx - 2] == '}'):
@@ -69,13 +63,13 @@ def fix_malformed_json(json_str: str) -> str:
69
 
70
 
71
  if __name__ == '__main__':
72
- json1 = '''{
73
  "key": "value"
74
  }
75
  '''
76
- json2 = '''["Reason": "Regular updates help protect against known vulnerabilities."]'''
77
- json3 = '''["Reason" Regular updates help protect against known vulnerabilities."]'''
78
- json4 = '''
79
  {"bullet_points": [
80
  ">> Write without stopping or editing",
81
  >> Set daily writing goals and stick to them,
@@ -83,7 +77,7 @@ if __name__ == '__main__':
83
  ],}
84
  '''
85
 
86
- print(fix_malformed_json(json1))
87
- print(fix_malformed_json(json2))
88
- print(fix_malformed_json(json3))
89
- print(fix_malformed_json(json4))
 
1
+ """
2
+ Utility functions to help with text processing.
3
+ """
4
  import json_repair as jr
5
 
6
 
 
20
 
21
  def get_clean_json(json_str: str) -> str:
22
  """
23
+ Attempt to clean a JSON response string from the LLM by removing ```json at the beginning and
24
+ trailing ``` and any text beyond that.
25
  CAUTION: May not be always accurate.
26
 
27
  :param json_str: The input string in JSON format.
28
  :return: The "cleaned" JSON string.
29
  """
30
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  response_cleaned = json_str
32
 
33
+ if json_str.startswith('```json'):
34
+ json_str = json_str[7:]
35
+
36
  while True:
37
  idx = json_str.rfind('```') # -1 on failure
38
 
 
40
  break
41
 
42
  # In the ideal scenario, the character before the last ``` should be
43
+ # a new line or a closing bracket
44
  prev_char = json_str[idx - 1]
45
 
46
  if (prev_char == '}') or (prev_char == '\n' and json_str[idx - 2] == '}'):
 
63
 
64
 
65
  if __name__ == '__main__':
66
+ JSON1 = '''{
67
  "key": "value"
68
  }
69
  '''
70
+ JSON2 = '''["Reason": "Regular updates help protect against known vulnerabilities."]'''
71
+ JSON3 = '''["Reason" Regular updates help protect against known vulnerabilities."]'''
72
+ JSON4 = '''
73
  {"bullet_points": [
74
  ">> Write without stopping or editing",
75
  >> Set daily writing goals and stick to them,
 
77
  ],}
78
  '''
79
 
80
+ print(fix_malformed_json(JSON1))
81
+ print(fix_malformed_json(JSON2))
82
+ print(fix_malformed_json(JSON3))
83
+ print(fix_malformed_json(JSON4))
requirements.txt CHANGED
@@ -10,6 +10,8 @@ pydantic==2.9.1
10
  langchain~=0.3.7
11
  langchain-core~=0.3.0
12
  langchain-community==0.3.0
 
 
13
  streamlit~=1.38.0
14
 
15
  python-pptx
@@ -19,7 +21,6 @@ requests~=2.32.3
19
 
20
  transformers~=4.44.0
21
  torch==2.4.0
22
- langchain-community
23
 
24
  urllib3~=2.2.1
25
  lxml~=4.9.3
 
10
  langchain~=0.3.7
11
  langchain-core~=0.3.0
12
  langchain-community==0.3.0
13
+ langchain-google-genai==2.0.6
14
+ langchain-cohere==0.3.3
15
  streamlit~=1.38.0
16
 
17
  python-pptx
 
21
 
22
  transformers~=4.44.0
23
  torch==2.4.0
 
24
 
25
  urllib3~=2.2.1
26
  lxml~=4.9.3
strings.json CHANGED
@@ -33,5 +33,6 @@
33
  "Looks like you have a looming deadline. Can I help you get started with your slide deck?",
34
  "Hello! What topic do you have on your mind today?"
35
  ],
36
- "chat_placeholder": "Write the topic or instructions here"
 
37
  }
 
33
  "Looks like you have a looming deadline. Can I help you get started with your slide deck?",
34
  "Hello! What topic do you have on your mind today?"
35
  ],
36
+ "chat_placeholder": "Write the topic or instructions here",
37
+ "like_feedback": "If you like SlideDeck AI, please consider leaving a heart ❤\uFE0F on the [Hugging Face Space](https://huggingface.co/spaces/barunsaha/slide-deck-ai/) or a star ⭐ on [GitHub](https://github.com/barun-saha/slide-deck-ai). Your [feedback](https://forms.gle/JECFBGhjvSj7moBx9) is appreciated."
38
  }