Spaces:
Running
Running
File size: 2,167 Bytes
170d9a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from PIL import Image
from pydantic import BaseModel
from medrag_multi_modal.assistant.llm_client import ClientType, LLMClient
class CalendarEvent(BaseModel):
name: str
date: str
participants: list[str]
class ImageDescription(BaseModel):
description: str
def test_openai_llm_client():
llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI)
event = llm_client.predict(
system_prompt="Extract the event information",
user_prompt="Alice and Bob are going to a science fair on Friday.",
schema=CalendarEvent,
)
assert event.name.lower() == "science fair"
assert event.date.lower() == "friday"
assert [item.lower() for item in event.participants] == ["alice", "bob"]
def test_openai_image_description():
llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI)
description = llm_client.predict(
system_prompt="Describe the image",
user_prompt=[Image.open("./assets/test_image.png")],
schema=ImageDescription,
)
assert "astronaut" in description.description.lower()
def test_google_llm_client():
llm_client = LLMClient(
model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI
)
event = llm_client.predict(
system_prompt="Extract the event information",
user_prompt="Alice and Bob are going to a science fair on Friday.",
schema=CalendarEvent,
)
event = event[0] if isinstance(event, list) else event
assert event["name"].lower() == "science fair"
assert event["date"].lower() == "friday"
assert [item.lower() for item in event["participants"]] == ["alice", "bob"]
def test_google_image_client():
llm_client = LLMClient(
model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI
)
description = llm_client.predict(
system_prompt="Describe the image",
user_prompt=[Image.open("./assets/test_image.png")],
schema=ImageDescription,
)
description = description[0] if isinstance(description, list) else description
assert "astronaut" in description["description"].lower()
|