Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, APIRouter | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import FileResponse | |
from fastapi.staticfiles import StaticFiles | |
from pydantic import BaseModel | |
from typing import List, Dict, Any, Optional | |
import os | |
import json | |
from workflow import create_workflow, run_workflow | |
import logging | |
from dotenv import load_dotenv | |
from langchain_openai import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.vectorstores import Qdrant | |
from langchain_openai.embeddings import OpenAIEmbeddings | |
from langchain_openai.chat_models import ChatOpenAI | |
from operator import itemgetter | |
from langchain.schema.output_parser import StrOutputParser | |
from langchain.schema.runnable import RunnablePassthrough | |
# Load environment variables | |
load_dotenv() | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize components | |
openai_api_key = os.getenv("OPENAI_API_KEY") | |
if not openai_api_key: | |
raise ValueError("OpenAI API key not configured") | |
# Initialize OpenAI components | |
chat_model = ChatOpenAI( | |
model_name="gpt-3.5-turbo", | |
temperature=0.7, | |
openai_api_key=openai_api_key | |
) | |
# Define Pydantic models | |
class ChatMessage(BaseModel): | |
content: str | |
context: Optional[Dict[str, Any]] = None | |
agent_type: Optional[str] = "believer" | |
class WorkflowResponse(BaseModel): | |
debate_history: List[Dict[str, str]] | |
supervisor_notes: List[str] | |
supervisor_chunks: List[Dict[str, List[str]]] | |
extractor_data: Dict[str, Any] | |
final_podcast: Dict[str, Any] | |
class PodcastChatRequest(BaseModel): | |
message: str | |
class PodcastChatResponse(BaseModel): | |
response: str | |
# Initialize FastAPI app | |
app = FastAPI() | |
# Create API router | |
api_router = APIRouter(prefix="/api") | |
# Configure CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["http://localhost:5173", "http://localhost:3000", "https://*.hf.space", "*"], | |
allow_credentials=True, | |
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD"], | |
allow_headers=["*"], | |
expose_headers=["Content-Type", "Content-Length"], | |
max_age=600, | |
) | |
# Configure storage directories | |
audio_dir = os.path.join(os.path.dirname(__file__), "audio_storage") | |
os.makedirs(audio_dir, exist_ok=True) | |
context_dir = os.path.join(os.path.dirname(__file__), "context_storage") | |
os.makedirs(context_dir, exist_ok=True) | |
# Add transcripts directory | |
transcripts_dir = os.path.join(os.path.dirname(__file__), "transcripts") | |
os.makedirs(transcripts_dir, exist_ok=True) | |
# Initialize empty transcripts file if it doesn't exist | |
transcripts_file = os.path.join(transcripts_dir, "podcasts.json") | |
if not os.path.exists(transcripts_file): | |
with open(transcripts_file, 'w') as f: | |
json.dump([], f) | |
# API Routes | |
async def chat(message: ChatMessage): | |
"""Process a chat message.""" | |
try: | |
# Get API key | |
tavily_api_key = os.getenv("TAVILY_API_KEY") | |
if not tavily_api_key: | |
logger.error("Tavily API key not found") | |
raise HTTPException(status_code=500, detail="Tavily API key not configured") | |
# Initialize the workflow | |
try: | |
workflow = create_workflow(tavily_api_key) | |
logger.info("Workflow created successfully") | |
except Exception as e: | |
logger.error(f"Error creating workflow: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error creating workflow: {str(e)}") | |
# Run the workflow with context | |
try: | |
result = await run_workflow( | |
workflow, | |
message.content, | |
agent_type=message.agent_type, | |
context=message.context | |
) | |
logger.info("Workflow completed successfully") | |
return result | |
except Exception as e: | |
logger.error(f"Error running workflow: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error running workflow: {str(e)}") | |
except Exception as e: | |
logger.error(f"Error in chat endpoint: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def list_audio_files(): | |
"""List all available audio files.""" | |
try: | |
files = os.listdir(audio_dir) | |
audio_files = [] | |
for file in files: | |
if file.endswith(('.mp3', '.wav')): | |
file_path = os.path.join(audio_dir, file) | |
audio_files.append({ | |
"filename": file, | |
"path": f"/audio-files/{file}", | |
"size": os.path.getsize(file_path) | |
}) | |
return audio_files if audio_files else [] | |
except Exception as e: | |
logger.error(f"Error listing audio files: {str(e)}") | |
return [] | |
async def get_audio_file(filename: str): | |
"""Get an audio file by filename.""" | |
try: | |
file_path = os.path.join(audio_dir, filename) | |
if not os.path.exists(file_path): | |
logger.error(f"Audio file not found: {filename}") | |
raise HTTPException(status_code=404, detail="File not found") | |
return FileResponse(file_path, media_type="audio/mpeg") | |
except Exception as e: | |
logger.error(f"Error serving audio file: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def delete_audio_file(filename: str): | |
"""Delete an audio file and its corresponding transcript.""" | |
try: | |
# Check if file exists before attempting deletion | |
file_path = os.path.join(audio_dir, filename) | |
if not os.path.exists(file_path): | |
logger.error(f"File not found for deletion: {filename}") | |
raise HTTPException(status_code=404, detail="File not found") | |
try: | |
# Delete the audio file first | |
os.remove(file_path) | |
logger.info(f"Deleted audio file: {filename}") | |
# Get all remaining audio files | |
audio_files = [f for f in os.listdir(audio_dir) if f.endswith(('.mp3', '.wav'))] | |
# Try to update transcripts if they exist | |
transcripts_file = os.path.join(os.path.dirname(__file__), "transcripts", "podcasts.json") | |
if os.path.exists(transcripts_file): | |
with open(transcripts_file, 'r') as f: | |
transcripts = json.load(f) | |
# Find the index of the deleted file in the original list | |
try: | |
podcast_id = audio_files.index(filename) + 1 | |
if len(transcripts) >= podcast_id: | |
transcripts.pop(podcast_id - 1) | |
with open(transcripts_file, 'w') as f: | |
json.dump(transcripts, f, indent=2) | |
logger.info(f"Updated transcripts after deletion") | |
except ValueError: | |
logger.warning(f"Could not find podcast ID for {filename} in transcripts") | |
return {"message": "File deleted successfully"} | |
except Exception as e: | |
logger.error(f"Error during file deletion process: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
except HTTPException as he: | |
raise he | |
except Exception as e: | |
logger.error(f"Error in delete_audio_file: {str(e)}") | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_podcast_context(podcast_id: str): | |
"""Get or generate context for a podcast.""" | |
try: | |
logger.info(f"Getting context for podcast {podcast_id}") | |
context_path = os.path.join(context_dir, f"{podcast_id}_context.json") | |
# If context exists, return it | |
if os.path.exists(context_path): | |
logger.info(f"Found existing context file at {context_path}") | |
with open(context_path, 'r') as f: | |
return json.load(f) | |
# If no context exists, we need to create it from the podcast content | |
logger.info("No existing context found, creating new context") | |
# Get the audio files to find the podcast filename | |
files = os.listdir(audio_dir) | |
logger.info(f"Found {len(files)} files in audio directory") | |
podcast_files = [f for f in files if f.endswith('.mp3')] | |
logger.info(f"Found {len(podcast_files)} podcast files: {podcast_files}") | |
if not podcast_files: | |
logger.error("No podcast files found") | |
raise HTTPException(status_code=404, detail="No podcast files found") | |
# Find the podcast file that matches this ID | |
try: | |
podcast_index = int(podcast_id) - 1 # Convert 1-based ID to 0-based index | |
if podcast_index < 0 or podcast_index >= len(podcast_files): | |
raise ValueError(f"Invalid podcast ID: {podcast_id}, total podcasts: {len(podcast_files)}") | |
podcast_filename = podcast_files[podcast_index] | |
logger.info(f"Selected podcast file: {podcast_filename}") | |
except (ValueError, IndexError) as e: | |
logger.error(f"Invalid podcast ID: {podcast_id}, Error: {str(e)}") | |
raise HTTPException(status_code=404, detail=f"Invalid podcast ID: {podcast_id}") | |
# Extract topic from filename | |
try: | |
topic = podcast_filename.split('-')[0].replace('_', ' ') | |
logger.info(f"Extracted topic: {topic}") | |
except Exception as e: | |
logger.error(f"Error extracting topic from filename: {podcast_filename}, Error: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error extracting topic from filename: {str(e)}") | |
# Initialize OpenAI chat model for content analysis | |
try: | |
chat_model = ChatOpenAI( | |
model_name="gpt-3.5-turbo", | |
temperature=0.3, | |
openai_api_key=openai_api_key | |
) | |
logger.info("Successfully initialized ChatOpenAI") | |
except Exception as e: | |
logger.error(f"Error initializing ChatOpenAI: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error initializing chat model: {str(e)}") | |
# Create prompt template for content analysis | |
prompt = ChatPromptTemplate.from_messages([ | |
("system", """You are an expert content analyzer. Your task is to: | |
1. Analyze the given topic and create balanced, factual content chunks about it | |
2. Generate two types of chunks: | |
- Believer chunks: Positive aspects, opportunities, and solutions related to the topic | |
- Skeptic chunks: Challenges, risks, and critical questions about the topic | |
3. Each chunk should be self-contained and focused on a single point | |
4. Keep chunks concise (2-3 sentences each) | |
5. Ensure all content is factual and balanced | |
Format your response as a JSON object with two arrays: | |
{{ | |
"believer_chunks": ["chunk1", "chunk2", ...], | |
"skeptic_chunks": ["chunk1", "chunk2", ...] | |
}}"""), | |
("human", "Create balanced content chunks about this topic: {topic}") | |
]) | |
# Generate content chunks | |
chain = prompt | chat_model | |
try: | |
logger.info(f"Generating content chunks for topic: {topic}") | |
response = await chain.ainvoke({ | |
"topic": topic | |
}) | |
logger.info("Successfully received response from OpenAI") | |
# Parse the response content as JSON | |
try: | |
content_chunks = json.loads(response.content) | |
logger.info(f"Successfully parsed response JSON with {len(content_chunks.get('believer_chunks', []))} believer chunks and {len(content_chunks.get('skeptic_chunks', []))} skeptic chunks") | |
except json.JSONDecodeError as e: | |
logger.error(f"Error parsing response JSON: {str(e)}, Response content: {response.content}") | |
raise HTTPException(status_code=500, detail=f"Error parsing content chunks: {str(e)}") | |
# Create the context object | |
context = { | |
"topic": topic, | |
"believer_chunks": content_chunks.get("believer_chunks", []), | |
"skeptic_chunks": content_chunks.get("skeptic_chunks", []) | |
} | |
# Save the context | |
try: | |
with open(context_path, 'w') as f: | |
json.dump(context, f) | |
logger.info(f"Saved new context to {context_path}") | |
except Exception as e: | |
logger.error(f"Error saving context file: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error saving context file: {str(e)}") | |
return context | |
except Exception as e: | |
logger.error(f"Error generating content chunks: {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Error generating content chunks: {str(e)}" | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in get_podcast_context: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def podcast_chat(podcast_id: str, request: PodcastChatRequest): | |
"""Handle chat messages for a specific podcast.""" | |
try: | |
logger.info(f"Processing chat message for podcast {podcast_id}") | |
logger.info(f"User message: {request.message}") | |
# Get list of audio files | |
audio_files = [f for f in os.listdir(audio_dir) if f.endswith('.mp3')] | |
logger.info(f"Found {len(audio_files)} audio files: {audio_files}") | |
# Convert podcast_id to zero-based index and get the filename | |
try: | |
podcast_index = int(podcast_id) - 1 | |
if podcast_index < 0 or podcast_index >= len(audio_files): | |
logger.error(f"Invalid podcast index: {podcast_index} (total files: {len(audio_files)})") | |
raise ValueError(f"Invalid podcast ID: {podcast_id}") | |
podcast_filename = audio_files[podcast_index] | |
logger.info(f"Found podcast file: {podcast_filename}") | |
except ValueError as e: | |
logger.error(f"Error converting podcast ID: {str(e)}") | |
raise HTTPException(status_code=404, detail=str(e)) | |
# Extract topic from filename | |
topic = podcast_filename.split('-')[0].replace('_', ' ') | |
logger.info(f"Extracted topic: {topic}") | |
# Path to transcripts file | |
transcripts_file = os.path.join(os.path.dirname(__file__), "transcripts", "podcasts.json") | |
# Check if transcripts file exists | |
if not os.path.exists(transcripts_file): | |
logger.error("Transcripts file not found") | |
raise HTTPException(status_code=404, detail="Transcripts file not found") | |
# Read transcripts | |
try: | |
with open(transcripts_file, 'r') as f: | |
transcripts = json.load(f) | |
logger.info(f"Loaded {len(transcripts)} transcripts") | |
logger.info(f"Available topics: {[t.get('topic', 'NO_TOPIC') for t in transcripts]}") | |
except json.JSONDecodeError as e: | |
logger.error(f"Error decoding transcripts file: {str(e)}") | |
raise HTTPException(status_code=500, detail="Error reading transcripts file") | |
# Find matching transcript by topic | |
podcast_transcript = None | |
for transcript in transcripts: | |
transcript_topic = transcript.get("topic", "").lower().strip() | |
if transcript_topic == topic.lower().strip(): | |
podcast_transcript = transcript.get("podcastScript") | |
logger.info(f"Found matching transcript for topic: {topic}") | |
break | |
if not podcast_transcript: | |
logger.error(f"No transcript found for topic: {topic}") | |
logger.error(f"Available topics: {[t.get('topic', 'NO_TOPIC') for t in transcripts]}") | |
raise HTTPException(status_code=404, detail=f"No transcript found for topic: {topic}") | |
logger.info(f"Found transcript for topic: {topic}") | |
logger.info(f"Full transcript length: {len(podcast_transcript)} characters") | |
logger.debug(f"Transcript preview: {podcast_transcript[:200]}...") | |
# Split text into chunks | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=100, | |
length_function=len, | |
separators=["\n\n", "\n", ". ", " ", ""] | |
) | |
# Use split_text for strings instead of split_documents | |
try: | |
logger.info("Starting text splitting process...") | |
chunks = text_splitter.split_text(podcast_transcript) | |
logger.info(f"Successfully split transcript into {len(chunks)} chunks") | |
# Log some sample chunks | |
logger.info("\nSample chunks:") | |
for i, chunk in enumerate(chunks[:3]): # Log first 3 chunks | |
logger.info(f"\nChunk {i+1}:") | |
logger.info("=" * 50) | |
logger.info(chunk) | |
logger.info("=" * 50) | |
if len(chunks) > 3: | |
logger.info(f"... and {len(chunks) - 3} more chunks") | |
except Exception as e: | |
logger.error(f"Error splitting text into chunks: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error splitting text: {str(e)}") | |
if not chunks: | |
logger.error("No content chunks found in transcript") | |
raise HTTPException(status_code=404, detail="No content chunks found in transcript") | |
# Validate chunk sizes | |
chunk_sizes = [len(chunk) for chunk in chunks] | |
logger.info(f"\nChunk size statistics:") | |
logger.info(f"Min chunk size: {min(chunk_sizes)} characters") | |
logger.info(f"Max chunk size: {max(chunk_sizes)} characters") | |
logger.info(f"Average chunk size: {sum(chunk_sizes)/len(chunk_sizes):.2f} characters") | |
# Initialize embedding model | |
embedding_model = OpenAIEmbeddings( | |
model="text-embedding-3-small", | |
openai_api_key=openai_api_key | |
) | |
# Create a unique collection name for this podcast | |
collection_name = f"podcast_{podcast_id}" | |
# Initialize Qdrant with local storage | |
vectorstore = Qdrant.from_texts( | |
texts=chunks, | |
embedding=embedding_model, | |
location=":memory:", # Use in-memory storage | |
collection_name=collection_name | |
) | |
logger.info(f"Created vector store for podcast {podcast_id}") | |
# Configure the retriever with search parameters | |
qdrant_retriever = vectorstore.as_retriever( | |
search_type="similarity", # Use simple similarity search | |
search_kwargs={ | |
"k": 8, # Increased from 5 to 8 chunks | |
"score_threshold": 0.05 # Lowered threshold further for more matches | |
} | |
) | |
base_rag_prompt_template = """\ | |
You are a helpful podcast assistant. Answer the user's question based on the provided context from the podcast transcript. | |
If the context contains relevant information, use it to answer the question. | |
If you can't find relevant information in the context to answer the question, say "I don't have enough information to answer that question." | |
Keep your responses concise and focused on the question. | |
Important: Even if only part of the context is relevant to the question, use that part to provide a partial answer rather than saying there isn't enough information. | |
Context: | |
{context} | |
Question: | |
{question} | |
Answer the question using the information from the context above. If you find ANY relevant information, use it to provide at least a partial answer. Only say "I don't have enough information" if there is absolutely nothing relevant in the context. | |
""" | |
base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template) | |
base_llm = ChatOpenAI( | |
model="gpt-3.5-turbo", | |
temperature=0.7, | |
openai_api_key=openai_api_key | |
) | |
# Create the RAG chain | |
def format_docs(docs): | |
formatted = "\n\n".join(doc.page_content for doc in docs) | |
logger.info(f"Formatted {len(docs)} documents into context of length: {len(formatted)}") | |
return formatted | |
# Add logging for the retrieved documents and final prompt | |
def get_context_and_log(input_dict): | |
try: | |
logger.info("\nAttempting to retrieve relevant documents...") | |
# Log the query being used | |
logger.info(f"Query: {input_dict['question']}") | |
# Use the newer invoke method instead of get_relevant_documents | |
retrieved_docs = qdrant_retriever.invoke(input_dict["question"]) | |
logger.info(f"Successfully retrieved {len(retrieved_docs)} documents") | |
if not retrieved_docs: | |
logger.warning("No documents were retrieved!") | |
return {"context": "No relevant context found.", "question": input_dict["question"]} | |
# Log each retrieved document with its content and similarity score | |
total_content_length = 0 | |
for i, doc in enumerate(retrieved_docs): | |
logger.info(f"\nDocument {i+1}:") | |
logger.info("=" * 50) | |
logger.info(f"Content: {doc.page_content}") | |
logger.info(f"Content Length: {len(doc.page_content)} characters") | |
logger.info(f"Metadata: {doc.metadata}") | |
logger.info("=" * 50) | |
total_content_length += len(doc.page_content) | |
context = format_docs(retrieved_docs) | |
# Log the final formatted context and question | |
logger.info("\nRetrieval Statistics:") | |
logger.info(f"Total documents retrieved: {len(retrieved_docs)}") | |
logger.info(f"Total content length: {total_content_length} characters") | |
logger.info(f"Average document length: {total_content_length/len(retrieved_docs):.2f} characters") | |
logger.info("\nFinal Context and Question:") | |
logger.info("=" * 50) | |
logger.info("Context:") | |
logger.info(f"{context}") | |
logger.info("-" * 50) | |
logger.info(f"Question: {input_dict['question']}") | |
logger.info("=" * 50) | |
if not context.strip(): | |
logger.error("Warning: Empty context retrieved!") | |
return {"context": "No relevant context found.", "question": input_dict["question"]} | |
return {"context": context, "question": input_dict["question"]} | |
except Exception as e: | |
logger.error(f"Error in get_context_and_log: {str(e)}") | |
logger.error("Stack trace:", exc_info=True) | |
return {"context": "Error retrieving context.", "question": input_dict["question"]} | |
# Create the chain | |
chain = ( | |
RunnablePassthrough() | |
| get_context_and_log | |
| base_rag_prompt | |
| base_llm | |
) | |
# Get response with enhanced logging | |
try: | |
logger.info("\nGenerating response...") | |
response = chain.invoke({"question": request.message}) | |
logger.info("=" * 50) | |
logger.info("Final Response:") | |
logger.info(f"{response.content}") | |
logger.info("=" * 50) | |
return PodcastChatResponse(response=response.content) | |
except Exception as e: | |
logger.error(f"Error generating response: {str(e)}") | |
raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}") | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.error(f"Error in podcast chat: {str(e)}", exc_info=True) | |
raise HTTPException(status_code=500, detail=str(e)) | |
# Include the API router | |
app.include_router(api_router) | |
# Mount static directories | |
app.mount("/audio-files", StaticFiles(directory=audio_dir), name="audio") | |
app.mount("/", StaticFiles(directory="static", html=True), name="frontend") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |