LiKenun's picture
Refactor #6
f0fe0fd
from dependency_injector.resources import AsyncResource
from httpx import AsyncClient
from loguru import logger
from openai import OpenAI
from pydantic import ConfigDict, PrivateAttr
from re import compile as compile_re, Pattern
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
from slack_bolt.async_app import AsyncApp
from slack_sdk.errors import SlackApiError
from slack_sdk.web.async_slack_response import AsyncSlackResponse
from typing import Any, ClassVar, Mapping, MutableMapping, Optional, Self, Set
from ctp_slack_bot.core import HealthReportingApplicationComponentBase, Settings
from ctp_slack_bot.enums import EventType
from ctp_slack_bot.models import SlackMessage, SlackResponse
from .event_brokerage_service import EventBrokerageService
class SlackService(HealthReportingApplicationComponentBase):
"""
Service for interfacing with Slack.
"""
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
_SLACK_USER_ID_PATTERN: ClassVar[Pattern] = compile_re(r"U\d+")
_SLACK_USER_MENTION_PATTERN: ClassVar[Pattern] = compile_re(r"<@(U[A-Z0-9]+)>")
settings: Settings
event_brokerage_service: EventBrokerageService
http_client: AsyncClient
slack_bolt_app: AsyncApp
socket_mode_handler: AsyncSocketModeHandler
_id_name_map: MutableMapping[str, str] = PrivateAttr(default={}) # TODO: Spin message processing out into its own service.
async def initialize(self: Self) -> None:
async def get_users_list():
cursor = None
while True:
try:
response = await self.slack_bolt_app.client.users_list(cursor=cursor, limit=200)
except SlackApiError as e:
logger.warning("Could not get a list of users: {}", e)
break
match response:
case AsyncSlackResponse(status_code=200, data={"ok": True, "members": users}):
for user in users:
yield user
match response.data:
case {"response_metadata": {"next_cursor": cursor}} if cursor:
continue
case AsyncSlackResponse(status_code=status_code) if status_code != 200:
logger.warning("Could not get a list of users: response status {}", status_code)
case AsyncSlackResponse(data={"ok": False}):
logger.warning("Could not get a list of users: non-OK response")
case _:
logger.warning("Could not get a list of users.")
break
id_name_map = {user["id"]: self._resolve_user_name(user)
async for user
in get_users_list()}
self._id_name_map.update(id_name_map)
logger.debug("Obtained a list of {} user name(s) for the workspace: {}", len(id_name_map), id_name_map)
await self.event_brokerage_service.subscribe(EventType.OUTGOING_SLACK_RESPONSE, self.send_message)
self.slack_bolt_app.event("message")(self._handle_message_event)
self.slack_bolt_app.event("app_mention")(self._handle_app_mention_event)
logger.debug("Registered 2 handlers for Slack Bolt message and app mention events.")
async def start(self: Self) -> None:
await self.socket_mode_handler.start_async()
async def stop(self: Self) -> None:
await self.socket_mode_handler.close_async()
logger.info("Stopped Slack Bolt socket mode handler and Slack service.")
async def send_message(self: Self, message: SlackResponse) -> None:
await self.slack_bolt_app.client.chat_postMessage(channel=message.channel, text=message.text, thread_ts=message.thread_ts)
@property
def name(self: Self) -> str:
return "slack_service"
async def is_healthy(self: Self) -> bool:
response = await self.http_client.get("https://slack-status.com/api/v2.0.0/current")
return response.status_code == 200
async def _adapt_event_payload(self: Self, event: Mapping[str, Any]) -> SlackMessage:
text = await self._resolve_user_mentions(event.get("text", ""))
user_id = event.get("user")
return SlackMessage(
type=event.get("type"),
subtype=event.get("subtype"),
channel=event.get("channel"),
channel_type=event.get("channel_type"),
user=await self._get_name(user_id),
bot_id=event.get("bot_id"),
thread_ts=event.get("thread_ts"),
text=text,
ts=event.get("ts"),
event_ts=event.get("event_ts")
)
async def _ensure_ids_in_id_name_map(self: Self, ids: Set[str]) -> None:
unknown_ids = ids - self._id_name_map.keys()
if len(unknown_ids) == 0:
return
async with TaskGroup() as task_group:
update_tasks = {unknown_id: task_group.create_task(self._look_up_name(unknown_id)) for unknown_id in unknown_ids}
self._id_name_map.update({id: task.result() for id, task in update_tasks.items() if task.result()})
async def _get_name(self: Self, id: str) -> str:
await self._ensure_ids_in_id_name_map({id})
return self._id_name_map.get(id, id)
async def _handle_message_event(self: Self, body: Mapping[str, Any]) -> None:
logger.debug("Ignored regular message: {}", body.get("event", {}).get("text"))
# await self._process_message(body)
async def _handle_app_mention_event(self: Self, body: Mapping[str, Any]) -> None:
logger.debug("Received app mention for processing: {}", body.get("event", {}).get("text"))
await self._process_message(body)
async def _look_up_name(self: Self, id: str) -> Optional[str]:
if self._SLACK_USER_ID_PATTERN.fullmatch(id):
match await self.slack_bolt_app.client.users_info(id):
case AsyncSlackResponse(data={"ok": True, "user": user}):
match user:
case {"real_name": real_name}:
return real_name
case {"profile": {"display_name": display_name}}:
return display_name
case {"name": name}:
return name
case AsyncSlackResponse(data={"ok": False, "error": "user_not_found"}):
logger.error("An attempt to look up a user failed (user not found): {}", id)
return None
async def _process_message(self: Self, event: Mapping[str, Any]) -> None:
slack_message = await self._adapt_event_payload(event.get("event", {}))
logger.debug("Received message from Slack: {}", slack_message)
await self.event_brokerage_service.publish(EventType.INCOMING_SLACK_MESSAGE, slack_message)
async def _resolve_user_mentions(self: Self, text: str) -> str:
matches = tuple(self._SLACK_USER_MENTION_PATTERN.finditer(text))
unique_ids = frozenset(match.group(1) for match in matches)
await self._ensure_ids_in_id_name_map(unique_ids)
parts = []
previous_end = 0
for match in matches:
start, end = match.span()
parts.append(text[previous_end:start])
user_id = match.group(1)
parts.append(f"@{self._id_name_map.get(user_id, user_id)}")
previous_end = end
parts.append(text[previous_end:])
return ''.join(parts)
@staticmethod
def _resolve_user_name(user: Mapping[str, Any]):
match user:
case {"real_name": real_name}:
return real_name
case {"profile": {"display_name": display_name}}:
return display_name
case {"name": name}:
return name
class SlackServiceResource(AsyncResource):
async def init(self: Self, settings: Settings, event_brokerage_service: EventBrokerageService, http_client: AsyncClient) -> SlackService:
slack_bolt_app = AsyncApp(token=settings.slack_bot_token.get_secret_value())
socket_mode_handler = AsyncSocketModeHandler(slack_bolt_app, settings.slack_app_token.get_secret_value())
slack_service = SlackService(settings=settings, event_brokerage_service=event_brokerage_service, http_client=http_client, slack_bolt_app=slack_bolt_app, socket_mode_handler=socket_mode_handler)
await slack_service.initialize()
return slack_service