Spaces:
Runtime error
Runtime error
File size: 4,974 Bytes
8d64162 7934a8e 8d64162 ceaeef3 8d64162 ceaeef3 8d64162 ceaeef3 8d64162 49cde8e 01ed12d 8d64162 01ed12d 8d64162 01ed12d 8d64162 35cb430 8d64162 35cb430 8d64162 35cb430 8d64162 7934a8e 8d64162 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import os
from enum import Enum
from typing import Any, Optional, Union
import instructor
import weave
from PIL import Image
from ..utils import base64_encode_image
class ClientType(str, Enum):
GEMINI = "gemini"
MISTRAL = "mistral"
GOOGLE_MODELS = [
"gemini-1.0-pro-latest",
"gemini-1.0-pro",
"gemini-pro",
"gemini-1.0-pro-001",
"gemini-1.0-pro-vision-latest",
"gemini-pro-vision",
"gemini-1.5-pro-latest",
"gemini-1.5-pro-001",
"gemini-1.5-pro-002",
"gemini-1.5-pro",
"gemini-1.5-pro-exp-0801",
"gemini-1.5-pro-exp-0827",
"gemini-1.5-flash-latest",
"gemini-1.5-flash-001",
"gemini-1.5-flash-001-tuning",
"gemini-1.5-flash",
"gemini-1.5-flash-exp-0827",
"gemini-1.5-flash-002",
"gemini-1.5-flash-8b",
"gemini-1.5-flash-8b-001",
"gemini-1.5-flash-8b-latest",
"gemini-1.5-flash-8b-exp-0827",
"gemini-1.5-flash-8b-exp-0924",
]
MISTRAL_MODELS = [
"ministral-3b-latest",
"ministral-8b-latest",
"mistral-large-latest",
"mistral-small-latest",
"codestral-latest",
"pixtral-12b-2409",
"open-mistral-nemo",
"open-codestral-mamba",
"open-mistral-7b",
"open-mixtral-8x7b",
"open-mixtral-8x22b",
]
class LLMClient(weave.Model):
model_name: str
client_type: Optional[ClientType]
def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
if client_type is None:
if model_name in GOOGLE_MODELS:
client_type = ClientType.GEMINI
elif model_name in MISTRAL_MODELS:
client_type = ClientType.MISTRAL
else:
raise ValueError(f"Invalid model name: {model_name}")
super().__init__(model_name=model_name, client_type=client_type)
@weave.op()
def execute_gemini_sdk(
self,
user_prompt: Union[str, list[str]],
system_prompt: Optional[Union[str, list[str]]] = None,
schema: Optional[Any] = None,
) -> Union[str, Any]:
import google.generativeai as genai
system_prompt = (
[system_prompt] if isinstance(system_prompt, str) else system_prompt
)
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
genai.configure(api_key=os.environ.get("GOOGLE_API_KEY"))
model = genai.GenerativeModel(self.model_name)
generation_config = (
None
if schema is None
else genai.GenerationConfig(
response_mime_type="application/json", response_schema=list[schema]
)
)
response = model.generate_content(
system_prompt + user_prompt, generation_config=generation_config
)
return response.text if schema is None else response
@weave.op()
def execute_mistral_sdk(
self,
user_prompt: Union[str, list[str]],
system_prompt: Optional[Union[str, list[str]]] = None,
schema: Optional[Any] = None,
) -> Union[str, Any]:
from mistralai import Mistral
system_prompt = (
[system_prompt] if isinstance(system_prompt, str) else system_prompt
)
user_prompt = [user_prompt] if isinstance(user_prompt, str) else user_prompt
system_messages = [{"type": "text", "text": prompt} for prompt in system_prompt]
user_messages = []
for prompt in user_prompt:
if isinstance(prompt, Image.Image):
user_messages.append(
{
"type": "image_url",
"image_url": base64_encode_image(prompt, "image/png"),
}
)
else:
user_messages.append({"type": "text", "text": prompt})
messages = [
{"role": "system", "content": system_messages},
{"role": "user", "content": user_messages},
]
client = Mistral(api_key=os.environ.get("MISTRAL_API_KEY"))
client = instructor.from_mistral(client) if schema is not None else client
response = (
client.chat.complete(model=self.model_name, messages=messages)
if schema is None
else client.messages.create(
response_model=schema, messages=messages, temperature=0
)
)
return response.choices[0].message.content
@weave.op()
def predict(
self,
user_prompt: Union[str, list[str]],
system_prompt: Optional[Union[str, list[str]]] = None,
schema: Optional[Any] = None,
) -> Union[str, Any]:
if self.client_type == ClientType.GEMINI:
return self.execute_gemini_sdk(user_prompt, system_prompt, schema)
elif self.client_type == ClientType.MISTRAL:
return self.execute_mistral_sdk(user_prompt, system_prompt, schema)
else:
raise ValueError(f"Invalid client type: {self.client_type}")
|