Spaces:
Runtime error
Runtime error
| from dependency_injector.resources import AsyncResource | |
| from httpx import AsyncClient | |
| from loguru import logger | |
| from openai import OpenAI | |
| from pydantic import ConfigDict, PrivateAttr | |
| from re import compile as compile_re, Pattern | |
| from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler | |
| from slack_bolt.async_app import AsyncApp | |
| from slack_sdk.errors import SlackApiError | |
| from slack_sdk.web.async_slack_response import AsyncSlackResponse | |
| from typing import Any, ClassVar, Mapping, MutableMapping, Optional, Self, Set | |
| from ctp_slack_bot.core import HealthReportingApplicationComponentBase, Settings | |
| from ctp_slack_bot.enums import EventType | |
| from ctp_slack_bot.models import SlackMessage, SlackResponse | |
| from .event_brokerage_service import EventBrokerageService | |
| class SlackService(HealthReportingApplicationComponentBase): | |
| """ | |
| Service for interfacing with Slack. | |
| """ | |
| model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True) | |
| _SLACK_USER_ID_PATTERN: ClassVar[Pattern] = compile_re(r"U\d+") | |
| _SLACK_USER_MENTION_PATTERN: ClassVar[Pattern] = compile_re(r"<@(U[A-Z0-9]+)>") | |
| settings: Settings | |
| event_brokerage_service: EventBrokerageService | |
| http_client: AsyncClient | |
| slack_bolt_app: AsyncApp | |
| socket_mode_handler: AsyncSocketModeHandler | |
| _id_name_map: MutableMapping[str, str] = PrivateAttr(default={}) # TODO: Spin message processing out into its own service. | |
| async def initialize(self: Self) -> None: | |
| async def get_users_list(): | |
| cursor = None | |
| while True: | |
| try: | |
| response = await self.slack_bolt_app.client.users_list(cursor=cursor, limit=200) | |
| except SlackApiError as e: | |
| logger.warning("Could not get a list of users: {}", e) | |
| break | |
| match response: | |
| case AsyncSlackResponse(status_code=200, data={"ok": True, "members": users}): | |
| for user in users: | |
| yield user | |
| match response.data: | |
| case {"response_metadata": {"next_cursor": cursor}} if cursor: | |
| continue | |
| case AsyncSlackResponse(status_code=status_code) if status_code != 200: | |
| logger.warning("Could not get a list of users: response status {}", status_code) | |
| case AsyncSlackResponse(data={"ok": False}): | |
| logger.warning("Could not get a list of users: non-OK response") | |
| case _: | |
| logger.warning("Could not get a list of users.") | |
| break | |
| id_name_map = {user["id"]: self._resolve_user_name(user) | |
| async for user | |
| in get_users_list()} | |
| self._id_name_map.update(id_name_map) | |
| logger.debug("Obtained a list of {} user name(s) for the workspace: {}", len(id_name_map), id_name_map) | |
| await self.event_brokerage_service.subscribe(EventType.OUTGOING_SLACK_RESPONSE, self.send_message) | |
| self.slack_bolt_app.event("message")(self._handle_message_event) | |
| self.slack_bolt_app.event("app_mention")(self._handle_app_mention_event) | |
| logger.debug("Registered 2 handlers for Slack Bolt message and app mention events.") | |
| async def start(self: Self) -> None: | |
| await self.socket_mode_handler.start_async() | |
| async def stop(self: Self) -> None: | |
| await self.socket_mode_handler.close_async() | |
| logger.info("Stopped Slack Bolt socket mode handler and Slack service.") | |
| async def send_message(self: Self, message: SlackResponse) -> None: | |
| await self.slack_bolt_app.client.chat_postMessage(channel=message.channel, text=message.text, thread_ts=message.thread_ts) | |
| def name(self: Self) -> str: | |
| return "slack_service" | |
| async def is_healthy(self: Self) -> bool: | |
| response = await self.http_client.get("https://slack-status.com/api/v2.0.0/current") | |
| return response.status_code == 200 | |
| async def _adapt_event_payload(self: Self, event: Mapping[str, Any]) -> SlackMessage: | |
| text = await self._resolve_user_mentions(event.get("text", "")) | |
| user_id = event.get("user") | |
| return SlackMessage( | |
| type=event.get("type"), | |
| subtype=event.get("subtype"), | |
| channel=event.get("channel"), | |
| channel_type=event.get("channel_type"), | |
| user=await self._get_name(user_id), | |
| bot_id=event.get("bot_id"), | |
| thread_ts=event.get("thread_ts"), | |
| text=text, | |
| ts=event.get("ts"), | |
| event_ts=event.get("event_ts") | |
| ) | |
| async def _ensure_ids_in_id_name_map(self: Self, ids: Set[str]) -> None: | |
| unknown_ids = ids - self._id_name_map.keys() | |
| if len(unknown_ids) == 0: | |
| return | |
| async with TaskGroup() as task_group: | |
| update_tasks = {unknown_id: task_group.create_task(self._look_up_name(unknown_id)) for unknown_id in unknown_ids} | |
| self._id_name_map.update({id: task.result() for id, task in update_tasks.items() if task.result()}) | |
| async def _get_name(self: Self, id: str) -> str: | |
| await self._ensure_ids_in_id_name_map({id}) | |
| return self._id_name_map.get(id, id) | |
| async def _handle_message_event(self: Self, body: Mapping[str, Any]) -> None: | |
| logger.debug("Ignored regular message: {}", body.get("event", {}).get("text")) | |
| # await self._process_message(body) | |
| async def _handle_app_mention_event(self: Self, body: Mapping[str, Any]) -> None: | |
| logger.debug("Received app mention for processing: {}", body.get("event", {}).get("text")) | |
| await self._process_message(body) | |
| async def _look_up_name(self: Self, id: str) -> Optional[str]: | |
| if self._SLACK_USER_ID_PATTERN.fullmatch(id): | |
| match await self.slack_bolt_app.client.users_info(id): | |
| case AsyncSlackResponse(data={"ok": True, "user": user}): | |
| match user: | |
| case {"real_name": real_name}: | |
| return real_name | |
| case {"profile": {"display_name": display_name}}: | |
| return display_name | |
| case {"name": name}: | |
| return name | |
| case AsyncSlackResponse(data={"ok": False, "error": "user_not_found"}): | |
| logger.error("An attempt to look up a user failed (user not found): {}", id) | |
| return None | |
| async def _process_message(self: Self, event: Mapping[str, Any]) -> None: | |
| slack_message = await self._adapt_event_payload(event.get("event", {})) | |
| logger.debug("Received message from Slack: {}", slack_message) | |
| await self.event_brokerage_service.publish(EventType.INCOMING_SLACK_MESSAGE, slack_message) | |
| async def _resolve_user_mentions(self: Self, text: str) -> str: | |
| matches = tuple(self._SLACK_USER_MENTION_PATTERN.finditer(text)) | |
| unique_ids = frozenset(match.group(1) for match in matches) | |
| await self._ensure_ids_in_id_name_map(unique_ids) | |
| parts = [] | |
| previous_end = 0 | |
| for match in matches: | |
| start, end = match.span() | |
| parts.append(text[previous_end:start]) | |
| user_id = match.group(1) | |
| parts.append(f"@{self._id_name_map.get(user_id, user_id)}") | |
| previous_end = end | |
| parts.append(text[previous_end:]) | |
| return ''.join(parts) | |
| def _resolve_user_name(user: Mapping[str, Any]): | |
| match user: | |
| case {"real_name": real_name}: | |
| return real_name | |
| case {"profile": {"display_name": display_name}}: | |
| return display_name | |
| case {"name": name}: | |
| return name | |
| class SlackServiceResource(AsyncResource): | |
| async def init(self: Self, settings: Settings, event_brokerage_service: EventBrokerageService, http_client: AsyncClient) -> SlackService: | |
| slack_bolt_app = AsyncApp(token=settings.slack_bot_token.get_secret_value()) | |
| socket_mode_handler = AsyncSocketModeHandler(slack_bolt_app, settings.slack_app_token.get_secret_value()) | |
| slack_service = SlackService(settings=settings, event_brokerage_service=event_brokerage_service, http_client=http_client, slack_bolt_app=slack_bolt_app, socket_mode_handler=socket_mode_handler) | |
| await slack_service.initialize() | |
| return slack_service | |