from PIL import Image from pydantic import BaseModel from medrag_multi_modal.assistant.llm_client import ClientType, LLMClient class CalendarEvent(BaseModel): name: str date: str participants: list[str] class ImageDescription(BaseModel): description: str def test_openai_llm_client(): llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI) event = llm_client.predict( system_prompt="Extract the event information", user_prompt="Alice and Bob are going to a science fair on Friday.", schema=CalendarEvent, ) assert event.name.lower() == "science fair" assert event.date.lower() == "friday" assert [item.lower() for item in event.participants] == ["alice", "bob"] def test_openai_image_description(): llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI) description = llm_client.predict( system_prompt="Describe the image", user_prompt=[Image.open("./assets/test_image.png")], schema=ImageDescription, ) assert "astronaut" in description.description.lower() def test_google_llm_client(): llm_client = LLMClient( model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI ) event = llm_client.predict( system_prompt="Extract the event information", user_prompt="Alice and Bob are going to a science fair on Friday.", schema=CalendarEvent, ) event = event[0] if isinstance(event, list) else event assert event["name"].lower() == "science fair" assert event["date"].lower() == "friday" assert [item.lower() for item in event["participants"]] == ["alice", "bob"] def test_google_image_client(): llm_client = LLMClient( model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI ) description = llm_client.predict( system_prompt="Describe the image", user_prompt=[Image.open("./assets/test_image.png")], schema=ImageDescription, ) description = description[0] if isinstance(description, list) else description assert "astronaut" in description["description"].lower()