Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException, APIRouter | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from typing import List, Dict, Any, Optional | |
| import os | |
| import json | |
| from workflow import create_workflow, run_workflow | |
| import logging | |
| from dotenv import load_dotenv | |
| from langchain_openai import ChatOpenAI | |
| from langchain.prompts import ChatPromptTemplate | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import Qdrant | |
| from langchain_openai.embeddings import OpenAIEmbeddings | |
| from langchain_openai.chat_models import ChatOpenAI | |
| from operator import itemgetter | |
| from langchain.schema.output_parser import StrOutputParser | |
| from langchain.schema.runnable import RunnablePassthrough | |
| # Load environment variables | |
| load_dotenv() | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize components | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai_api_key: | |
| raise ValueError("OpenAI API key not configured") | |
| # Initialize OpenAI components | |
| chat_model = ChatOpenAI( | |
| model_name="gpt-3.5-turbo", | |
| temperature=0.7, | |
| openai_api_key=openai_api_key | |
| ) | |
| # Define Pydantic models | |
| class ChatMessage(BaseModel): | |
| content: str | |
| context: Optional[Dict[str, Any]] = None | |
| agent_type: Optional[str] = "believer" | |
| class WorkflowResponse(BaseModel): | |
| debate_history: List[Dict[str, str]] | |
| supervisor_notes: List[str] | |
| supervisor_chunks: List[Dict[str, List[str]]] | |
| extractor_data: Dict[str, Any] | |
| final_podcast: Dict[str, Any] | |
| class PodcastChatRequest(BaseModel): | |
| message: str | |
| class PodcastChatResponse(BaseModel): | |
| response: str | |
| # Initialize FastAPI app | |
| app = FastAPI() | |
| # Create API router | |
| api_router = APIRouter(prefix="/api") | |
| # Configure CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["http://localhost:5173", "http://localhost:3000", "https://*.hf.space", "*"], | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "HEAD"], | |
| allow_headers=["*"], | |
| expose_headers=["Content-Type", "Content-Length"], | |
| max_age=600, | |
| ) | |
| # Configure storage directories | |
| audio_dir = os.path.join(os.path.dirname(__file__), "audio_storage") | |
| os.makedirs(audio_dir, exist_ok=True) | |
| context_dir = os.path.join(os.path.dirname(__file__), "context_storage") | |
| os.makedirs(context_dir, exist_ok=True) | |
| # Add transcripts directory | |
| transcripts_dir = os.path.join(os.path.dirname(__file__), "transcripts") | |
| os.makedirs(transcripts_dir, exist_ok=True) | |
| # Initialize empty transcripts file if it doesn't exist | |
| transcripts_file = os.path.join(transcripts_dir, "podcasts.json") | |
| if not os.path.exists(transcripts_file): | |
| with open(transcripts_file, 'w') as f: | |
| json.dump([], f) | |
| # API Routes | |
| async def chat(message: ChatMessage): | |
| """Process a chat message.""" | |
| try: | |
| # Get API key | |
| tavily_api_key = os.getenv("TAVILY_API_KEY") | |
| if not tavily_api_key: | |
| logger.error("Tavily API key not found") | |
| raise HTTPException(status_code=500, detail="Tavily API key not configured") | |
| # Initialize the workflow | |
| try: | |
| workflow = create_workflow(tavily_api_key) | |
| logger.info("Workflow created successfully") | |
| except Exception as e: | |
| logger.error(f"Error creating workflow: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error creating workflow: {str(e)}") | |
| # Run the workflow with context | |
| try: | |
| result = await run_workflow( | |
| workflow, | |
| message.content, | |
| agent_type=message.agent_type, | |
| context=message.context | |
| ) | |
| logger.info("Workflow completed successfully") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Error running workflow: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error running workflow: {str(e)}") | |
| except Exception as e: | |
| logger.error(f"Error in chat endpoint: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def list_audio_files(): | |
| """List all available audio files.""" | |
| try: | |
| files = os.listdir(audio_dir) | |
| audio_files = [] | |
| for file in files: | |
| if file.endswith(('.mp3', '.wav')): | |
| file_path = os.path.join(audio_dir, file) | |
| audio_files.append({ | |
| "filename": file, | |
| "path": f"/audio-files/{file}", | |
| "size": os.path.getsize(file_path) | |
| }) | |
| return audio_files if audio_files else [] | |
| except Exception as e: | |
| logger.error(f"Error listing audio files: {str(e)}") | |
| return [] | |
| async def get_audio_file(filename: str): | |
| """Get an audio file by filename.""" | |
| try: | |
| file_path = os.path.join(audio_dir, filename) | |
| if not os.path.exists(file_path): | |
| logger.error(f"Audio file not found: {filename}") | |
| raise HTTPException(status_code=404, detail="File not found") | |
| return FileResponse(file_path, media_type="audio/mpeg") | |
| except Exception as e: | |
| logger.error(f"Error serving audio file: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def delete_audio_file(filename: str): | |
| """Delete an audio file and its corresponding transcript.""" | |
| try: | |
| # Check if file exists before attempting deletion | |
| file_path = os.path.join(audio_dir, filename) | |
| if not os.path.exists(file_path): | |
| logger.error(f"File not found for deletion: {filename}") | |
| raise HTTPException(status_code=404, detail="File not found") | |
| try: | |
| # Delete the audio file first | |
| os.remove(file_path) | |
| logger.info(f"Deleted audio file: {filename}") | |
| # Get all remaining audio files | |
| audio_files = [f for f in os.listdir(audio_dir) if f.endswith(('.mp3', '.wav'))] | |
| # Try to update transcripts if they exist | |
| transcripts_file = os.path.join(os.path.dirname(__file__), "transcripts", "podcasts.json") | |
| if os.path.exists(transcripts_file): | |
| with open(transcripts_file, 'r') as f: | |
| transcripts = json.load(f) | |
| # Find the index of the deleted file in the original list | |
| try: | |
| podcast_id = audio_files.index(filename) + 1 | |
| if len(transcripts) >= podcast_id: | |
| transcripts.pop(podcast_id - 1) | |
| with open(transcripts_file, 'w') as f: | |
| json.dump(transcripts, f, indent=2) | |
| logger.info(f"Updated transcripts after deletion") | |
| except ValueError: | |
| logger.warning(f"Could not find podcast ID for {filename} in transcripts") | |
| return {"message": "File deleted successfully"} | |
| except Exception as e: | |
| logger.error(f"Error during file deletion process: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| except HTTPException as he: | |
| raise he | |
| except Exception as e: | |
| logger.error(f"Error in delete_audio_file: {str(e)}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def get_podcast_context(podcast_id: str): | |
| """Get or generate context for a podcast.""" | |
| try: | |
| logger.info(f"Getting context for podcast {podcast_id}") | |
| context_path = os.path.join(context_dir, f"{podcast_id}_context.json") | |
| # If context exists, return it | |
| if os.path.exists(context_path): | |
| logger.info(f"Found existing context file at {context_path}") | |
| with open(context_path, 'r') as f: | |
| return json.load(f) | |
| # If no context exists, we need to create it from the podcast content | |
| logger.info("No existing context found, creating new context") | |
| # Get the audio files to find the podcast filename | |
| files = os.listdir(audio_dir) | |
| logger.info(f"Found {len(files)} files in audio directory") | |
| podcast_files = [f for f in files if f.endswith('.mp3')] | |
| logger.info(f"Found {len(podcast_files)} podcast files: {podcast_files}") | |
| if not podcast_files: | |
| logger.error("No podcast files found") | |
| raise HTTPException(status_code=404, detail="No podcast files found") | |
| # Find the podcast file that matches this ID | |
| try: | |
| podcast_index = int(podcast_id) - 1 # Convert 1-based ID to 0-based index | |
| if podcast_index < 0 or podcast_index >= len(podcast_files): | |
| raise ValueError(f"Invalid podcast ID: {podcast_id}, total podcasts: {len(podcast_files)}") | |
| podcast_filename = podcast_files[podcast_index] | |
| logger.info(f"Selected podcast file: {podcast_filename}") | |
| except (ValueError, IndexError) as e: | |
| logger.error(f"Invalid podcast ID: {podcast_id}, Error: {str(e)}") | |
| raise HTTPException(status_code=404, detail=f"Invalid podcast ID: {podcast_id}") | |
| # Extract topic from filename | |
| try: | |
| topic = podcast_filename.split('-')[0].replace('_', ' ') | |
| logger.info(f"Extracted topic: {topic}") | |
| except Exception as e: | |
| logger.error(f"Error extracting topic from filename: {podcast_filename}, Error: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error extracting topic from filename: {str(e)}") | |
| # Initialize OpenAI chat model for content analysis | |
| try: | |
| chat_model = ChatOpenAI( | |
| model_name="gpt-3.5-turbo", | |
| temperature=0.3, | |
| openai_api_key=openai_api_key | |
| ) | |
| logger.info("Successfully initialized ChatOpenAI") | |
| except Exception as e: | |
| logger.error(f"Error initializing ChatOpenAI: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error initializing chat model: {str(e)}") | |
| # Create prompt template for content analysis | |
| prompt = ChatPromptTemplate.from_messages([ | |
| ("system", """You are an expert content analyzer. Your task is to: | |
| 1. Analyze the given topic and create balanced, factual content chunks about it | |
| 2. Generate two types of chunks: | |
| - Believer chunks: Positive aspects, opportunities, and solutions related to the topic | |
| - Skeptic chunks: Challenges, risks, and critical questions about the topic | |
| 3. Each chunk should be self-contained and focused on a single point | |
| 4. Keep chunks concise (2-3 sentences each) | |
| 5. Ensure all content is factual and balanced | |
| Format your response as a JSON object with two arrays: | |
| {{ | |
| "believer_chunks": ["chunk1", "chunk2", ...], | |
| "skeptic_chunks": ["chunk1", "chunk2", ...] | |
| }}"""), | |
| ("human", "Create balanced content chunks about this topic: {topic}") | |
| ]) | |
| # Generate content chunks | |
| chain = prompt | chat_model | |
| try: | |
| logger.info(f"Generating content chunks for topic: {topic}") | |
| response = await chain.ainvoke({ | |
| "topic": topic | |
| }) | |
| logger.info("Successfully received response from OpenAI") | |
| # Parse the response content as JSON | |
| try: | |
| content_chunks = json.loads(response.content) | |
| logger.info(f"Successfully parsed response JSON with {len(content_chunks.get('believer_chunks', []))} believer chunks and {len(content_chunks.get('skeptic_chunks', []))} skeptic chunks") | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Error parsing response JSON: {str(e)}, Response content: {response.content}") | |
| raise HTTPException(status_code=500, detail=f"Error parsing content chunks: {str(e)}") | |
| # Create the context object | |
| context = { | |
| "topic": topic, | |
| "believer_chunks": content_chunks.get("believer_chunks", []), | |
| "skeptic_chunks": content_chunks.get("skeptic_chunks", []) | |
| } | |
| # Save the context | |
| try: | |
| with open(context_path, 'w') as f: | |
| json.dump(context, f) | |
| logger.info(f"Saved new context to {context_path}") | |
| except Exception as e: | |
| logger.error(f"Error saving context file: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error saving context file: {str(e)}") | |
| return context | |
| except Exception as e: | |
| logger.error(f"Error generating content chunks: {str(e)}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Error generating content chunks: {str(e)}" | |
| ) | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in get_podcast_context: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def podcast_chat(podcast_id: str, request: PodcastChatRequest): | |
| """Handle chat messages for a specific podcast.""" | |
| try: | |
| logger.info(f"Processing chat message for podcast {podcast_id}") | |
| logger.info(f"User message: {request.message}") | |
| # Get list of audio files | |
| audio_files = [f for f in os.listdir(audio_dir) if f.endswith('.mp3')] | |
| logger.info(f"Found {len(audio_files)} audio files: {audio_files}") | |
| # Convert podcast_id to zero-based index and get the filename | |
| try: | |
| podcast_index = int(podcast_id) - 1 | |
| if podcast_index < 0 or podcast_index >= len(audio_files): | |
| logger.error(f"Invalid podcast index: {podcast_index} (total files: {len(audio_files)})") | |
| raise ValueError(f"Invalid podcast ID: {podcast_id}") | |
| podcast_filename = audio_files[podcast_index] | |
| logger.info(f"Found podcast file: {podcast_filename}") | |
| except ValueError as e: | |
| logger.error(f"Error converting podcast ID: {str(e)}") | |
| raise HTTPException(status_code=404, detail=str(e)) | |
| # Extract topic from filename | |
| topic = podcast_filename.split('-')[0].replace('_', ' ') | |
| logger.info(f"Extracted topic: {topic}") | |
| # Path to transcripts file | |
| transcripts_file = os.path.join(os.path.dirname(__file__), "transcripts", "podcasts.json") | |
| # Check if transcripts file exists | |
| if not os.path.exists(transcripts_file): | |
| logger.error("Transcripts file not found") | |
| raise HTTPException(status_code=404, detail="Transcripts file not found") | |
| # Read transcripts | |
| try: | |
| with open(transcripts_file, 'r') as f: | |
| transcripts = json.load(f) | |
| logger.info(f"Loaded {len(transcripts)} transcripts") | |
| logger.info(f"Available topics: {[t.get('topic', 'NO_TOPIC') for t in transcripts]}") | |
| except json.JSONDecodeError as e: | |
| logger.error(f"Error decoding transcripts file: {str(e)}") | |
| raise HTTPException(status_code=500, detail="Error reading transcripts file") | |
| # Find matching transcript by topic | |
| podcast_transcript = None | |
| for transcript in transcripts: | |
| transcript_topic = transcript.get("topic", "").lower().strip() | |
| if transcript_topic == topic.lower().strip(): | |
| podcast_transcript = transcript.get("podcastScript") | |
| logger.info(f"Found matching transcript for topic: {topic}") | |
| break | |
| if not podcast_transcript: | |
| logger.error(f"No transcript found for topic: {topic}") | |
| logger.error(f"Available topics: {[t.get('topic', 'NO_TOPIC') for t in transcripts]}") | |
| raise HTTPException(status_code=404, detail=f"No transcript found for topic: {topic}") | |
| logger.info(f"Found transcript for topic: {topic}") | |
| logger.info(f"Full transcript length: {len(podcast_transcript)} characters") | |
| logger.debug(f"Transcript preview: {podcast_transcript[:200]}...") | |
| # Split text into chunks | |
| text_splitter = RecursiveCharacterTextSplitter( | |
| chunk_size=1000, | |
| chunk_overlap=100, | |
| length_function=len, | |
| separators=["\n\n", "\n", ". ", " ", ""] | |
| ) | |
| # Use split_text for strings instead of split_documents | |
| try: | |
| logger.info("Starting text splitting process...") | |
| chunks = text_splitter.split_text(podcast_transcript) | |
| logger.info(f"Successfully split transcript into {len(chunks)} chunks") | |
| # Log some sample chunks | |
| logger.info("\nSample chunks:") | |
| for i, chunk in enumerate(chunks[:3]): # Log first 3 chunks | |
| logger.info(f"\nChunk {i+1}:") | |
| logger.info("=" * 50) | |
| logger.info(chunk) | |
| logger.info("=" * 50) | |
| if len(chunks) > 3: | |
| logger.info(f"... and {len(chunks) - 3} more chunks") | |
| except Exception as e: | |
| logger.error(f"Error splitting text into chunks: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error splitting text: {str(e)}") | |
| if not chunks: | |
| logger.error("No content chunks found in transcript") | |
| raise HTTPException(status_code=404, detail="No content chunks found in transcript") | |
| # Validate chunk sizes | |
| chunk_sizes = [len(chunk) for chunk in chunks] | |
| logger.info(f"\nChunk size statistics:") | |
| logger.info(f"Min chunk size: {min(chunk_sizes)} characters") | |
| logger.info(f"Max chunk size: {max(chunk_sizes)} characters") | |
| logger.info(f"Average chunk size: {sum(chunk_sizes)/len(chunk_sizes):.2f} characters") | |
| # Initialize embedding model | |
| embedding_model = OpenAIEmbeddings( | |
| model="text-embedding-3-small", | |
| openai_api_key=openai_api_key | |
| ) | |
| # Create a unique collection name for this podcast | |
| collection_name = f"podcast_{podcast_id}" | |
| # Initialize Qdrant with local storage | |
| vectorstore = Qdrant.from_texts( | |
| texts=chunks, | |
| embedding=embedding_model, | |
| location=":memory:", # Use in-memory storage | |
| collection_name=collection_name | |
| ) | |
| logger.info(f"Created vector store for podcast {podcast_id}") | |
| # Configure the retriever with search parameters | |
| qdrant_retriever = vectorstore.as_retriever( | |
| search_type="similarity", # Use simple similarity search | |
| search_kwargs={ | |
| "k": 8, # Increased from 5 to 8 chunks | |
| "score_threshold": 0.05 # Lowered threshold further for more matches | |
| } | |
| ) | |
| base_rag_prompt_template = """\ | |
| You are a helpful podcast assistant. Answer the user's question based on the provided context from the podcast transcript. | |
| If the context contains relevant information, use it to answer the question. | |
| If you can't find relevant information in the context to answer the question, say "I don't have enough information to answer that question." | |
| Keep your responses concise and focused on the question. | |
| Important: Even if only part of the context is relevant to the question, use that part to provide a partial answer rather than saying there isn't enough information. | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| Answer the question using the information from the context above. If you find ANY relevant information, use it to provide at least a partial answer. Only say "I don't have enough information" if there is absolutely nothing relevant in the context. | |
| """ | |
| base_rag_prompt = ChatPromptTemplate.from_template(base_rag_prompt_template) | |
| base_llm = ChatOpenAI( | |
| model="gpt-3.5-turbo", | |
| temperature=0.7, | |
| openai_api_key=openai_api_key | |
| ) | |
| # Create the RAG chain | |
| def format_docs(docs): | |
| formatted = "\n\n".join(doc.page_content for doc in docs) | |
| logger.info(f"Formatted {len(docs)} documents into context of length: {len(formatted)}") | |
| return formatted | |
| # Add logging for the retrieved documents and final prompt | |
| def get_context_and_log(input_dict): | |
| try: | |
| logger.info("\nAttempting to retrieve relevant documents...") | |
| # Log the query being used | |
| logger.info(f"Query: {input_dict['question']}") | |
| # Use the newer invoke method instead of get_relevant_documents | |
| retrieved_docs = qdrant_retriever.invoke(input_dict["question"]) | |
| logger.info(f"Successfully retrieved {len(retrieved_docs)} documents") | |
| if not retrieved_docs: | |
| logger.warning("No documents were retrieved!") | |
| return {"context": "No relevant context found.", "question": input_dict["question"]} | |
| # Log each retrieved document with its content and similarity score | |
| total_content_length = 0 | |
| for i, doc in enumerate(retrieved_docs): | |
| logger.info(f"\nDocument {i+1}:") | |
| logger.info("=" * 50) | |
| logger.info(f"Content: {doc.page_content}") | |
| logger.info(f"Content Length: {len(doc.page_content)} characters") | |
| logger.info(f"Metadata: {doc.metadata}") | |
| logger.info("=" * 50) | |
| total_content_length += len(doc.page_content) | |
| context = format_docs(retrieved_docs) | |
| # Log the final formatted context and question | |
| logger.info("\nRetrieval Statistics:") | |
| logger.info(f"Total documents retrieved: {len(retrieved_docs)}") | |
| logger.info(f"Total content length: {total_content_length} characters") | |
| logger.info(f"Average document length: {total_content_length/len(retrieved_docs):.2f} characters") | |
| logger.info("\nFinal Context and Question:") | |
| logger.info("=" * 50) | |
| logger.info("Context:") | |
| logger.info(f"{context}") | |
| logger.info("-" * 50) | |
| logger.info(f"Question: {input_dict['question']}") | |
| logger.info("=" * 50) | |
| if not context.strip(): | |
| logger.error("Warning: Empty context retrieved!") | |
| return {"context": "No relevant context found.", "question": input_dict["question"]} | |
| return {"context": context, "question": input_dict["question"]} | |
| except Exception as e: | |
| logger.error(f"Error in get_context_and_log: {str(e)}") | |
| logger.error("Stack trace:", exc_info=True) | |
| return {"context": "Error retrieving context.", "question": input_dict["question"]} | |
| # Create the chain | |
| chain = ( | |
| RunnablePassthrough() | |
| | get_context_and_log | |
| | base_rag_prompt | |
| | base_llm | |
| ) | |
| # Get response with enhanced logging | |
| try: | |
| logger.info("\nGenerating response...") | |
| response = chain.invoke({"question": request.message}) | |
| logger.info("=" * 50) | |
| logger.info("Final Response:") | |
| logger.info(f"{response.content}") | |
| logger.info("=" * 50) | |
| return PodcastChatResponse(response=response.content) | |
| except Exception as e: | |
| logger.error(f"Error generating response: {str(e)}") | |
| raise HTTPException(status_code=500, detail=f"Error generating response: {str(e)}") | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.error(f"Error in podcast chat: {str(e)}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Include the API router | |
| app.include_router(api_router) | |
| # Mount static directories | |
| app.mount("/audio-files", StaticFiles(directory=audio_dir), name="audio") | |
| app.mount("/", StaticFiles(directory="static", html=True), name="frontend") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) |