Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from pathlib import Path | |
| import os | |
| from string import Formatter | |
| from typing import Any, Dict, List, Union | |
| from pydantic import model_validator | |
| from ....utils.registry import registry | |
| from .base import (DEFAULT_FORMATTER_MAPPING, BasePromptTemplate, | |
| _get_jinja2_variables_from_template, check_valid_template) | |
| class PromptTemplate(BasePromptTemplate): | |
| """Schema to represent a prompt for an LLM. | |
| Example: | |
| .. code-block:: python | |
| prompt = PromptTemplate(input_variables=["foo"], template="Say {foo}") | |
| """ | |
| # input_variables: List[str] | |
| template: str | |
| """The prompt template.""" | |
| template_format: str = "jinja2" | |
| """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" | |
| validate_template: bool = True | |
| """Whether or not to try validating the template.""" | |
| role: str = "user" | |
| class Config: | |
| """Configuration for this pydantic object.""" | |
| extra = "allow" | |
| def __init__(self, **kwargs: Any): | |
| super().__init__(**kwargs) | |
| input_variables = kwargs.get("input_variables", []) | |
| pre_filled_kv = {key: kwargs[key] for key in input_variables if key in kwargs.keys()} | |
| if pre_filled_kv: | |
| self.template = self.format(**pre_filled_kv) | |
| input_variables = list(set(input_variables) - set(pre_filled_kv.keys())) | |
| self.input_variables = input_variables | |
| def format(self, **kwargs: Any) -> str: | |
| """Format the prompt with the inputs. | |
| Args: | |
| kwargs: Any arguments to be passed to the prompt template. | |
| Returns: | |
| A formatted string. | |
| Example: | |
| .. code-block:: python | |
| prompt.format(variable1="foo") | |
| """ | |
| kwargs = self._merge_partial_and_user_variables(**kwargs) | |
| return DEFAULT_FORMATTER_MAPPING[self.template_format].format( | |
| self.template, **kwargs | |
| ) | |
| def template_is_valid(self) -> "PromptTemplate": | |
| """Check that template and input variables are consistent.""" | |
| if self.validate_template: | |
| all_inputs = self.input_variables + list(self.partial_variables) | |
| check_valid_template(self.template, self.template_format, all_inputs) | |
| return self | |
| def from_examples( | |
| cls, | |
| examples: List[str], | |
| suffix: str, | |
| input_variables: List[str], | |
| example_separator: str = "\n\n", | |
| prefix: str = "", | |
| **kwargs: Any, | |
| ) -> PromptTemplate: | |
| """Take examples in list format with prefix and suffix to create a prompt. | |
| Intended to be used as a way to dynamically create a prompt from examples. | |
| Args: | |
| examples: List of examples to use in the prompt. | |
| suffix: String to go after the list of examples. Should generally | |
| set up the user's input. | |
| input_variables: A list of variable names the final prompt template | |
| will expect. | |
| example_separator: The separator to use in between examples. Defaults | |
| to two new line characters. | |
| prefix: String that should go before any examples. Generally includes | |
| examples. Default to an empty string. | |
| Returns: | |
| The final prompt generated. | |
| """ | |
| template = example_separator.join([prefix, *examples, suffix]) | |
| return cls(input_variables=input_variables, template=template, **kwargs) | |
| def find_file(cls, start_dir, file_name): | |
| for root, dirs, files in os.walk(start_dir): | |
| if file_name in files: | |
| return os.path.join(root, file_name) | |
| return None | |
| def from_file( | |
| cls, template_file: Union[str, Path], **kwargs: Any | |
| ) -> PromptTemplate: | |
| """Load a prompt from a file. | |
| Args: | |
| template_file: The path to the file containing the prompt template. | |
| input_variables: A list of variable names the final prompt template | |
| will expect. | |
| Returns: | |
| The prompt loaded from the file. | |
| """ | |
| original_file = template_file | |
| while True: | |
| if os.path.exists(template_file): | |
| with open(template_file, "r") as f: | |
| template = f.read() | |
| return cls.from_template(template=template, **kwargs) | |
| if "/" in template_file: | |
| template_file = "/".join(template_file.split("/", 1)[1:]) | |
| else: | |
| raise ValueError(f"the prompt file path ({original_file}) is not valid") | |
| def from_template( | |
| cls, template: str, template_format: str = "jinja2", **kwargs: Any | |
| ) -> PromptTemplate: | |
| """Load a prompt template from a template.""" | |
| if template_format == "jinja2": | |
| # Get the variables for the template | |
| input_variables = _get_jinja2_variables_from_template(template) | |
| else: | |
| input_variables = { | |
| v for _, v, _, _ in Formatter().parse(template) if v is not None | |
| } | |
| return cls( | |
| input_variables=list(sorted(input_variables)), | |
| template=template, | |
| template_format=template_format, | |
| **kwargs, | |
| ) | |
| def from_config(cls, config: Dict) -> PromptTemplate: | |
| """Load a prompt template from a config.""" | |
| template = config.pop("template") | |
| if template.endswith(".prompt"): | |
| return cls.from_file(template, **config) | |
| else: | |
| return cls.from_template(template, **config) | |