Ritesh-hf commited on
Commit
e524800
·
verified ·
1 Parent(s): a627b06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -85
app.py CHANGED
@@ -317,97 +317,102 @@ def process_citations(complete_answer: str, ranked_docs: List[dict]) -> Tuple[st
317
  # ------------------------------------------------------------------------------
318
  @app.websocket("/chat")
319
  async def websocket_endpoint(websocket: WebSocket):
 
320
  await websocket.accept()
 
321
  try:
322
- # Receive and validate the request
323
- try:
324
- data = await asyncio.wait_for(websocket.receive_json(), timeout=30)
325
- chat_request = ChatRequest(**data)
326
- except ValidationError as e:
327
- logger.error(f"Validation error: {e}")
328
- await safe_send(websocket, {"response": "Something went wrong with your request!", "sources": []})
329
- return
330
- except Exception as e:
331
- logger.error(f"Error receiving data: {e}")
332
- await safe_send(websocket, {"response": "Something went wrong with your request!", "sources": []})
333
- return
334
-
335
- question = chat_request.question
336
- language = chat_request.language
337
-
338
- # Retrieve documents using the retriever
339
- try:
340
- retrieved_docs = await asyncio.to_thread(retriever.invoke, question)
341
- except Exception as e:
342
- logger.error(f"Document retrieval error: {e}")
343
- await safe_send(websocket, {"response": "Document retrieval failed", "sources": []})
344
- return
345
-
346
- docs = [{
347
- "summary": ele.metadata.get("summary", ""),
348
- "chunk": ele.page_content,
349
- "page_source": ele.metadata.get("source", "")
350
- } for ele in retrieved_docs]
351
-
352
- if not docs:
353
- await safe_send(websocket, {"response": "Cannot provide answer to this question", "sources": []})
354
- return
355
-
356
- # Rerank the documents (fallback to original docs if reranking fails)
357
- try:
358
- ranked_docs = await asyncio.to_thread(rerank_docs, question, docs, pc)
359
- except Exception as e:
360
- logger.error(f"Reranking error: {e}")
361
- ranked_docs = docs
362
-
363
- # Prepare the conversation messages
364
- messages = [{"role": "system", "content": system_prompt}]
365
- messages.extend(chat_request.previous_chats)
366
- messages.append({"role": "user", "content": format_query(question, language, ranked_docs)})
367
-
368
- complete_answer = ""
369
- chunk_buffer = ""
370
-
371
- # Generate and stream the chat response
372
- try:
373
- completion = await openai_client.chat.completions.create(
374
- model="gpt-4o",
375
- messages=messages,
376
- temperature=0.2,
377
- max_completion_tokens=1024,
378
- stream=True
379
- )
380
- async for chunk in completion:
381
- delta_content = chunk.choices[0].delta.content
382
- if delta_content:
383
- complete_answer += delta_content
384
- # Remove inline citation markers from the streamed chunk before sending
385
- cleaned_content = re.sub(r'\[\d+\]', '', delta_content)
386
- chunk_buffer += cleaned_content
387
- if len(chunk_buffer) >= 1:
388
- await safe_send(websocket, {"response": chunk_buffer})
389
- chunk_buffer = ""
390
- if chunk_buffer:
391
- await safe_send(websocket, {"response": chunk_buffer})
392
- except Exception as e:
393
- logger.error(f"Streaming error: {e}")
394
- await safe_send(websocket, {"response": "Response generation failed", "sources": []})
395
- return
396
-
397
- # Process and map citations in the final answer
398
- complete_answer, citations = process_citations(complete_answer, ranked_docs)
399
-
400
- await safe_send(websocket, {
401
- "response": complete_answer,
402
- "sources": citations
403
- })
404
-
405
- except WebSocketDisconnect:
406
- logger.info("Client disconnected")
 
 
407
  except Exception as e:
408
  logger.error(f"Unexpected error: {e}")
409
  await safe_send(websocket, {"response": "Something went wrong! Please try again.", "sources": []})
410
 
 
411
  # ------------------------------------------------------------------------------
412
  # Simple health check endpoint
413
  # ------------------------------------------------------------------------------
 
317
  # ------------------------------------------------------------------------------
318
  @app.websocket("/chat")
319
  async def websocket_endpoint(websocket: WebSocket):
320
+ logger.info("Client connected to WebSocket")
321
  await websocket.accept()
322
+
323
  try:
324
+ while True:
325
+ try:
326
+ # Wait indefinitely for the next query from the client.
327
+ data = await websocket.receive_json()
328
+ except WebSocketDisconnect:
329
+ logger.info("Client disconnected")
330
+ break
331
+ except Exception as e:
332
+ logger.error(f"Error receiving data: {e}")
333
+ await safe_send(websocket, {"response": "Error receiving data", "sources": []})
334
+ continue
335
+
336
+ # Validate the received query
337
+ try:
338
+ chat_request = ChatRequest(**data)
339
+ except Exception as e:
340
+ logger.error(f"Validation error: {e}")
341
+ await safe_send(websocket, {"response": "Invalid query data", "sources": []})
342
+ continue
343
+
344
+ question = chat_request.question
345
+ language = chat_request.language
346
+ previous_chats = chat_request.previous_chats
347
+
348
+ # Retrieve documents using the retriever
349
+ try:
350
+ retrieved_docs = await asyncio.to_thread(retriever.invoke, question)
351
+ except Exception as e:
352
+ logger.error(f"Document retrieval error: {e}")
353
+ await safe_send(websocket, {"response": "Document retrieval failed", "sources": []})
354
+ continue
355
+
356
+ docs = [{
357
+ "summary": ele.metadata.get("summary", ""),
358
+ "chunk": ele.page_content,
359
+ "page_source": ele.metadata.get("source", "")
360
+ } for ele in retrieved_docs]
361
+
362
+ if not docs:
363
+ await safe_send(websocket, {"response": "Cannot provide an answer to this question", "sources": []})
364
+ continue
365
+
366
+ # Rerank the documents; if the reranking fails, use the original docs
367
+ try:
368
+ ranked_docs = await asyncio.to_thread(rerank_docs, question, docs, pc)
369
+ except Exception as e:
370
+ logger.error(f"Reranking error: {e}")
371
+ ranked_docs = docs
372
+
373
+ # Prepare the conversation messages for the chat model.
374
+ messages = [{"role": "system", "content": system_prompt}]
375
+ messages.extend(previous_chats)
376
+ messages.append({"role": "user", "content": format_query(question, language, ranked_docs)})
377
+
378
+ complete_answer = ""
379
+ chunk_buffer = ""
380
+
381
+ # Generate the answer in streaming mode.
382
+ try:
383
+ completion = await openai_client.chat.completions.create(
384
+ model="gpt-4o",
385
+ messages=messages,
386
+ temperature=0.2,
387
+ max_completion_tokens=1024,
388
+ stream=True
389
+ )
390
+ async for chunk in completion:
391
+ delta_content = chunk.choices[0].delta.content
392
+ if delta_content:
393
+ complete_answer += delta_content
394
+ # Remove inline citation markers from the streamed chunk before sending.
395
+ cleaned_content = re.sub(r'\[\d+\]', '', delta_content)
396
+ chunk_buffer += cleaned_content
397
+ if len(chunk_buffer) >= 1:
398
+ await safe_send(websocket, {"response": chunk_buffer})
399
+ chunk_buffer = ""
400
+ if chunk_buffer:
401
+ await safe_send(websocket, {"response": chunk_buffer})
402
+ except Exception as e:
403
+ logger.error(f"Streaming error: {e}")
404
+ await safe_send(websocket, {"response": "Response generation failed", "sources": []})
405
+ continue
406
+
407
+ # Process citations in the complete answer and send the final response.
408
+ complete_answer, citations = process_citations(complete_answer, ranked_docs)
409
+ await safe_send(websocket, {"response": complete_answer, "sources": citations})
410
+
411
  except Exception as e:
412
  logger.error(f"Unexpected error: {e}")
413
  await safe_send(websocket, {"response": "Something went wrong! Please try again.", "sources": []})
414
 
415
+
416
  # ------------------------------------------------------------------------------
417
  # Simple health check endpoint
418
  # ------------------------------------------------------------------------------