LiKenun commited on
Commit
488a150
·
1 Parent(s): d1ed688

Enable `SlackService` to update look-up table for new user identifiers as encountered in incoming messages

Browse files
src/ctp_slack_bot/services/slack_service.py CHANGED
@@ -3,10 +3,11 @@ from httpx import AsyncClient
3
  from loguru import logger
4
  from openai import OpenAI
5
  from pydantic import ConfigDict
6
- from re import compile as compile_re
 
7
  from slack_bolt.async_app import AsyncApp
8
  from slack_sdk.web.async_slack_response import AsyncSlackResponse
9
- from typing import Any, Mapping, MutableMapping, Self
10
 
11
  from ctp_slack_bot.core import HealthReportingApplicationComponentBase
12
  from ctp_slack_bot.enums import EventType
@@ -14,9 +15,6 @@ from ctp_slack_bot.models import SlackMessage, SlackResponse
14
  from ctp_slack_bot.services.event_brokerage_service import EventBrokerageService
15
 
16
 
17
- SLACK_USER_MENTION_PATTERN = compile_re(r"<@([A-Z0-9]+)>")
18
-
19
-
20
  class SlackService(HealthReportingApplicationComponentBase):
21
  """
22
  Service for interfacing with Slack.
@@ -24,24 +22,40 @@ class SlackService(HealthReportingApplicationComponentBase):
24
 
25
  model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
26
 
 
 
 
27
  event_brokerage_service: EventBrokerageService
28
  http_client: AsyncClient
29
  slack_bolt_app: AsyncApp
30
- user_id_name_map: MutableMapping[str, str]
31
 
32
- def model_post_init(self: Self, context: Any, /) -> None:
33
- super().model_post_init(context)
34
  self.event_brokerage_service.subscribe(EventType.OUTGOING_SLACK_RESPONSE, self.send_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
- async def adapt_event_payload(self: Self, event: Mapping[str, Any]) -> SlackMessage:
37
- text = SLACK_USER_MENTION_PATTERN.sub(lambda match: f"@{self.user_id_name_map.get(match.group(1))}", event.get("text", "")) # TODO: permit look-up of Slack again when not found.
38
  user_id = event.get("user")
39
  return SlackMessage(
40
  type=event.get("type"),
41
  subtype=event.get("subtype"),
42
  channel=event.get("channel"),
43
  channel_type=event.get("channel_type"),
44
- user=await self._get_user_display_name(user_id),
45
  bot_id=event.get("bot_id"),
46
  thread_ts=event.get("thread_ts"),
47
  text=text,
@@ -49,66 +63,100 @@ class SlackService(HealthReportingApplicationComponentBase):
49
  event_ts=event.get("event_ts")
50
  )
51
 
52
- async def process_message(self: Self, event: Mapping[str, Any]) -> None:
53
- slack_message = await self.adapt_event_payload(event.get("event", {}))
54
- logger.debug("Received message from Slack: {}", slack_message)
55
- await self.event_brokerage_service.publish(EventType.INCOMING_SLACK_MESSAGE, slack_message)
 
 
 
56
 
57
- async def send_message(self: Self, message: SlackResponse) -> None:
58
- await self.slack_bolt_app.client.chat_postMessage(channel=message.channel, text=message.text, thread_ts=message.thread_ts)
 
59
 
60
- async def handle_message_event(self: Self, body: Mapping[str, Any]) -> None:
61
  logger.debug("Ignored regular message: {}", body.get("event", {}).get("text"))
62
- # await self.process_message(body)
63
 
64
- async def handle_app_mention_event(self: Self, body: Mapping[str, Any]) -> None:
65
  logger.debug("Received app mention for processing: {}", body.get("event", {}).get("text"))
66
- await self.process_message(body)
67
-
68
- def initialize(self: Self) -> None:
69
- self.slack_bolt_app.event("message")(self.handle_message_event)
70
- self.slack_bolt_app.event("app_mention")(self.handle_app_mention_event)
71
- logger.debug("Registered 2 handlers for Slack Bolt message and app mention events.")
72
-
73
- @property
74
- def name(self: Self) -> str:
75
- return "slack_service"
76
-
77
- async def is_healthy(self: Self) -> bool:
78
- response = await self.http_client.get("https://slack-status.com/api/v2.0.0/current")
79
- return response.status_code == 200
 
 
 
 
 
 
 
80
 
81
- async def _get_user_display_name(self: Self, user_id: str) -> str:
82
- return self.user_id_name_map.get(user_id, f"<@{user_id}>")
83
- # TODO: Handle new users.
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  class SlackServiceResource(AsyncResource):
87
  async def init(self: Self, event_brokerage_service: EventBrokerageService, http_client: AsyncClient, slack_bolt_app: AsyncApp) -> SlackService:
88
- match await slack_bolt_app.client.users_list():
89
- case AsyncSlackResponse(status_code=200, data={"ok": True, "members": users}):
90
- user_id_name_map = {id: display_name
91
- for id, display_name
92
- in zip(map(SlackServiceResource._get_user_id, users), map(SlackServiceResource._get_user_display_name, users))
93
- if display_name}
94
- logger.debug("Obtained a list of {} user name(s) for the workspace: {}", len(user_id_name_map), user_id_name_map)
95
- case something:
96
- user_id_name_map = {}
97
- logger.error("Could not obtain a list of user name for the workspace.")
98
- slack_service = SlackService(event_brokerage_service=event_brokerage_service, http_client=http_client, slack_bolt_app=slack_bolt_app, user_id_name_map=user_id_name_map)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  slack_service.initialize()
100
  return slack_service
101
 
102
  @classmethod
103
- def _get_user_id(cls, user: Mapping[str, Any]):
104
- return user["id"]
105
-
106
- @classmethod
107
- def _get_user_display_name(cls, user: Mapping[str, Any]):
108
  match user:
109
- case {"profile": {"display_name": display_name}}:
110
- return display_name
111
  case {"real_name": real_name}:
112
  return real_name
113
- case _:
114
- None
 
 
 
3
  from loguru import logger
4
  from openai import OpenAI
5
  from pydantic import ConfigDict
6
+ from re import compile as compile_re, Pattern
7
+ from slack_sdk.errors import SlackApiError
8
  from slack_bolt.async_app import AsyncApp
9
  from slack_sdk.web.async_slack_response import AsyncSlackResponse
10
+ from typing import Any, Mapping, MutableMapping, Optional, Self, Set
11
 
12
  from ctp_slack_bot.core import HealthReportingApplicationComponentBase
13
  from ctp_slack_bot.enums import EventType
 
15
  from ctp_slack_bot.services.event_brokerage_service import EventBrokerageService
16
 
17
 
 
 
 
18
  class SlackService(HealthReportingApplicationComponentBase):
19
  """
20
  Service for interfacing with Slack.
 
22
 
23
  model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
24
 
25
+ SLACK_USER_ID_PATTERN: Pattern = compile_re(r"U\d+")
26
+ SLACK_USER_MENTION_PATTERN: Pattern = compile_re(r"<@(U[A-Z0-9]+)>")
27
+
28
  event_brokerage_service: EventBrokerageService
29
  http_client: AsyncClient
30
  slack_bolt_app: AsyncApp
31
+ id_name_map: MutableMapping[str, str] # TODO: Spin message processing out into its own service.
32
 
33
+ def initialize(self: Self) -> None:
 
34
  self.event_brokerage_service.subscribe(EventType.OUTGOING_SLACK_RESPONSE, self.send_message)
35
+ self.slack_bolt_app.event("message")(self._handle_message_event)
36
+ self.slack_bolt_app.event("app_mention")(self._handle_app_mention_event)
37
+ logger.debug("Registered 2 handlers for Slack Bolt message and app mention events.")
38
+
39
+ async def send_message(self: Self, message: SlackResponse) -> None:
40
+ await self.slack_bolt_app.client.chat_postMessage(channel=message.channel, text=message.text, thread_ts=message.thread_ts)
41
+
42
+ @property
43
+ def name(self: Self) -> str:
44
+ return "slack_service"
45
+
46
+ async def is_healthy(self: Self) -> bool:
47
+ response = await self.http_client.get("https://slack-status.com/api/v2.0.0/current")
48
+ return response.status_code == 200
49
 
50
+ async def _adapt_event_payload(self: Self, event: Mapping[str, Any]) -> SlackMessage:
51
+ text = await self._resolve_user_mentions(event.get("text", ""))
52
  user_id = event.get("user")
53
  return SlackMessage(
54
  type=event.get("type"),
55
  subtype=event.get("subtype"),
56
  channel=event.get("channel"),
57
  channel_type=event.get("channel_type"),
58
+ user=await self._get_name(user_id),
59
  bot_id=event.get("bot_id"),
60
  thread_ts=event.get("thread_ts"),
61
  text=text,
 
63
  event_ts=event.get("event_ts")
64
  )
65
 
66
+ async def _ensure_ids_in_id_name_map(self: Self, ids: Set[str]) -> None:
67
+ unknown_ids = ids - self.id_name_map.keys()
68
+ if len(unknown_ids) == 0:
69
+ return
70
+ async with TaskGroup() as task_group:
71
+ update_tasks = {unknown_id: task_group.create_task(self._look_up_name(unknown_id)) for unknown_id in unknown_ids}
72
+ self.id_name_map.update({id: task.result() for id, task in update_tasks.items() if task.result()})
73
 
74
+ async def _get_name(self: Self, id: str) -> str:
75
+ await self._ensure_ids_in_id_name_map({id})
76
+ return self.id_name_map.get(id, id)
77
 
78
+ async def _handle_message_event(self: Self, body: Mapping[str, Any]) -> None:
79
  logger.debug("Ignored regular message: {}", body.get("event", {}).get("text"))
80
+ # await self._process_message(body)
81
 
82
+ async def _handle_app_mention_event(self: Self, body: Mapping[str, Any]) -> None:
83
  logger.debug("Received app mention for processing: {}", body.get("event", {}).get("text"))
84
+ await self._process_message(body)
85
+
86
+ async def _look_up_name(self: Self, id: str) -> Optional[str]:
87
+ if self.SLACK_USER_ID_PATTERN.fullmatch(id):
88
+ match await self.slack_bolt_app.client.users_info(id):
89
+ case AsyncSlackResponse(data={"ok": True, "user": user}):
90
+ match user:
91
+ case {"real_name": real_name}:
92
+ return real_name
93
+ case {"profile": {"display_name": display_name}}:
94
+ return display_name
95
+ case {"name": name}:
96
+ return name
97
+ case AsyncSlackResponse(data={"ok": False, "error": "user_not_found"}):
98
+ logger.error("An attempt to look up a user failed (user not found): {}", id)
99
+ return None
100
+
101
+ async def _process_message(self: Self, event: Mapping[str, Any]) -> None:
102
+ slack_message = await self._adapt_event_payload(event.get("event", {}))
103
+ logger.debug("Received message from Slack: {}", slack_message)
104
+ await self.event_brokerage_service.publish(EventType.INCOMING_SLACK_MESSAGE, slack_message)
105
 
106
+ async def _resolve_user_mentions(self: Self, text: str) -> str:
107
+ matches = tuple(self.SLACK_USER_MENTION_PATTERN.finditer(text))
108
+ unique_ids = frozenset(match.group(1) for match in matches)
109
+ await self._ensure_ids_in_id_name_map(unique_ids)
110
+ parts = []
111
+ previous_end = 0
112
+ for match in matches:
113
+ start, end = match.span()
114
+ parts.append(text[previous_end:start])
115
+ user_id = match.group(1)
116
+ parts.append(f"@{self.id_name_map.get(user_id, user_id)}")
117
+ previous_end = end
118
+ parts.append(text[previous_end:])
119
+ return ''.join(parts)
120
 
121
 
122
  class SlackServiceResource(AsyncResource):
123
  async def init(self: Self, event_brokerage_service: EventBrokerageService, http_client: AsyncClient, slack_bolt_app: AsyncApp) -> SlackService:
124
+ async def get_users_list():
125
+ cursor = None
126
+ while True:
127
+ try:
128
+ response = await slack_bolt_app.client.users_list(cursor=cursor, limit=200)
129
+ except SlackApiError as e:
130
+ logger.warning("Could not get a list of users: {}", e)
131
+ break
132
+ match response:
133
+ case AsyncSlackResponse(status_code=200, data={"ok": True, "members": users}):
134
+ for user in users:
135
+ yield user
136
+ match response.data:
137
+ case {"response_metadata": {"next_cursor": cursor}} if cursor:
138
+ continue
139
+ case AsyncSlackResponse(status_code=status_code) if status_code != 200:
140
+ logger.warning("Could not get a list of users: response status {}", status_code)
141
+ case AsyncSlackResponse(data={"ok": False}):
142
+ logger.warning("Could not get a list of users: non-OK response")
143
+ case _:
144
+ logger.warning("Could not get a list of users.")
145
+ break
146
+ id_name_map = {user["id"]: self._get_name(user)
147
+ async for user
148
+ in get_users_list()}
149
+ logger.debug("Obtained a list of {} user name(s) for the workspace: {}", len(id_name_map), id_name_map)
150
+ slack_service = SlackService(event_brokerage_service=event_brokerage_service, http_client=http_client, slack_bolt_app=slack_bolt_app, id_name_map=id_name_map)
151
  slack_service.initialize()
152
  return slack_service
153
 
154
  @classmethod
155
+ def _get_name(cls, user: Mapping[str, Any]):
 
 
 
 
156
  match user:
 
 
157
  case {"real_name": real_name}:
158
  return real_name
159
+ case {"profile": {"display_name": display_name}}:
160
+ return display_name
161
+ case {"name": name}:
162
+ return name