Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| @Time : 2023/5/8 22:12 | |
| @Author : alexanderwu | |
| @File : schema.py | |
| @Desc : mashenquan, 2023/8/22. Add tags to enable custom message classification. | |
| """ | |
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from typing import Type, TypedDict, Set, Optional, List | |
| from pydantic import BaseModel | |
| from metagpt.logs import logger | |
| class MessageTag(Enum): | |
| Prerequisite = "prerequisite" | |
| class RawMessage(TypedDict): | |
| content: str | |
| role: str | |
| class Message: | |
| """list[<role>: <content>]""" | |
| content: str | |
| instruct_content: BaseModel = field(default=None) | |
| role: str = field(default='user') # system / user / assistant | |
| cause_by: Type["Action"] = field(default="") | |
| sent_from: str = field(default="") | |
| send_to: str = field(default="") | |
| tags: Optional[Set] = field(default=None) | |
| def __str__(self): | |
| # prefix = '-'.join([self.role, str(self.cause_by)]) | |
| return f"{self.role}: {self.content}" | |
| def __repr__(self): | |
| return self.__str__() | |
| def to_dict(self) -> dict: | |
| return { | |
| "role": self.role, | |
| "content": self.content | |
| } | |
| def add_tag(self, tag): | |
| if self.tags is None: | |
| self.tags = set() | |
| self.tags.add(tag) | |
| def remove_tag(self, tag): | |
| if self.tags is None or tag not in self.tags: | |
| return | |
| self.tags.remove(tag) | |
| def is_contain_tags(self, tags: list) -> bool: | |
| """Determine whether the message contains tags.""" | |
| if not tags or not self.tags: | |
| return False | |
| intersection = set(tags) & self.tags | |
| return len(intersection) > 0 | |
| def is_contain(self, tag): | |
| return self.is_contain_tags([tag]) | |
| def dict(self): | |
| """pydantic-like `dict` function""" | |
| full = { | |
| "instruct_content": self.instruct_content, | |
| "sent_from": self.sent_from, | |
| "send_to": self.send_to, | |
| "tags": self.tags | |
| } | |
| m = {"content": self.content} | |
| for k, v in full.items(): | |
| if v: | |
| m[k] = v | |
| return m | |
| class UserMessage(Message): | |
| """便于支持OpenAI的消息 | |
| Facilitate support for OpenAI messages | |
| """ | |
| def __init__(self, content: str): | |
| super().__init__(content, 'user') | |
| class SystemMessage(Message): | |
| """便于支持OpenAI的消息 | |
| Facilitate support for OpenAI messages | |
| """ | |
| def __init__(self, content: str): | |
| super().__init__(content, 'system') | |
| class AIMessage(Message): | |
| """便于支持OpenAI的消息 | |
| Facilitate support for OpenAI messages | |
| """ | |
| def __init__(self, content: str): | |
| super().__init__(content, 'assistant') | |
| if __name__ == '__main__': | |
| test_content = 'test_message' | |
| msgs = [ | |
| UserMessage(test_content), | |
| SystemMessage(test_content), | |
| AIMessage(test_content), | |
| Message(test_content, role='QA') | |
| ] | |
| logger.info(msgs) | |