Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
/* | |
# Copyright 2025 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
*/ | |
import React, {useEffect, useRef, useState} from 'react'; | |
import IconBackArrow from '../icons/IconBackArrow'; | |
import IconCodeBlocks from '../icons/IconCodeBlocks'; | |
import IconWarning from '../icons/IconWarning'; | |
import MCQOption from '../components/MCQOption'; | |
import ChatMessage from '../components/ChatMessage'; | |
import styles from './ChatScreen.module.css'; | |
import TextWithTooltips from '../components/TextWithTooltips'; | |
import {redactPhrases} from "../components/RedactedTextView.js"; | |
import {CONDITION_TERMS} from "../data/constants.js"; | |
import { fetchWithRetry } from '../utils/fetchWithRetry'; | |
const API_ENDPOINTS = { | |
getCaseImage: (journeyId) => `/api/case/${journeyId}/stub`, | |
getCaseQuestions: (journeyId) => `/api/case/${journeyId}/all-questions`, | |
summarizeCase: (journeyId) => `/api/case/${journeyId}/summarize`, | |
}; | |
const ChatScreen = ({journey, onNavigate, onShowDetails, cachedImage, onImageLoad, onGoToSummary}) => { | |
const [allQuestions, setAllQuestions] = useState([]); | |
const [currentQuestionIndex, setCurrentQuestionIndex] = useState(0); | |
const [messages, setMessages] = useState([]); | |
const [caseImage, setCaseImage] = useState(cachedImage || ''); | |
const [modelResponseHistory, setModelResponseHistory] = useState([]); | |
const [isSummarizing, setIsSummarizing] = useState(false); | |
const [isLoading, setIsLoading] = useState(true); | |
const [isImageLoading, setIsImageLoading] = useState(true); | |
const [imageError, setImageError] = useState(false); | |
const [userActionHistory, setUserActionHistory] = useState({}); | |
const chatWindowRef = useRef(null); | |
const timeoutIdsRef = useRef([]); | |
useEffect(() => { | |
if (chatWindowRef.current) { | |
chatWindowRef.current.scrollTo({ | |
top: chatWindowRef.current.scrollHeight, | |
behavior: 'smooth' | |
}); | |
} | |
}, [messages, isLoading]); | |
useEffect(() => { | |
const fetchImage = async () => { | |
if (cachedImage) { | |
setCaseImage(cachedImage); | |
setIsImageLoading(false); | |
return; | |
} | |
if (journey) { | |
setIsImageLoading(true); | |
setImageError(false); | |
try { | |
const response = await fetchWithRetry(API_ENDPOINTS.getCaseImage(journey.id)); | |
if (!response.ok) throw new Error("Failed to fetch case image"); | |
const data = await response.json(); | |
const imageUrl = data.download_image_url; | |
setCaseImage(imageUrl); | |
if (onImageLoad) onImageLoad(imageUrl); | |
} catch (error) { | |
console.error("Error fetching case data:", error); | |
setImageError(true); | |
} finally { | |
setIsImageLoading(false); | |
} | |
} | |
}; | |
fetchImage(); | |
}, [journey, cachedImage, onImageLoad]); | |
useEffect(() => { | |
if (!journey) { | |
setIsLoading(false); | |
return; | |
} | |
const fetchQuestions = async () => { | |
setIsLoading(true); | |
setMessages([]); | |
try { | |
const response = await fetchWithRetry(API_ENDPOINTS.getCaseQuestions(journey.id)); | |
if (!response.ok) throw new Error("Failed to fetch questions"); | |
const questions = await response.json(); | |
if (questions && questions.length > 0) { | |
const questionsWithIds = questions.map((q, index) => ({...q, id: `q-${journey.id}-${index}`})); | |
setAllQuestions(questionsWithIds); | |
setCurrentQuestionIndex(0); | |
displayQuestion(questionsWithIds[0], 0, questionsWithIds.length); | |
} else { | |
setMessages([{ | |
type: 'system', | |
id: Date.now(), | |
content: "Sorry, I couldn't load the questions for this case. Try again later!" | |
}]); | |
} | |
} catch (error) { | |
console.error("Error fetching questions:", error); | |
setMessages([{ | |
type: 'system', | |
id: Date.now(), | |
content: "Sorry, I couldn't load the questions for this case. Try again later!" | |
}]); | |
} finally { | |
setIsLoading(false); | |
} | |
}; | |
fetchQuestions(); | |
}, [journey]); | |
useEffect(() => { | |
return () => { | |
timeoutIdsRef.current.forEach(clearTimeout); | |
}; | |
}, []); | |
const displayQuestion = (questionData, index, totalQuestions) => { | |
let questionText = `Question ${index + 1}: ${questionData.question}`; | |
if (index === 0) { | |
questionText = `Okay, let's start with Case ${journey.id.padStart(2, '0')}. ${questionData.question}`; | |
} | |
const questionMessage = {type: 'system', id: Date.now(), content: questionText}; | |
const mcqMessage = { | |
type: 'mcq', | |
id: questionData.id || `q-${index}`, | |
data: questionData, | |
isLast: index === totalQuestions - 1, | |
incorrectAttempts: [], | |
isAnswered: false, | |
}; | |
setMessages(prev => [...prev, questionMessage, mcqMessage]); | |
}; | |
const handleSelectOption = (selectedOptionKey, messageId) => { | |
const currentMCQMessageIndex = messages.findIndex(m => m.id === messageId && !m.isAnswered); | |
if (currentMCQMessageIndex === -1) return; | |
const currentMCQMessage = messages[currentMCQMessageIndex]; | |
const {answer, rationale, hint, choices, id: questionId} = currentMCQMessage.data; | |
const selectedOptionText = choices[selectedOptionKey]; | |
const isCorrect = selectedOptionKey === answer; | |
setUserActionHistory(prev => { | |
const newHistory = {...prev}; | |
if (!newHistory[questionId]) { | |
newHistory[questionId] = {attempt1: selectedOptionKey}; | |
} else if (!newHistory[questionId].attempt2) { | |
newHistory[questionId] = {...newHistory[questionId], attempt2: selectedOptionKey}; | |
} | |
return newHistory; | |
}); | |
let userResponseMessage = {type: 'user', id: Date.now(), content: `You responded: "${selectedOptionText}"`}; | |
const updatedMessages = messages.map(msg => | |
msg.id === messageId ? {...msg, isAnswered: true} : msg | |
); | |
setMessages([...updatedMessages, userResponseMessage]); | |
let feedbackMessages = []; | |
const handleNextStep = (isQuestionComplete) => { | |
if (isQuestionComplete) { | |
setModelResponseHistory(prev => [...prev, currentMCQMessage.data]); | |
} | |
if (currentMCQMessage.isLast && isQuestionComplete) { | |
feedbackMessages.push({type: 'summary_button', id: Date.now() + 2}); | |
} else if (isQuestionComplete) { | |
const nextIndex = currentQuestionIndex + 1; | |
if (nextIndex < allQuestions.length) { | |
const timerId = setTimeout(() => { | |
setCurrentQuestionIndex(nextIndex); | |
displayQuestion(allQuestions[nextIndex], nextIndex, allQuestions.length); | |
}, 1500); | |
timeoutIdsRef.current.push(timerId); | |
} | |
} | |
}; | |
let redactedRationale = ""; | |
if (!currentMCQMessage.isLast) { | |
redactedRationale = redactPhrases(rationale, CONDITION_TERMS); | |
} else { | |
redactedRationale = rationale | |
} | |
if (isCorrect) { | |
feedbackMessages.push({type: 'system', id: Date.now() + 1, content: `That's right. ${redactedRationale}`}); | |
handleNextStep(true); | |
} else { | |
const attempts = [...currentMCQMessage.incorrectAttempts, selectedOptionKey]; | |
if (attempts.length < 2) { | |
feedbackMessages.push({ | |
type: 'system_hint', id: Date.now() + 1, | |
content: `That's not quite right. Would you like to try again?`, | |
hint: `Hint: ${hint}`, | |
}); | |
feedbackMessages.push({ | |
...currentMCQMessage, type: 'mcq_retry', id: Date.now() + 2, | |
incorrectAttempts: attempts, isAnswered: false, | |
}); | |
} else { | |
feedbackMessages.push({ | |
type: 'system', id: Date.now() + 1, | |
content: `That's not right. The correct answer is "${choices[answer]}". ${redactedRationale}` | |
}); | |
handleNextStep(true); | |
} | |
} | |
const timerId = setTimeout(() => { | |
setMessages(prev => [...prev, ...feedbackMessages]); | |
}, 800); | |
timeoutIdsRef.current.push(timerId); | |
}; | |
const handleGoToSummary = async () => { | |
setIsSummarizing(true); | |
const conversation_history = modelResponseHistory.map(modelResponse => { | |
const userResponse = userActionHistory[modelResponse.id] || {}; | |
const finalUserResponse = { | |
attempt1: userResponse.attempt1 || null, | |
attempt2: userResponse.attempt2 || null, | |
}; | |
return { | |
ModelResponse: modelResponse, | |
UserResponse: finalUserResponse | |
}; | |
}); | |
try { | |
const response = await fetchWithRetry(API_ENDPOINTS.summarizeCase(journey.id), { | |
method: 'POST', | |
headers: {'Content-Type': 'application/json'}, | |
body: JSON.stringify({conversation_history}) | |
}); | |
if (!response.ok) { | |
throw new Error(`Failed to fetch summary. Status: ${response.status}`); | |
} | |
const summaryData = await response.json(); | |
onGoToSummary(summaryData); | |
} catch (error) { | |
console.error("Error fetching summary:", error); | |
setMessages(prev => [...prev, { | |
type: 'system', | |
id: Date.now(), | |
content: "Sorry, there was an error generating the summary.Please try again." | |
}]); | |
} finally { | |
setIsSummarizing(false); | |
} | |
}; | |
return ( | |
<div className={styles.pageContainer}> | |
<div className={styles.topNav}> | |
<button className={styles.navButton} onClick={() => onNavigate('landing')}> | |
<IconBackArrow/> <span>Exit</span> | |
</button> | |
<button className={styles.detailsButton} onClick={onShowDetails}> | |
<IconCodeBlocks fill="#004A77"/> Details about this Demo | |
</button> | |
</div> | |
<div className={styles.contentBox}> | |
<div className={styles.mainLayout}> | |
<div className={styles.leftPanel}> | |
{isImageLoading ? ( | |
<div className={styles.loadingContainer}> | |
<div className={styles.loadingSpinner}></div> | |
<p className={styles.loadingText}>Loading Chest X-Ray image...</p> | |
</div> | |
) : imageError ? ( | |
<div className={styles.imageErrorFallback}> | |
<p>⚠️</p> | |
<p>Could not load case image. Please try again.</p> | |
</div> | |
) : ( | |
<img | |
src={caseImage} | |
alt={`Image for Case ${journey.label}`} | |
className={styles.caseImage} | |
/> | |
)} | |
</div> | |
<div className={styles.rightPanel}> | |
{!isLoading && allQuestions.length > 0 && ( | |
<div className={styles.progressTracker}> | |
{modelResponseHistory.length} / {allQuestions.length} questions answered | |
</div> | |
)} | |
<div className={styles.chatWindow} ref={chatWindowRef}> | |
{isLoading ? ( | |
<div className={styles.loadingContainer}> | |
<div className={styles.loadingSpinner}></div> | |
<p className={styles.loadingText}>Loading questions...</p> | |
</div> | |
) : ( | |
messages.map((msg) => { | |
switch (msg.type) { | |
case 'mcq': | |
case 'mcq_retry': | |
return ( | |
<div key={msg.id} className={styles.mcqOptionsOnly}> | |
{Object.entries(msg.data.choices).map(([key, value]) => ( | |
<MCQOption | |
key={key} | |
text={value} | |
onClick={() => handleSelectOption(key, msg.id)} | |
disabled={msg.isAnswered || msg.incorrectAttempts.includes(key)} | |
isIncorrect={msg.incorrectAttempts.includes(key)} | |
/> | |
))} | |
</div> | |
); | |
case 'user': | |
return ( | |
<ChatMessage key={msg.id} type="user" text={msg.content}/> | |
); | |
case 'system': | |
case 'system_hint': | |
return ( | |
<ChatMessage key={msg.id} type="system"> | |
<p> | |
<TextWithTooltips text={msg.content}/> | |
</p> | |
{msg.hint && ( | |
<p className={styles.hintText}> | |
<TextWithTooltips text={msg.hint}/> | |
</p> | |
)} | |
</ChatMessage> | |
); | |
case 'summary_button': | |
return ( | |
<div key={msg.id} className={styles.summaryButtonContainer}> | |
<button onClick={handleGoToSummary} className={styles.summaryButton} disabled={isSummarizing}> | |
{isSummarizing ? 'Generating Summary...' : 'Go to case review and summary'} | |
</button> | |
</div> | |
); | |
default: | |
return null; | |
} | |
}) | |
)} | |
</div> | |
<div className={styles.disclaimerBox}> | |
<IconWarning className={styles.disclaimerIcon}/> | |
<p className={styles.disclaimerText}>This demonstration is for illustrative purposes of MedGemma’s | |
baseline capabilities only. It does not represent a finished or approved product, is not intended to | |
diagnose or suggest treatment of any disease or condition, and should not be used for medical | |
advice.</p> | |
</div> | |
</div> | |
</div> | |
</div> | |
</div> | |
); | |
}; | |
export default ChatScreen; |