Spaces:
Running
Running
Upload app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
from pydub import AudioSegment
|
3 |
-
import google.generativeai as genai
|
4 |
-
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
5 |
import json
|
6 |
import uuid
|
7 |
import io
|
@@ -13,6 +11,11 @@ import os
|
|
13 |
import time
|
14 |
from typing import List, Dict, Tuple
|
15 |
import openai
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
class PodcastGenerator:
|
18 |
def __init__(self):
|
@@ -253,12 +256,10 @@ class PodcastGenerator:
|
|
253 |
user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
|
254 |
|
255 |
# 配置 SambaNova API client
|
256 |
-
if api_key:
|
257 |
-
|
258 |
-
else:
|
259 |
-
openai.api_key = os.getenv("YOUR_API_TOKEN")
|
260 |
client = openai.OpenAI(
|
261 |
-
api_key=
|
262 |
base_url="https://api.sambanova.ai/v1",
|
263 |
)
|
264 |
|
@@ -270,24 +271,36 @@ class PodcastGenerator:
|
|
270 |
{"role": "system", "content": system_prompt},
|
271 |
{"role": "user", "content": user_prompt}
|
272 |
],
|
273 |
-
temperature=
|
274 |
max_tokens=8192
|
275 |
)
|
276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
277 |
except Exception as e:
|
|
|
278 |
# 處理可能的錯誤
|
279 |
if "API key not valid" in str(e):
|
280 |
raise gr.Error("Invalid API key. Please provide a valid SambaNova API key.")
|
281 |
elif "rate limit" in str(e).lower():
|
282 |
raise gr.Error("Rate limit exceeded for the API key. Please try again later or provide your own SambaNova API key.")
|
283 |
else:
|
284 |
-
raise gr.Error(f"Failed to generate podcast script: {e}")
|
285 |
|
286 |
# 列印生成的Podcast指令碼
|
287 |
print(f"Generated podcast script:\n{generated_text}")
|
288 |
|
289 |
-
#
|
290 |
-
|
|
|
|
|
|
|
|
|
291 |
|
292 |
async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
|
293 |
"""
|
|
|
1 |
import gradio as gr
|
2 |
from pydub import AudioSegment
|
|
|
|
|
3 |
import json
|
4 |
import uuid
|
5 |
import io
|
|
|
11 |
import time
|
12 |
from typing import List, Dict, Tuple
|
13 |
import openai
|
14 |
+
import logging
|
15 |
+
|
16 |
+
# At the beginning of your script, set up logging
|
17 |
+
logging.basicConfig(level=logging.INFO)
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
|
20 |
class PodcastGenerator:
|
21 |
def __init__(self):
|
|
|
256 |
user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
|
257 |
|
258 |
# 配置 SambaNova API client
|
259 |
+
if not api_key:
|
260 |
+
api_key = os.getenv("YOUR_API_TOKEN")
|
|
|
|
|
261 |
client = openai.OpenAI(
|
262 |
+
api_key=api_key,
|
263 |
base_url="https://api.sambanova.ai/v1",
|
264 |
)
|
265 |
|
|
|
271 |
{"role": "system", "content": system_prompt},
|
272 |
{"role": "user", "content": user_prompt}
|
273 |
],
|
274 |
+
temperature=1,
|
275 |
max_tokens=8192
|
276 |
)
|
277 |
+
logger.info(f"API Response: {response}")
|
278 |
+
|
279 |
+
if response.choices and len(response.choices) > 0:
|
280 |
+
generated_text = response.choices[0].message.content
|
281 |
+
else:
|
282 |
+
logger.warning("No content generated from the API")
|
283 |
+
raise ValueError("No content generated from the API")
|
284 |
+
|
285 |
except Exception as e:
|
286 |
+
logger.error(f"Error generating script: {str(e)}")
|
287 |
# 處理可能的錯誤
|
288 |
if "API key not valid" in str(e):
|
289 |
raise gr.Error("Invalid API key. Please provide a valid SambaNova API key.")
|
290 |
elif "rate limit" in str(e).lower():
|
291 |
raise gr.Error("Rate limit exceeded for the API key. Please try again later or provide your own SambaNova API key.")
|
292 |
else:
|
293 |
+
raise gr.Error(f"Failed to generate podcast script: {str(e)}")
|
294 |
|
295 |
# 列印生成的Podcast指令碼
|
296 |
print(f"Generated podcast script:\n{generated_text}")
|
297 |
|
298 |
+
# 嘗試解析JSON,如果失敗則返回原始文本
|
299 |
+
try:
|
300 |
+
return json.loads(generated_text)
|
301 |
+
except json.JSONDecodeError:
|
302 |
+
print("Warning: Generated text is not valid JSON. Returning raw text.")
|
303 |
+
return {"raw_text": generated_text}
|
304 |
|
305 |
async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
|
306 |
"""
|