Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -1,7 +1,5 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
from pydub import AudioSegment
|
| 3 |
-
import google.generativeai as genai
|
| 4 |
-
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
| 5 |
import json
|
| 6 |
import uuid
|
| 7 |
import io
|
|
@@ -13,6 +11,11 @@ import os
|
|
| 13 |
import time
|
| 14 |
from typing import List, Dict, Tuple
|
| 15 |
import openai
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
class PodcastGenerator:
|
| 18 |
def __init__(self):
|
|
@@ -253,12 +256,10 @@ class PodcastGenerator:
|
|
| 253 |
user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
|
| 254 |
|
| 255 |
# 配置 SambaNova API client
|
| 256 |
-
if api_key:
|
| 257 |
-
|
| 258 |
-
else:
|
| 259 |
-
openai.api_key = os.getenv("YOUR_API_TOKEN")
|
| 260 |
client = openai.OpenAI(
|
| 261 |
-
api_key=
|
| 262 |
base_url="https://api.sambanova.ai/v1",
|
| 263 |
)
|
| 264 |
|
|
@@ -270,24 +271,36 @@ class PodcastGenerator:
|
|
| 270 |
{"role": "system", "content": system_prompt},
|
| 271 |
{"role": "user", "content": user_prompt}
|
| 272 |
],
|
| 273 |
-
temperature=
|
| 274 |
max_tokens=8192
|
| 275 |
)
|
| 276 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 277 |
except Exception as e:
|
|
|
|
| 278 |
# 處理可能的錯誤
|
| 279 |
if "API key not valid" in str(e):
|
| 280 |
raise gr.Error("Invalid API key. Please provide a valid SambaNova API key.")
|
| 281 |
elif "rate limit" in str(e).lower():
|
| 282 |
raise gr.Error("Rate limit exceeded for the API key. Please try again later or provide your own SambaNova API key.")
|
| 283 |
else:
|
| 284 |
-
raise gr.Error(f"Failed to generate podcast script: {e}")
|
| 285 |
|
| 286 |
# 列印生成的Podcast指令碼
|
| 287 |
print(f"Generated podcast script:\n{generated_text}")
|
| 288 |
|
| 289 |
-
#
|
| 290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
|
| 293 |
"""
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
from pydub import AudioSegment
|
|
|
|
|
|
|
| 3 |
import json
|
| 4 |
import uuid
|
| 5 |
import io
|
|
|
|
| 11 |
import time
|
| 12 |
from typing import List, Dict, Tuple
|
| 13 |
import openai
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
# At the beginning of your script, set up logging
|
| 17 |
+
logging.basicConfig(level=logging.INFO)
|
| 18 |
+
logger = logging.getLogger(__name__)
|
| 19 |
|
| 20 |
class PodcastGenerator:
|
| 21 |
def __init__(self):
|
|
|
|
| 256 |
user_prompt = f"Please generate a podcast script based on the following user input:\n{prompt}"
|
| 257 |
|
| 258 |
# 配置 SambaNova API client
|
| 259 |
+
if not api_key:
|
| 260 |
+
api_key = os.getenv("YOUR_API_TOKEN")
|
|
|
|
|
|
|
| 261 |
client = openai.OpenAI(
|
| 262 |
+
api_key=api_key,
|
| 263 |
base_url="https://api.sambanova.ai/v1",
|
| 264 |
)
|
| 265 |
|
|
|
|
| 271 |
{"role": "system", "content": system_prompt},
|
| 272 |
{"role": "user", "content": user_prompt}
|
| 273 |
],
|
| 274 |
+
temperature=1,
|
| 275 |
max_tokens=8192
|
| 276 |
)
|
| 277 |
+
logger.info(f"API Response: {response}")
|
| 278 |
+
|
| 279 |
+
if response.choices and len(response.choices) > 0:
|
| 280 |
+
generated_text = response.choices[0].message.content
|
| 281 |
+
else:
|
| 282 |
+
logger.warning("No content generated from the API")
|
| 283 |
+
raise ValueError("No content generated from the API")
|
| 284 |
+
|
| 285 |
except Exception as e:
|
| 286 |
+
logger.error(f"Error generating script: {str(e)}")
|
| 287 |
# 處理可能的錯誤
|
| 288 |
if "API key not valid" in str(e):
|
| 289 |
raise gr.Error("Invalid API key. Please provide a valid SambaNova API key.")
|
| 290 |
elif "rate limit" in str(e).lower():
|
| 291 |
raise gr.Error("Rate limit exceeded for the API key. Please try again later or provide your own SambaNova API key.")
|
| 292 |
else:
|
| 293 |
+
raise gr.Error(f"Failed to generate podcast script: {str(e)}")
|
| 294 |
|
| 295 |
# 列印生成的Podcast指令碼
|
| 296 |
print(f"Generated podcast script:\n{generated_text}")
|
| 297 |
|
| 298 |
+
# 嘗試解析JSON,如果失敗則返回原始文本
|
| 299 |
+
try:
|
| 300 |
+
return json.loads(generated_text)
|
| 301 |
+
except json.JSONDecodeError:
|
| 302 |
+
print("Warning: Generated text is not valid JSON. Returning raw text.")
|
| 303 |
+
return {"raw_text": generated_text}
|
| 304 |
|
| 305 |
async def tts_generate(self, text: str, speaker: int, speaker1: str, speaker2: str) -> str:
|
| 306 |
"""
|