Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -317,97 +317,102 @@ def process_citations(complete_answer: str, ranked_docs: List[dict]) -> Tuple[st
|
|
317 |
# ------------------------------------------------------------------------------
|
318 |
@app.websocket("/chat")
|
319 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
320 |
await websocket.accept()
|
|
|
321 |
try:
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
-
|
405 |
-
|
406 |
-
|
|
|
|
|
407 |
except Exception as e:
|
408 |
logger.error(f"Unexpected error: {e}")
|
409 |
await safe_send(websocket, {"response": "Something went wrong! Please try again.", "sources": []})
|
410 |
|
|
|
411 |
# ------------------------------------------------------------------------------
|
412 |
# Simple health check endpoint
|
413 |
# ------------------------------------------------------------------------------
|
|
|
317 |
# ------------------------------------------------------------------------------
|
318 |
@app.websocket("/chat")
|
319 |
async def websocket_endpoint(websocket: WebSocket):
|
320 |
+
logger.info("Client connected to WebSocket")
|
321 |
await websocket.accept()
|
322 |
+
|
323 |
try:
|
324 |
+
while True:
|
325 |
+
try:
|
326 |
+
# Wait indefinitely for the next query from the client.
|
327 |
+
data = await websocket.receive_json()
|
328 |
+
except WebSocketDisconnect:
|
329 |
+
logger.info("Client disconnected")
|
330 |
+
break
|
331 |
+
except Exception as e:
|
332 |
+
logger.error(f"Error receiving data: {e}")
|
333 |
+
await safe_send(websocket, {"response": "Error receiving data", "sources": []})
|
334 |
+
continue
|
335 |
+
|
336 |
+
# Validate the received query
|
337 |
+
try:
|
338 |
+
chat_request = ChatRequest(**data)
|
339 |
+
except Exception as e:
|
340 |
+
logger.error(f"Validation error: {e}")
|
341 |
+
await safe_send(websocket, {"response": "Invalid query data", "sources": []})
|
342 |
+
continue
|
343 |
+
|
344 |
+
question = chat_request.question
|
345 |
+
language = chat_request.language
|
346 |
+
previous_chats = chat_request.previous_chats
|
347 |
+
|
348 |
+
# Retrieve documents using the retriever
|
349 |
+
try:
|
350 |
+
retrieved_docs = await asyncio.to_thread(retriever.invoke, question)
|
351 |
+
except Exception as e:
|
352 |
+
logger.error(f"Document retrieval error: {e}")
|
353 |
+
await safe_send(websocket, {"response": "Document retrieval failed", "sources": []})
|
354 |
+
continue
|
355 |
+
|
356 |
+
docs = [{
|
357 |
+
"summary": ele.metadata.get("summary", ""),
|
358 |
+
"chunk": ele.page_content,
|
359 |
+
"page_source": ele.metadata.get("source", "")
|
360 |
+
} for ele in retrieved_docs]
|
361 |
+
|
362 |
+
if not docs:
|
363 |
+
await safe_send(websocket, {"response": "Cannot provide an answer to this question", "sources": []})
|
364 |
+
continue
|
365 |
+
|
366 |
+
# Rerank the documents; if the reranking fails, use the original docs
|
367 |
+
try:
|
368 |
+
ranked_docs = await asyncio.to_thread(rerank_docs, question, docs, pc)
|
369 |
+
except Exception as e:
|
370 |
+
logger.error(f"Reranking error: {e}")
|
371 |
+
ranked_docs = docs
|
372 |
+
|
373 |
+
# Prepare the conversation messages for the chat model.
|
374 |
+
messages = [{"role": "system", "content": system_prompt}]
|
375 |
+
messages.extend(previous_chats)
|
376 |
+
messages.append({"role": "user", "content": format_query(question, language, ranked_docs)})
|
377 |
+
|
378 |
+
complete_answer = ""
|
379 |
+
chunk_buffer = ""
|
380 |
+
|
381 |
+
# Generate the answer in streaming mode.
|
382 |
+
try:
|
383 |
+
completion = await openai_client.chat.completions.create(
|
384 |
+
model="gpt-4o",
|
385 |
+
messages=messages,
|
386 |
+
temperature=0.2,
|
387 |
+
max_completion_tokens=1024,
|
388 |
+
stream=True
|
389 |
+
)
|
390 |
+
async for chunk in completion:
|
391 |
+
delta_content = chunk.choices[0].delta.content
|
392 |
+
if delta_content:
|
393 |
+
complete_answer += delta_content
|
394 |
+
# Remove inline citation markers from the streamed chunk before sending.
|
395 |
+
cleaned_content = re.sub(r'\[\d+\]', '', delta_content)
|
396 |
+
chunk_buffer += cleaned_content
|
397 |
+
if len(chunk_buffer) >= 1:
|
398 |
+
await safe_send(websocket, {"response": chunk_buffer})
|
399 |
+
chunk_buffer = ""
|
400 |
+
if chunk_buffer:
|
401 |
+
await safe_send(websocket, {"response": chunk_buffer})
|
402 |
+
except Exception as e:
|
403 |
+
logger.error(f"Streaming error: {e}")
|
404 |
+
await safe_send(websocket, {"response": "Response generation failed", "sources": []})
|
405 |
+
continue
|
406 |
+
|
407 |
+
# Process citations in the complete answer and send the final response.
|
408 |
+
complete_answer, citations = process_citations(complete_answer, ranked_docs)
|
409 |
+
await safe_send(websocket, {"response": complete_answer, "sources": citations})
|
410 |
+
|
411 |
except Exception as e:
|
412 |
logger.error(f"Unexpected error: {e}")
|
413 |
await safe_send(websocket, {"response": "Something went wrong! Please try again.", "sources": []})
|
414 |
|
415 |
+
|
416 |
# ------------------------------------------------------------------------------
|
417 |
# Simple health check endpoint
|
418 |
# ------------------------------------------------------------------------------
|