Jiangxz01 commited on
Commit
d77a3a5
·
verified ·
1 Parent(s): 1282b4e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -48
app.py CHANGED
@@ -266,48 +266,58 @@ class PodcastGenerator:
266
  base_url="https://api.sambanova.ai/v1",
267
  )
268
 
269
- # 嘗試生成內容
270
- try:
271
- # Calculate the available tokens for generation
272
- prompt_tokens = len(user_prompt.split()) # This is a rough estimate
273
- system_tokens = len(system_prompt.split()) # This is a rough estimate
274
- max_tokens = 4096 - prompt_tokens - system_tokens - 100 # 100 is a safety margin
275
-
276
- if max_tokens <= 0:
277
- return {"error": "Input prompt is too long. Please provide a shorter prompt."}
278
-
279
- logger.info(f"Sending request to SambaNova API with prompt: {user_prompt[:100]}...")
280
- response = client.chat.completions.create(
281
- model='Meta-Llama-3.1-405B-Instruct',
282
- messages=[
283
- {"role": "system", "content": system_prompt},
284
- {"role": "user", "content": user_prompt}
285
- ],
286
- temperature=1,
287
- max_tokens=max_tokens
288
- )
289
- logger.info(f"Received response from API: {response}")
290
-
291
- if hasattr(response, 'error'):
292
- logger.error(f"API returned an error: {response.error}")
293
- return {"error": f"API error: {response.error.get('message', 'Unknown error')}"}
294
-
295
- if response.choices and len(response.choices) > 0:
296
- generated_text = response.choices[0].message.content
297
- logger.info(f"Generated text: {generated_text[:100]}...")
298
- else:
299
- logger.warning("No content generated from the API")
300
- return {"error": "No content generated from the API"}
301
-
302
- except Exception as e:
303
- logger.error(f"Error generating script: {str(e)}")
304
- # 處理可能的錯誤
305
- if "API key not valid" in str(e):
306
- raise gr.Error("Invalid API key. Please provide a valid SambaNova API key.")
307
- elif "rate limit" in str(e).lower():
308
- raise gr.Error("Rate limit exceeded for the API key. Please try again later or provide your own SambaNova API key.")
309
- else:
310
- raise gr.Error(f"Failed to generate podcast script: {str(e)}")
 
 
 
 
 
 
 
 
 
 
311
 
312
  # 嘗試解析JSON,如果失敗則嘗試從原始文本中提取對話
313
  try:
@@ -520,12 +530,6 @@ async def process_input(input_text: str, input_file, language: str, speaker1: st
520
  gr.Error(f"Selected voices may not be compatible with the chosen language: {language}")
521
  return None
522
 
523
- # Check input text length
524
- max_input_length = 1000 # Adjust this value as needed
525
- if len(input_text) > max_input_length:
526
- gr.Error(f"Input text is too long. Please limit your input to {max_input_length} characters.")
527
- return None
528
-
529
  # 如果提供了輸入檔案,則從檔案中提取文字
530
  if input_file:
531
  input_text = await TextExtractor.extract_text(input_file.name)
 
266
  base_url="https://api.sambanova.ai/v1",
267
  )
268
 
269
+ async def generate_chunk(chunk: str) -> str:
270
+ try:
271
+ # Calculate the available tokens for generation
272
+ prompt_tokens = len(chunk.split()) # This is a rough estimate
273
+ system_tokens = len(system_prompt.split()) # This is a rough estimate
274
+ max_tokens = 4096 - prompt_tokens - system_tokens - 100 # 100 is a safety margin
275
+
276
+ if max_tokens <= 0:
277
+ return {"error": "Input chunk is too long. Please provide a shorter prompt."}
278
+
279
+ logger.info(f"Sending request to SambaNova API with prompt chunk: {chunk[:100]}...")
280
+ response = client.chat.completions.create(
281
+ model='Meta-Llama-3.1-405B-Instruct',
282
+ messages=[
283
+ {"role": "system", "content": system_prompt},
284
+ {"role": "user", "content": chunk}
285
+ ],
286
+ temperature=1,
287
+ max_tokens=max_tokens
288
+ )
289
+ logger.info(f"Received response from API: {response}")
290
+
291
+ if hasattr(response, 'error'):
292
+ logger.error(f"API returned an error: {response.error}")
293
+ return {"error": f"API error: {response.error.get('message', 'Unknown error')}"}
294
+
295
+ if response.choices and len(response.choices) > 0:
296
+ generated_text = response.choices[0].message.content
297
+ logger.info(f"Generated text: {generated_text[:100]}...")
298
+ return generated_text
299
+ else:
300
+ logger.warning("No content generated from the API")
301
+ return {"error": "No content generated from the API"}
302
+
303
+ except Exception as e:
304
+ logger.error(f"Error generating script chunk: {str(e)}")
305
+ return {"error": f"Failed to generate podcast script chunk: {str(e)}"}
306
+
307
+ # Split the prompt into chunks
308
+ chunk_size = 1000 # Adjust this value as needed
309
+ chunks = [prompt[i:i+chunk_size] for i in range(0, len(prompt), chunk_size)]
310
+
311
+ # Generate script for each chunk
312
+ generated_chunks = []
313
+ for chunk in chunks:
314
+ result = await generate_chunk(chunk)
315
+ if isinstance(result, dict) and "error" in result:
316
+ return result
317
+ generated_chunks.append(result)
318
+
319
+ # Combine generated chunks
320
+ generated_text = " ".join(generated_chunks)
321
 
322
  # 嘗試解析JSON,如果失敗則嘗試從原始文本中提取對話
323
  try:
 
530
  gr.Error(f"Selected voices may not be compatible with the chosen language: {language}")
531
  return None
532
 
 
 
 
 
 
 
533
  # 如果提供了輸入檔案,則從檔案中提取文字
534
  if input_file:
535
  input_text = await TextExtractor.extract_text(input_file.name)