File size: 2,167 Bytes
170d9a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from PIL import Image
from pydantic import BaseModel

from medrag_multi_modal.assistant.llm_client import ClientType, LLMClient


class CalendarEvent(BaseModel):
    name: str
    date: str
    participants: list[str]


class ImageDescription(BaseModel):
    description: str


def test_openai_llm_client():
    llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI)
    event = llm_client.predict(
        system_prompt="Extract the event information",
        user_prompt="Alice and Bob are going to a science fair on Friday.",
        schema=CalendarEvent,
    )
    assert event.name.lower() == "science fair"
    assert event.date.lower() == "friday"
    assert [item.lower() for item in event.participants] == ["alice", "bob"]


def test_openai_image_description():
    llm_client = LLMClient(model_name="gpt-4o-mini", client_type=ClientType.OPENAI)
    description = llm_client.predict(
        system_prompt="Describe the image",
        user_prompt=[Image.open("./assets/test_image.png")],
        schema=ImageDescription,
    )
    assert "astronaut" in description.description.lower()


def test_google_llm_client():
    llm_client = LLMClient(
        model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI
    )
    event = llm_client.predict(
        system_prompt="Extract the event information",
        user_prompt="Alice and Bob are going to a science fair on Friday.",
        schema=CalendarEvent,
    )
    event = event[0] if isinstance(event, list) else event
    assert event["name"].lower() == "science fair"
    assert event["date"].lower() == "friday"
    assert [item.lower() for item in event["participants"]] == ["alice", "bob"]


def test_google_image_client():
    llm_client = LLMClient(
        model_name="gemini-1.5-flash-latest", client_type=ClientType.GEMINI
    )
    description = llm_client.predict(
        system_prompt="Describe the image",
        user_prompt=[Image.open("./assets/test_image.png")],
        schema=ImageDescription,
    )
    description = description[0] if isinstance(description, list) else description
    assert "astronaut" in description["description"].lower()