Jiangxz01 commited on
Commit
7635190
·
verified ·
1 Parent(s): 10b2a46

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -12
app.py CHANGED
@@ -1,7 +1,5 @@
1
  import gradio as gr
2
  from pydub import AudioSegment
3
- import google.generativeai as genai
4
- from google.generativeai.types import HarmCategory, HarmBlockThreshold
5
  import json
6
  import uuid
7
  import io
@@ -13,6 +11,11 @@ import os
13
  import time
14
  from typing import List, Dict, Tuple
15
  import openai
 
 
 
 
 
16
 
17
  class PodcastGenerator:
18
  def __init__(self):
@@ -253,12 +256,10 @@ class PodcastGenerator:
253
  user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
254
 
255
  # 配置 SambaNova API client
256
- if api_key:
257
- openai.api_key = api_key
258
- else:
259
- openai.api_key = os.getenv("YOUR_API_TOKEN")
260
  client = openai.OpenAI(
261
- api_key=openai.api_key,
262
  base_url="https://api.sambanova.ai/v1",
263
  )
264
 
@@ -270,24 +271,36 @@ class PodcastGenerator:
270
  {"role": "system", "content": system_prompt},
271
  {"role": "user", "content": user_prompt}
272
  ],
273
- temperature=0.7,
274
  max_tokens=8192
275
  )
276
- generated_text = response.choices[0].message.content
 
 
 
 
 
 
 
277
  except Exception as e:
 
278
  # 處理可能的錯誤
279
  if "API key not valid" in str(e):
280
  raise gr.Error("Invalid API key. Please provide a valid SambaNova API key.")
281
  elif "rate limit" in str(e).lower():
282
  raise gr.Error("Rate limit exceeded for the API key. Please try again later or provide your own SambaNova API key.")
283
  else:
284
- raise gr.Error(f"Failed to generate podcast script: {e}")
285
 
286
  # 列印生成的Podcast指令碼
287
  print(f"Generated podcast script:\n{generated_text}")
288
 
289
- # 返回解析後的JSON資料
290
- return json.loads(generated_text)
 
 
 
 
291
 
292
  async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
293
  """
 
1
  import gradio as gr
2
  from pydub import AudioSegment
 
 
3
  import json
4
  import uuid
5
  import io
 
11
  import time
12
  from typing import List, Dict, Tuple
13
  import openai
14
+ import logging
15
+
16
+ # At the beginning of your script, set up logging
17
+ logging.basicConfig(level=logging.INFO)
18
+ logger = logging.getLogger(__name__)
19
 
20
  class PodcastGenerator:
21
  def __init__(self):
 
256
  user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
257
 
258
  # 配置 SambaNova API client
259
+ if not api_key:
260
+ api_key = os.getenv("YOUR_API_TOKEN")
 
 
261
  client = openai.OpenAI(
262
+ api_key=api_key,
263
  base_url="https://api.sambanova.ai/v1",
264
  )
265
 
 
271
  {"role": "system", "content": system_prompt},
272
  {"role": "user", "content": user_prompt}
273
  ],
274
+ temperature=1,
275
  max_tokens=8192
276
  )
277
+ logger.info(f"API Response: {response}")
278
+
279
+ if response.choices and len(response.choices) > 0:
280
+ generated_text = response.choices[0].message.content
281
+ else:
282
+ logger.warning("No content generated from the API")
283
+ raise ValueError("No content generated from the API")
284
+
285
  except Exception as e:
286
+ logger.error(f"Error generating script: {str(e)}")
287
  # 處理可能的錯誤
288
  if "API key not valid" in str(e):
289
  raise gr.Error("Invalid API key. Please provide a valid SambaNova API key.")
290
  elif "rate limit" in str(e).lower():
291
  raise gr.Error("Rate limit exceeded for the API key. Please try again later or provide your own SambaNova API key.")
292
  else:
293
+ raise gr.Error(f"Failed to generate podcast script: {str(e)}")
294
 
295
  # 列印生成的Podcast指令碼
296
  print(f"Generated podcast script:\n{generated_text}")
297
 
298
+ # 嘗試解析JSON,如果失敗則返回原始文本
299
+ try:
300
+ return json.loads(generated_text)
301
+ except json.JSONDecodeError:
302
+ print("Warning: Generated text is not valid JSON. Returning raw text.")
303
+ return {"raw_text": generated_text}
304
 
305
  async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
306
  """