Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	File size: 6,845 Bytes
			
			| 870e141 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | from utils import get_relevant_history, get_embedding
import torch
from LLM.base_LLM import *
from Memory import Memory
from Prompt import * 
import json
class Environment:
    """
    The place where the agent activities, responsible for storing some shared memories
    """
    def __init__(self, config) -> None:
        self.shared_memory = {"long_term_memory": [], "short_term_memory": None}
        self.agents = None
        self.summary_system_prompt = {}
        self.summary_last_prompt = {}
        self.environment_prompt = {}
        self.environment_type = config["environment_type"] if "environment_type" in config else "cooperative"
        self.current_chat_history_idx = 0
        self.LLMs = {}
        
        # 初始化每个state 的summary 方法
        # Initialize the summary method for each state
        for state_name, state_dict in config["states"].items():
            if state_name != "end_state":
                self.summary_system_prompt[state_name] = (
                    state_dict["summary_system_prompt"]
                    if "summary_system_prompt" in state_dict
                    else eval(Default_environment_summary_system_prompt)
                )
                self.summary_last_prompt[state_name] = (
                    state_dict["summary_last_prompt"]
                    if "summary_last_prompt" in state_dict
                    else eval(Default_environment_summary_last_prompt)
                )
                self.environment_prompt[state_name] = (
                    state_dict["environment_prompt"]
                    if "environment_prompt" in state_dict
                    else " "
                )
                self.LLMs[state_name] = init_LLM(f"logs/{state_name}",**state_dict)
        self.roles_to_names = None
        self.names_to_roles = None
    @classmethod
    def from_config(cls, config_path):
        with open(config_path) as f:
            config = json.load(f)
        return cls(config)
    def summary(self, current_state):
        """
        Summarize the situation in the current environment every once in a while
        """
        MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
        current_state_name = current_state.name
        query = self.shared_memory["long_term_memory"][-1].content
        relevant_history = get_relevant_history(
            query,
            self.shared_memory["long_term_memory"][:-1],
            self.shared_memory["chat_embeddings"][:-1],
        )
        relevant_history = Memory.get_chat_history(relevant_history)
        chat_history = Memory.get_chat_history(
            self.shared_memory["long_term_memory"][-MAX_CHAT_HISTORY + 1 :]
        )
        summary = self.shared_memory["short_term_memory"]
        
        
        # system prompt = environment prompt + current memory + system prompt
        # current_memory = summary + chat history + relevant history
        current_memory = eval(Environment_summary_memory)
        environment_prompt = self.environment_prompt[current_state_name]
        summary_system_prompt = self.summary_system_prompt[current_state_name]
        
        environment_summary_system_prompt = eval(Environment_summary_system_prompt)
        response = self.LLMs[current_state_name].get_response(None, environment_summary_system_prompt, stream=False)
        return response
    def update_memory(self, memory, current_state):
        """
        update chat embbedings and long term memory,short term memory,agents long term memory
        """
        MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
        self.shared_memory["long_term_memory"].append(memory)
        current_embedding = get_embedding(memory.content)
        if "chat_embeddings" not in self.shared_memory:
            self.shared_memory["chat_embeddings"] = current_embedding
        else:
            self.shared_memory["chat_embeddings"] = torch.cat(
                [self.shared_memory["chat_embeddings"], current_embedding], dim=0
            )
        if len(self.shared_memory["long_term_memory"]) % MAX_CHAT_HISTORY == 0:
            summary = self.summary(current_state)
            self.shared_memory["short_term_memory"] = summary
        self.agents[memory.send_name].update_memory(memory)
    
    
    def _get_agent_last_conversation_idx(self,agent,current_long_term_memory):
        last_conversation_idx = -1
        for i, history in enumerate(current_long_term_memory):
            if history.send_name == agent.name:
                last_conversation_idx = i
        return last_conversation_idx
    
    
    def _get_agent_new_memory(self,agent,current_long_term_memory):
        # get new conversation
        last_conversation_idx = self._get_agent_last_conversation_idx(agent,current_long_term_memory)
        if last_conversation_idx == -1:
            new_conversation =current_long_term_memory
        elif (
            last_conversation_idx
            == len(current_long_term_memory) - 1
        ):
            new_conversation = []
        else:
            new_conversation = current_long_term_memory[
                last_conversation_idx + 1 :
            ]
        # get chat history from new conversation
        return Memory.get_chat_history(new_conversation)
    
    
    def _observe(self,agent):
        MAX_CHAT_HISTORY = eval(os.environ["MAX_CHAT_HISTORY"])
        current_state = agent.current_state
        current_role = agent.state_roles[current_state.name]
        current_component_dict = current_state.components[current_role]
        
        # cooperative:Sharing information between different states ;  competive: No information is shared between different states
        current_chat_history_idx = self.current_chat_history_idx if self.environment_type == "competive" else 0
        current_long_term_memory = self.shared_memory["long_term_memory"][current_chat_history_idx:]
        current_chat_embbedings = self.shared_memory["chat_embeddings"][current_chat_history_idx:]
            
        
        # relevant_memory
        query = current_long_term_memory[-1].content
        relevant_memory = get_relevant_history(
            query,
            current_long_term_memory[:-1],
            current_chat_embbedings[:-1],
        )
        relevant_memory = Memory.get_chat_history(relevant_memory,agent.name)
        
        relevant_memory = eval(Agent_observe_relevant_memory)
        agent.relevant_memory = relevant_memory
        
        
        # get chat history from new conversation
        conversations = self._get_agent_new_memory(agent,current_long_term_memory)
        # memory = relevant_memory + summary + history + query
        query = current_long_term_memory[-1]
        current_memory = eval(Agent_observe_memory)
        return {"role": "user", "content": current_memory}
 | 
