Spaces:
Runtime error
Runtime error
File size: 2,717 Bytes
bb59984 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from typing import Any, Dict
import comet_llm
from langchain.callbacks.base import BaseCallbackHandler
from financial_bot import constants
class CometLLMMonitoringHandler(BaseCallbackHandler):
"""
A callback handler for monitoring LLM models using Comet.ml.
Args:
project_name (str): The name of the Comet.ml project to log to.
llm_model_id (str): The ID of the LLM model to use for inference.
llm_qlora_model_id (str): The ID of the PEFT model to use for inference.
llm_inference_max_new_tokens (int): The maximum number of new tokens to generate during inference.
llm_inference_temperature (float): The temperature to use during inference.
"""
def __init__(
self,
project_name: str = None,
llm_model_id: str = constants.LLM_MODEL_ID,
llm_qlora_model_id: str = constants.LLM_QLORA_CHECKPOINT,
llm_inference_max_new_tokens: int = constants.LLM_INFERNECE_MAX_NEW_TOKENS,
llm_inference_temperature: float = constants.LLM_INFERENCE_TEMPERATURE,
):
self._project_name = project_name
self._llm_model_id = llm_model_id
self._llm_qlora_model_id = llm_qlora_model_id
self._llm_inference_max_new_tokens = llm_inference_max_new_tokens
self._llm_inference_temperature = llm_inference_temperature
def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> None:
"""
A callback function that logs the prompt and output to Comet.ml.
Args:
outputs (Dict[str, Any]): The output of the LLM model.
**kwargs (Any): Additional arguments passed to the function.
"""
should_log_prompt = "metadata" in kwargs
if should_log_prompt:
metadata = kwargs["metadata"]
comet_llm.log_prompt(
project=self._project_name,
prompt=metadata["prompt"],
output=outputs["answer"],
prompt_template=metadata["prompt_template"],
prompt_template_variables=metadata["prompt_template_variables"],
metadata={
"usage.prompt_tokens": metadata["usage.prompt_tokens"],
"usage.total_tokens": metadata["usage.total_tokens"],
"usage.max_new_tokens": self._llm_inference_max_new_tokens,
"usage.temperature": self._llm_inference_temperature,
"usage.actual_new_tokens": metadata["usage.actual_new_tokens"],
"model": self._llm_model_id,
"peft_model": self._llm_qlora_model_id,
},
duration=metadata["duration_milliseconds"],
)
|