| from models import * | |
| from utils import * | |
| from .extraction_agent import ExtractionAgent | |
| from .knowledge_base.case_repository import CaseRepositoryHandler | |
| class ReflectionGenerator: | |
| def __init__(self, llm: BaseEngine): | |
| self.llm = llm | |
| def get_reflection(self, instruction="", examples="", text="",schema="", result=""): | |
| result = json.dumps(result) | |
| examples = bad_case_wrapper(examples) | |
| prompt = reflect_instruction.format(instruction=instruction, examples=examples, text=text, schema=schema, result=result) | |
| response = self.llm.get_chat_response(prompt) | |
| response = extract_json_dict(response) | |
| return response | |
| class ReflectionAgent: | |
| def __init__(self, llm: BaseEngine, case_repo: CaseRepositoryHandler): | |
| self.llm = llm | |
| self.module = ReflectionGenerator(llm = llm) | |
| self.extractor = ExtractionAgent(llm = llm, case_repo = case_repo) | |
| self.case_repo = case_repo | |
| self.methods = ["reflect_with_case"] | |
| def __select_result(self, result_list): | |
| dict_objects = [obj for obj in result_list if isinstance(obj, dict)] | |
| if dict_objects: | |
| selected_obj = max(dict_objects, key=lambda d: len(json.dumps(d))) | |
| else: | |
| selected_obj = max(result_list, key=lambda o: len(json.dumps(o))) | |
| return selected_obj | |
| def __self_consistance_check(self, data: DataPoint): | |
| extract_func = list(data.result_trajectory.keys())[-1] | |
| if hasattr(self.extractor, extract_func): | |
| result_trails = [] | |
| result_trails.append(data.result_list) | |
| extract_func = getattr(self.extractor, extract_func) | |
| temperature = [0.5, 1] | |
| for index in range(2): | |
| self.module.llm.set_hyperparameter(temperature=temperature[index]) | |
| data = extract_func(data) | |
| result_trails.append(data.result_list) | |
| self.module.llm.set_hyperparameter() | |
| consistant_result = [] | |
| reflect_index = [] | |
| for index, elements in enumerate(zip(*result_trails)): | |
| normalized_elements = [normalize_obj(e) for e in elements] | |
| element_counts = Counter(normalized_elements) | |
| selected_element = next((elements[i] for i, element in enumerate(normalized_elements) | |
| if element_counts[element] >= 2), None) | |
| if selected_element is None: | |
| selected_element = self.__select_result(elements) | |
| reflect_index.append(index) | |
| consistant_result.append(selected_element) | |
| data.set_result_list(consistant_result) | |
| return reflect_index | |
| def reflect_with_case(self, data: DataPoint): | |
| if data.result_list == []: | |
| return data | |
| reflect_index = self.__self_consistance_check(data) | |
| reflected_result_list = data.result_list | |
| for idx in reflect_index: | |
| text = data.chunk_text_list[idx] | |
| result = data.result_list[idx] | |
| examples = json.dumps(self.case_repo.query_bad_case(data)) | |
| reflected_res = self.module.get_reflection(instruction=data.instruction, examples=examples, text=text, schema=data.output_schema, result=result) | |
| reflected_result_list[idx] = reflected_res | |
| data.set_result_list(reflected_result_list) | |
| function_name = current_function_name() | |
| data.update_trajectory(function_name, data.result_list) | |
| return data | |