File size: 750 Bytes
183e719
60532a1
183e719
 
 
 
 
 
 
 
 
 
 
60532a1
183e719
3536fc0
 
 
183e719
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from pydantic_ai.models import Model, KnownModelName
from knowlang.core.types import ModelProvider
from typing import get_args

def create_pydantic_model(
    model_provider: ModelProvider,
    model_name: str,
) -> Model | KnownModelName:
    model_str = f"{model_provider}:{model_name}"

    if model_str in get_args(KnownModelName):
        return model_str
    elif model_provider == ModelProvider.HUGGINGFACE:
        from knowlang.models.huggingface import HuggingFaceModel
        return HuggingFaceModel(model_name=model_name)
    elif model_provider == ModelProvider.TESTING:
        # should be used for testing purposes only
        pass
    else:
        raise NotImplementedError(f"Model {model_provider}:{model_name} is not supported")