File size: 7,098 Bytes
fe8dcb5 59be457 b462f85 fe8dcb5 59be457 fe8dcb5 0a1b314 fe8dcb5 0a1b314 fe8dcb5 59be457 fe8dcb5 b462f85 59be457 fe8dcb5 b462f85 fe8dcb5 b462f85 fe8dcb5 0a1b314 b462f85 0a1b314 b462f85 59be457 b462f85 59be457 b462f85 59be457 0a1b314 59be457 0a1b314 59be457 b462f85 59be457 b462f85 59be457 b462f85 59be457 0a1b314 59be457 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 |
import abc
import os
from dataclasses import field
from typing import Any, Dict, List, Literal, Optional, Union
from .artifact import Artifact
from .operator import PackageRequirementsMixin
class InferenceEngine(abc.ABC, Artifact):
"""Abstract base class for inference."""
@abc.abstractmethod
def _infer(self, dataset):
"""Perform inference on the input dataset."""
pass
def infer(self, dataset):
"""Verifies instances of a dataset and performs inference."""
[self.verify_instance(instance) for instance in dataset]
return self._infer(dataset)
class HFPipelineBasedInferenceEngine(InferenceEngine, PackageRequirementsMixin):
model_name: str
max_new_tokens: int
use_fp16: bool = True
_requirement = {
"transformers": "Install huggingface package using 'pip install --upgrade transformers"
}
def prepare(self):
import torch
from transformers import AutoConfig, pipeline
model_args: Dict[str, Any] = (
{"torch_dtype": torch.float16} if self.use_fp16 else {}
)
model_args.update({"max_new_tokens": self.max_new_tokens})
device = torch.device(
"mps"
if torch.backends.mps.is_available()
else 0
if torch.cuda.is_available()
else "cpu"
)
# We do this, because in some cases, using device:auto will offload some weights to the cpu
# (even though the model might *just* fit to a single gpu), even if there is a gpu available, and this will
# cause an error because the data is always on the gpu
if torch.cuda.device_count() > 1:
assert device == torch.device(0)
model_args.update({"device_map": "auto"})
else:
model_args.update({"device": device})
task = (
"text2text-generation"
if AutoConfig.from_pretrained(
self.model_name, trust_remote_code=True
).is_encoder_decoder
else "text-generation"
)
if task == "text-generation":
model_args.update({"return_full_text": False})
self.model = pipeline(
model=self.model_name, trust_remote_code=True, **model_args
)
def _infer(self, dataset):
outputs = []
for output in self.model([instance["source"] for instance in dataset]):
if isinstance(output, list):
output = output[0]
outputs.append(output["generated_text"])
return outputs
class MockInferenceEngine(InferenceEngine):
model_name: str
def prepare(self):
return
def _infer(self, dataset):
return ["[[10]]" for instance in dataset]
class IbmGenAiInferenceEngineParams(Artifact):
decoding_method: Optional[Literal["greedy", "sample"]] = None
max_new_tokens: Optional[int] = None
min_new_tokens: Optional[int] = None
random_seed: Optional[int] = None
repetition_penalty: Optional[float] = None
stop_sequences: Optional[List[str]] = None
temperature: Optional[float] = None
top_k: Optional[int] = None
top_p: Optional[float] = None
typical_p: Optional[float] = None
class IbmGenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
label: str = "ibm_genai"
model_name: str
parameters: IbmGenAiInferenceEngineParams = field(
default_factory=IbmGenAiInferenceEngineParams
)
_requirement = {
"genai": "Install ibm-genai package using 'pip install --upgrade ibm-generative-ai"
}
data_classification_policy = ["public", "proprietary"]
def prepare(self):
from genai import Client, Credentials
api_key_env_var_name = "GENAI_KEY"
api_key = os.environ.get(api_key_env_var_name)
assert api_key is not None, (
f"Error while trying to run IbmGenAiInferenceEngine."
f" Please set the environment param '{api_key_env_var_name}'."
)
api_endpoint = os.environ.get("GENAI_KEY")
credentials = Credentials(api_key=api_key, api_endpoint=api_endpoint)
self.client = Client(credentials=credentials)
def _infer(self, dataset):
from genai.schema import TextGenerationParameters
genai_params = TextGenerationParameters(
max_new_tokens=self.parameters.max_new_tokens,
min_new_tokens=self.parameters.min_new_tokens,
random_seed=self.parameters.random_seed,
repetition_penalty=self.parameters.repetition_penalty,
stop_sequences=self.parameters.stop_sequences,
temperature=self.parameters.temperature,
top_p=self.parameters.top_p,
top_k=self.parameters.top_k,
typical_p=self.parameters.typical_p,
decoding_method=self.parameters.decoding_method,
)
return list(
self.client.text.generation.create(
model_id=self.model_name,
inputs=[instance["source"] for instance in dataset],
parameters=genai_params,
)
)
class OpenAiInferenceEngineParams(Artifact):
frequency_penalty: Optional[float] = None
presence_penalty: Optional[float] = None
max_tokens: Optional[int] = None
seed: Optional[int] = None
stop: Union[Optional[str], List[str]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
class OpenAiInferenceEngine(InferenceEngine, PackageRequirementsMixin):
label: str = "openai"
model_name: str
parameters: OpenAiInferenceEngineParams = field(
default_factory=OpenAiInferenceEngineParams
)
_requirement = {
"openai": "Install openai package using 'pip install --upgrade openai"
}
def prepare(self):
from openai import OpenAI
api_key_env_var_name = "OPENAI_API_KEY"
api_key = os.environ.get(api_key_env_var_name)
assert api_key is not None, (
f"Error while trying to run OpenAiInferenceEngine."
f" Please set the environment param '{api_key_env_var_name}'."
)
self.client = OpenAI(api_key=api_key)
def _infer(self, dataset):
return [
self.client.chat.completions.create(
messages=[
# {
# "role": "system",
# "content": self.system_prompt,
# },
{
"role": "user",
"content": instance["source"],
}
],
model=self.model_name,
frequency_penalty=self.parameters.frequency_penalty,
presence_penalty=self.parameters.presence_penalty,
max_tokens=self.parameters.max_tokens,
seed=self.parameters.seed,
stop=self.parameters.stop,
temperature=self.parameters.temperature,
top_p=self.parameters.top_p,
)
for instance in dataset
]
|