Spaces:
Runtime error
Runtime error
Refactor #6
Browse files- README.md +1 -1
- src/ctp_slack_bot/app.py +14 -13
- src/ctp_slack_bot/containers.py +39 -34
- src/ctp_slack_bot/controllers/__init__.py +1 -1
- src/ctp_slack_bot/controllers/application_health_controller.py +5 -3
- src/ctp_slack_bot/controllers/base.py +54 -20
- src/ctp_slack_bot/db/mongo_db.py +11 -33
- src/ctp_slack_bot/mime_type_handlers/__init__.py +1 -1
- src/ctp_slack_bot/mime_type_handlers/base.py +29 -11
- src/ctp_slack_bot/mime_type_handlers/text/vtt.py +2 -2
- src/ctp_slack_bot/services/__init__.py +1 -1
- src/ctp_slack_bot/services/content_ingestion_service.py +9 -6
- src/ctp_slack_bot/services/event_brokerage_service.py +9 -8
- src/ctp_slack_bot/services/google_drive_service.py +0 -1
- src/ctp_slack_bot/services/question_dispatch_service.py +9 -6
- src/ctp_slack_bot/services/slack_service.py +59 -45
- src/ctp_slack_bot/services/{schedule_service.py → task_service.py} +13 -17
README.md
CHANGED
@@ -101,7 +101,7 @@ Not every file or folder is listed, but the important stuff is here.
|
|
101 |
* `google_drive_service.py`: interfaces with Google Drive
|
102 |
* `language_model_service.py`: answers questions using relevant context
|
103 |
* `question_dispatch_service.py`: listens for questions and retrieves relevant context to get answers
|
104 |
-
* `
|
105 |
* `slack_service.py`: handles events from Slack and sends back responses
|
106 |
* `vectorization_service.py`: converts chunks into chunks with embeddings
|
107 |
* `tasks/`: scheduled tasks to run in the background
|
|
|
101 |
* `google_drive_service.py`: interfaces with Google Drive
|
102 |
* `language_model_service.py`: answers questions using relevant context
|
103 |
* `question_dispatch_service.py`: listens for questions and retrieves relevant context to get answers
|
104 |
+
* `task_service.py`: runs periodic background tasks
|
105 |
* `slack_service.py`: handles events from Slack and sends back responses
|
106 |
* `vectorization_service.py`: converts chunks into chunks with embeddings
|
107 |
* `tasks/`: scheduled tasks to run in the background
|
src/ctp_slack_bot/app.py
CHANGED
@@ -7,8 +7,11 @@ from containers import Container
|
|
7 |
from core.logging import setup_logging
|
8 |
|
9 |
|
10 |
-
async def handle_shutdown_signal() -> None:
|
11 |
logger.info("Received shutdown signal.")
|
|
|
|
|
|
|
12 |
for task in all_tasks():
|
13 |
if task is not current_task() and not task.done():
|
14 |
task.cancel()
|
@@ -16,9 +19,9 @@ async def handle_shutdown_signal() -> None:
|
|
16 |
logger.info("Cancelled all tasks.")
|
17 |
|
18 |
|
19 |
-
def create_shutdown_signal_handler() -> Callable[[], None]:
|
20 |
def shutdown_signal_handler() -> None:
|
21 |
-
create_task(handle_shutdown_signal())
|
22 |
return shutdown_signal_handler
|
23 |
|
24 |
|
@@ -32,29 +35,27 @@ async def main() -> None:
|
|
32 |
container.wire(packages=["ctp_slack_bot"])
|
33 |
logger.debug("Created dependency injection container with providers: {}", '; '.join(container.providers))
|
34 |
|
35 |
-
# Initialize/instantiate services which should be
|
36 |
-
container.content_ingestion_service()
|
37 |
-
container.question_dispatch_service()
|
38 |
http_server = await container.http_server()
|
39 |
-
|
40 |
-
container.
|
41 |
logger.debug("Initialized services.")
|
42 |
|
43 |
# Install the shutdown signal handler.
|
44 |
-
shutdown_signal_handler = create_shutdown_signal_handler()
|
45 |
loop = get_running_loop()
|
46 |
loop.add_signal_handler(SIGINT, shutdown_signal_handler)
|
47 |
loop.add_signal_handler(SIGTERM, shutdown_signal_handler)
|
48 |
|
49 |
# Start the HTTP server and Slack socket mode handler in the background; clean up resources when shut down.
|
50 |
try:
|
51 |
-
logger.info("Starting
|
52 |
-
await gather(http_server.start(),
|
53 |
except CancelledError:
|
54 |
logger.info("Shutting down application…")
|
55 |
finally:
|
56 |
-
await socket_mode_handler.close_async()
|
57 |
-
logger.info("Stopped Slack Socket Mode handler.")
|
58 |
await container.shutdown_resources()
|
59 |
|
60 |
|
|
|
7 |
from core.logging import setup_logging
|
8 |
|
9 |
|
10 |
+
async def handle_shutdown_signal(*args) -> None:
|
11 |
logger.info("Received shutdown signal.")
|
12 |
+
for arg in args:
|
13 |
+
await arg()
|
14 |
+
logger.info("Executed shutdown tasks.")
|
15 |
for task in all_tasks():
|
16 |
if task is not current_task() and not task.done():
|
17 |
task.cancel()
|
|
|
19 |
logger.info("Cancelled all tasks.")
|
20 |
|
21 |
|
22 |
+
def create_shutdown_signal_handler(*args) -> Callable[[], None]:
|
23 |
def shutdown_signal_handler() -> None:
|
24 |
+
create_task(handle_shutdown_signal(*args))
|
25 |
return shutdown_signal_handler
|
26 |
|
27 |
|
|
|
35 |
container.wire(packages=["ctp_slack_bot"])
|
36 |
logger.debug("Created dependency injection container with providers: {}", '; '.join(container.providers))
|
37 |
|
38 |
+
# Initialize/instantiate services which should be available from the start.
|
39 |
+
await container.content_ingestion_service()
|
40 |
+
await container.question_dispatch_service()
|
41 |
http_server = await container.http_server()
|
42 |
+
slack_service = await container.slack_service()
|
43 |
+
task_service = await container.task_service()
|
44 |
logger.debug("Initialized services.")
|
45 |
|
46 |
# Install the shutdown signal handler.
|
47 |
+
shutdown_signal_handler = create_shutdown_signal_handler(http_server.stop, slack_service.stop, task_service.stop)
|
48 |
loop = get_running_loop()
|
49 |
loop.add_signal_handler(SIGINT, shutdown_signal_handler)
|
50 |
loop.add_signal_handler(SIGTERM, shutdown_signal_handler)
|
51 |
|
52 |
# Start the HTTP server and Slack socket mode handler in the background; clean up resources when shut down.
|
53 |
try:
|
54 |
+
logger.info("Starting services…")
|
55 |
+
await gather(http_server.start(), slack_service.start(), task_service.start())
|
56 |
except CancelledError:
|
57 |
logger.info("Shutting down application…")
|
58 |
finally:
|
|
|
|
|
59 |
await container.shutdown_resources()
|
60 |
|
61 |
|
src/ctp_slack_bot/containers.py
CHANGED
@@ -1,13 +1,13 @@
|
|
|
|
|
|
1 |
from dependency_injector.containers import DeclarativeContainer
|
2 |
-
from dependency_injector.providers import
|
3 |
from importlib import import_module
|
4 |
from itertools import chain
|
5 |
from openai import AsyncOpenAI
|
6 |
from pkgutil import iter_modules
|
7 |
-
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
8 |
-
from slack_bolt.async_app import AsyncApp
|
9 |
from types import ModuleType
|
10 |
-
from typing import Sequence
|
11 |
|
12 |
from ctp_slack_bot.controllers import ControllerBase, ControllerRegistry
|
13 |
from ctp_slack_bot.core import Settings
|
@@ -16,7 +16,7 @@ from ctp_slack_bot.db.repositories.mongo_db_vectorized_chunk_repository import M
|
|
16 |
from ctp_slack_bot.mime_type_handlers import MimeTypeHandlerRegistry
|
17 |
from ctp_slack_bot.services.answer_retrieval_service import AnswerRetrievalService
|
18 |
from ctp_slack_bot.services.application_health_service import ApplicationHealthService
|
19 |
-
from ctp_slack_bot.services.content_ingestion_service import
|
20 |
from ctp_slack_bot.services.context_retrieval_service import ContextRetrievalService
|
21 |
from ctp_slack_bot.services.embeddings_model_service import EmbeddingsModelService
|
22 |
from ctp_slack_bot.services.event_brokerage_service import EventBrokerageService
|
@@ -24,26 +24,38 @@ from ctp_slack_bot.services.google_drive_service import GoogleDriveService
|
|
24 |
from ctp_slack_bot.services.http_client_service import HTTPClientServiceResource
|
25 |
from ctp_slack_bot.services.http_server_service import HTTPServerResource
|
26 |
from ctp_slack_bot.services.language_model_service import LanguageModelService
|
27 |
-
from ctp_slack_bot.services.question_dispatch_service import
|
28 |
-
from ctp_slack_bot.services.schedule_service import ScheduleServiceResource
|
29 |
from ctp_slack_bot.services.slack_service import SlackServiceResource
|
|
|
30 |
from ctp_slack_bot.services.vectorization_service import VectorizationService
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
class Container(DeclarativeContainer): # TODO: audit for potential async-related bugs.
|
34 |
-
async def
|
35 |
-
return [controller_class(**{dependency_name: await
|
36 |
-
for dependency_name
|
37 |
-
in
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
__self__ = Self()
|
43 |
settings = Singleton(Settings)
|
44 |
event_brokerage_service = Singleton(EventBrokerageService)
|
45 |
-
schedule_service = Resource (ScheduleServiceResource,
|
46 |
-
settings=settings)
|
47 |
http_client = Resource (HTTPClientServiceResource)
|
48 |
mongo_db = Resource (MongoDBResource,
|
49 |
settings=settings)
|
@@ -58,7 +70,7 @@ class Container(DeclarativeContainer): # TODO: audit for potential async-related
|
|
58 |
vectorization_service = Singleton(VectorizationService,
|
59 |
settings=settings,
|
60 |
embeddings_model_service=embeddings_model_service)
|
61 |
-
content_ingestion_service =
|
62 |
settings=settings,
|
63 |
event_brokerage_service=event_brokerage_service,
|
64 |
vectorized_chunk_repository=vectorized_chunk_repository,
|
@@ -74,25 +86,18 @@ class Container(DeclarativeContainer): # TODO: audit for potential async-related
|
|
74 |
settings=settings,
|
75 |
event_brokerage_service=event_brokerage_service,
|
76 |
language_model_service=language_model_service)
|
77 |
-
question_dispatch_service =
|
78 |
settings=settings,
|
79 |
event_brokerage_service=event_brokerage_service,
|
80 |
-
content_ingestion_service=content_ingestion_service,
|
81 |
context_retrieval_service=context_retrieval_service,
|
82 |
answer_retrieval_service=answer_retrieval_service)
|
83 |
-
slack_bolt_app = Singleton(lambda settings: AsyncApp(token=settings.slack_bot_token.get_secret_value()),
|
84 |
-
settings)
|
85 |
slack_service = Resource (SlackServiceResource,
|
|
|
86 |
event_brokerage_service=event_brokerage_service,
|
87 |
-
http_client=http_client
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
slack_bolt_app,
|
92 |
-
settings)
|
93 |
-
mime_type_handlers = Dict ({mime_type: Singleton(handler)
|
94 |
-
for mime_type, handler
|
95 |
-
in MimeTypeHandlerRegistry.get_registry().items()})
|
96 |
google_drive_service = Singleton(GoogleDriveService,
|
97 |
settings=settings)
|
98 |
# file_monitor_service = Singleton(FileMonitorService,
|
@@ -101,8 +106,8 @@ class Container(DeclarativeContainer): # TODO: audit for potential async-related
|
|
101 |
# mime_type_handler_factory=mime_type_handler_factory)
|
102 |
application_health_service = Singleton(ApplicationHealthService,
|
103 |
services=List(mongo_db, slack_service))
|
104 |
-
|
105 |
-
|
106 |
http_server = Resource (HTTPServerResource,
|
107 |
settings=settings,
|
108 |
-
controllers=
|
|
|
1 |
+
from asyncio import iscoroutine, isfuture
|
2 |
+
from dependency_injector import providers
|
3 |
from dependency_injector.containers import DeclarativeContainer
|
4 |
+
from dependency_injector.providers import Callable, Configuration, Dict, List, Resource, Singleton
|
5 |
from importlib import import_module
|
6 |
from itertools import chain
|
7 |
from openai import AsyncOpenAI
|
8 |
from pkgutil import iter_modules
|
|
|
|
|
9 |
from types import ModuleType
|
10 |
+
from typing import Any, Iterator, Sequence
|
11 |
|
12 |
from ctp_slack_bot.controllers import ControllerBase, ControllerRegistry
|
13 |
from ctp_slack_bot.core import Settings
|
|
|
16 |
from ctp_slack_bot.mime_type_handlers import MimeTypeHandlerRegistry
|
17 |
from ctp_slack_bot.services.answer_retrieval_service import AnswerRetrievalService
|
18 |
from ctp_slack_bot.services.application_health_service import ApplicationHealthService
|
19 |
+
from ctp_slack_bot.services.content_ingestion_service import ContentIngestionServiceResource
|
20 |
from ctp_slack_bot.services.context_retrieval_service import ContextRetrievalService
|
21 |
from ctp_slack_bot.services.embeddings_model_service import EmbeddingsModelService
|
22 |
from ctp_slack_bot.services.event_brokerage_service import EventBrokerageService
|
|
|
24 |
from ctp_slack_bot.services.http_client_service import HTTPClientServiceResource
|
25 |
from ctp_slack_bot.services.http_server_service import HTTPServerResource
|
26 |
from ctp_slack_bot.services.language_model_service import LanguageModelService
|
27 |
+
from ctp_slack_bot.services.question_dispatch_service import QuestionDispatchServiceResource
|
|
|
28 |
from ctp_slack_bot.services.slack_service import SlackServiceResource
|
29 |
+
from ctp_slack_bot.services.task_service import TaskServiceResource
|
30 |
from ctp_slack_bot.services.vectorization_service import VectorizationService
|
31 |
|
32 |
|
33 |
+
async def _await_or_return(value):
|
34 |
+
if iscoroutine(value) or isfuture(value):
|
35 |
+
return await value
|
36 |
+
return value
|
37 |
+
|
38 |
+
|
39 |
class Container(DeclarativeContainer): # TODO: audit for potential async-related bugs.
|
40 |
+
async def __get_http_controller_providers(container) -> Sequence[ControllerBase]:
|
41 |
+
return [controller_class(**{dependency_name: await _await_or_return(container.providers[dependency_name]())
|
42 |
+
for dependency_name
|
43 |
+
in controller_class.model_fields.keys() & container.providers.keys()})
|
44 |
+
for controller_class in ControllerRegistry.get_registry()]
|
45 |
+
|
46 |
+
def __iter_mime_type_handler_providers() -> Iterator[tuple[str, Singleton]]:
|
47 |
+
handler_provider_map = {}
|
48 |
+
for mime_type, handler in MimeTypeHandlerRegistry.get_registry().items():
|
49 |
+
if handler in handler_provider_map:
|
50 |
+
provider = handler_provider_map[handler]
|
51 |
+
else:
|
52 |
+
provider = Singleton(handler)
|
53 |
+
handler_provider_map[handler] = provider
|
54 |
+
yield (mime_type, provider)
|
55 |
|
56 |
+
__self__ = providers.Self()
|
57 |
settings = Singleton(Settings)
|
58 |
event_brokerage_service = Singleton(EventBrokerageService)
|
|
|
|
|
59 |
http_client = Resource (HTTPClientServiceResource)
|
60 |
mongo_db = Resource (MongoDBResource,
|
61 |
settings=settings)
|
|
|
70 |
vectorization_service = Singleton(VectorizationService,
|
71 |
settings=settings,
|
72 |
embeddings_model_service=embeddings_model_service)
|
73 |
+
content_ingestion_service = Resource (ContentIngestionServiceResource,
|
74 |
settings=settings,
|
75 |
event_brokerage_service=event_brokerage_service,
|
76 |
vectorized_chunk_repository=vectorized_chunk_repository,
|
|
|
86 |
settings=settings,
|
87 |
event_brokerage_service=event_brokerage_service,
|
88 |
language_model_service=language_model_service)
|
89 |
+
question_dispatch_service = Resource (QuestionDispatchServiceResource,
|
90 |
settings=settings,
|
91 |
event_brokerage_service=event_brokerage_service,
|
|
|
92 |
context_retrieval_service=context_retrieval_service,
|
93 |
answer_retrieval_service=answer_retrieval_service)
|
|
|
|
|
94 |
slack_service = Resource (SlackServiceResource,
|
95 |
+
settings=settings,
|
96 |
event_brokerage_service=event_brokerage_service,
|
97 |
+
http_client=http_client)
|
98 |
+
mime_type_handlers = Dict ({mime_type: handler_provider
|
99 |
+
for mime_type, handler_provider
|
100 |
+
in __iter_mime_type_handler_providers()})
|
|
|
|
|
|
|
|
|
|
|
101 |
google_drive_service = Singleton(GoogleDriveService,
|
102 |
settings=settings)
|
103 |
# file_monitor_service = Singleton(FileMonitorService,
|
|
|
106 |
# mime_type_handler_factory=mime_type_handler_factory)
|
107 |
application_health_service = Singleton(ApplicationHealthService,
|
108 |
services=List(mongo_db, slack_service))
|
109 |
+
task_service = Resource (TaskServiceResource,
|
110 |
+
settings=settings)
|
111 |
http_server = Resource (HTTPServerResource,
|
112 |
settings=settings,
|
113 |
+
controllers=Callable(__get_http_controller_providers, __self__))
|
src/ctp_slack_bot/controllers/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
from .application_health_controller import ApplicationHealthController
|
2 |
-
from .base import ControllerBase, ControllerRegistry, Route
|
|
|
1 |
from .application_health_controller import ApplicationHealthController
|
2 |
+
from .base import controller, ControllerBase, ControllerRegistry, delete, get, patch, post, put, Route, route
|
src/ctp_slack_bot/controllers/application_health_controller.py
CHANGED
@@ -2,19 +2,21 @@ from aiohttp.web import json_response, Request, Response
|
|
2 |
from pydantic import ConfigDict
|
3 |
from typing import Self
|
4 |
|
5 |
-
from .base import ControllerBase,
|
6 |
from ctp_slack_bot.services import ApplicationHealthService
|
7 |
|
8 |
|
9 |
-
@
|
10 |
class ApplicationHealthController(ControllerBase):
|
11 |
"""
|
12 |
Application health reporting endpoints.
|
13 |
"""
|
14 |
|
|
|
|
|
15 |
application_health_service: ApplicationHealthService
|
16 |
|
17 |
-
@
|
18 |
async def get_health(self: Self, request: Request) -> Response:
|
19 |
health_statuses = await self.application_health_service.get_health()
|
20 |
return json_response(dict(health_statuses), status=200 if all(health_statuses.values()) else 503)
|
|
|
2 |
from pydantic import ConfigDict
|
3 |
from typing import Self
|
4 |
|
5 |
+
from .base import ControllerBase, controller, get
|
6 |
from ctp_slack_bot.services import ApplicationHealthService
|
7 |
|
8 |
|
9 |
+
@controller("/health")
|
10 |
class ApplicationHealthController(ControllerBase):
|
11 |
"""
|
12 |
Application health reporting endpoints.
|
13 |
"""
|
14 |
|
15 |
+
model_config = ConfigDict(frozen=True)
|
16 |
+
|
17 |
application_health_service: ApplicationHealthService
|
18 |
|
19 |
+
@get("")
|
20 |
async def get_health(self: Self, request: Request) -> Response:
|
21 |
health_statuses = await self.application_health_service.get_health()
|
22 |
return json_response(dict(health_statuses), status=200 if all(health_statuses.values()) else 503)
|
src/ctp_slack_bot/controllers/base.py
CHANGED
@@ -1,9 +1,10 @@
|
|
|
|
1 |
from aiohttp.web import Request, Response
|
2 |
-
from functools import partial
|
3 |
from importlib import import_module
|
4 |
from inspect import getmembers, ismethod
|
5 |
from pydantic import BaseModel, ConfigDict
|
6 |
-
from typing import Awaitable, Callable, ClassVar, Mapping, Self, Sequence, TypeVar
|
7 |
|
8 |
from ctp_slack_bot.core import ApplicationComponentBase
|
9 |
|
@@ -18,21 +19,20 @@ class Route(BaseModel):
|
|
18 |
path: str
|
19 |
handler: AsyncHandler
|
20 |
|
21 |
-
@staticmethod
|
22 |
-
def get(path: str) -> Callable[[AsyncHandler], AsyncHandler]:
|
23 |
-
def decorator(function: AsyncHandler) -> AsyncHandler:
|
24 |
-
function._http_method = "GET"
|
25 |
-
function._http_path = path
|
26 |
-
return function
|
27 |
-
return decorator
|
28 |
-
|
29 |
|
30 |
class ControllerBase(ApplicationComponentBase):
|
31 |
|
32 |
def get_routes(self: Self) -> Sequence[Route]:
|
33 |
-
return tuple(Route(method=method._http_method,
|
|
|
|
|
34 |
for name, method in getmembers(self, predicate=ismethod)
|
35 |
-
if name != 'get_routes' and hasattr(method, "_http_method") and hasattr(method, "_http_path"))
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
|
38 |
T = TypeVar('T', bound=ControllerBase)
|
@@ -40,16 +40,50 @@ T = TypeVar('T', bound=ControllerBase)
|
|
40 |
|
41 |
class ControllerRegistry:
|
42 |
|
43 |
-
|
44 |
|
45 |
@classmethod
|
46 |
-
def get_registry(cls) ->
|
47 |
import_module(__package__)
|
48 |
-
return tuple(cls.
|
49 |
|
50 |
@classmethod
|
51 |
-
def register(cls):
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
from aiohttp.web import Request, Response
|
3 |
+
from functools import partial, wraps
|
4 |
from importlib import import_module
|
5 |
from inspect import getmembers, ismethod
|
6 |
from pydantic import BaseModel, ConfigDict
|
7 |
+
from typing import Awaitable, Callable, ClassVar, Collection, Mapping, Optional, overload, ParamSpec, Self, Sequence, TypeVar
|
8 |
|
9 |
from ctp_slack_bot.core import ApplicationComponentBase
|
10 |
|
|
|
19 |
path: str
|
20 |
handler: AsyncHandler
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
class ControllerBase(ApplicationComponentBase):
|
24 |
|
25 |
def get_routes(self: Self) -> Sequence[Route]:
|
26 |
+
return tuple(Route(method=method._http_method,
|
27 |
+
path="/".join(filter(None, (self.prefix, method._http_path))),
|
28 |
+
handler=method)
|
29 |
for name, method in getmembers(self, predicate=ismethod)
|
30 |
+
if name != 'get_routes' and name != 'prefix' and hasattr(method, "_http_method") and hasattr(method, "_http_path"))
|
31 |
+
|
32 |
+
@property
|
33 |
+
@abstractmethod
|
34 |
+
def prefix(self: Self) -> str:
|
35 |
+
pass
|
36 |
|
37 |
|
38 |
T = TypeVar('T', bound=ControllerBase)
|
|
|
40 |
|
41 |
class ControllerRegistry:
|
42 |
|
43 |
+
__registry: ClassVar[list[T]] = []
|
44 |
|
45 |
@classmethod
|
46 |
+
def get_registry(cls) -> Collection[T]:
|
47 |
import_module(__package__)
|
48 |
+
return tuple(cls.__registry)
|
49 |
|
50 |
@classmethod
|
51 |
+
def register(cls, controller_cls: T) -> None:
|
52 |
+
cls.__registry.append(controller_cls)
|
53 |
+
|
54 |
+
|
55 |
+
@overload
|
56 |
+
def controller(cls: T) -> T: ...
|
57 |
+
|
58 |
+
@overload
|
59 |
+
def controller(prefix: str = "/") -> Callable[[T], T]: ...
|
60 |
+
|
61 |
+
def controller(cls_or_prefix=None):
|
62 |
+
def implement_prefix_property_and_register_controller(cls: T, prefix: Optional[str] = "/") -> T:
|
63 |
+
def prefix_getter(self: T) -> str:
|
64 |
+
return prefix
|
65 |
+
setattr(cls, 'prefix', property(prefix_getter))
|
66 |
+
if hasattr(cls, '__abstractmethods__'):
|
67 |
+
cls.__abstractmethods__ = frozenset(method for method in cls.__abstractmethods__ if method != 'prefix')
|
68 |
+
ControllerRegistry.register(cls)
|
69 |
+
return cls
|
70 |
+
if isinstance(cls_or_prefix, type):
|
71 |
+
return implement_prefix_property_and_register_controller(cls_or_prefix)
|
72 |
+
def decorator(cls: T) -> T:
|
73 |
+
return implement_prefix_property_and_register_controller(cls, cls_or_prefix)
|
74 |
+
return decorator
|
75 |
+
|
76 |
+
|
77 |
+
def route(method: str, path: str = "") -> Callable[[AsyncHandler], AsyncHandler]:
|
78 |
+
def decorator(function: AsyncHandler) -> AsyncHandler:
|
79 |
+
function._http_method = method
|
80 |
+
function._http_path = path
|
81 |
+
return function
|
82 |
+
return decorator
|
83 |
+
|
84 |
+
|
85 |
+
get = partial(route, "GET")
|
86 |
+
post = partial(route, "POST")
|
87 |
+
put = partial(route, "PUT")
|
88 |
+
delete = partial(route, "DELETE")
|
89 |
+
patch = partial(route, "PATCH")
|
src/ctp_slack_bot/db/mongo_db.py
CHANGED
@@ -65,30 +65,16 @@ class MongoDB(HealthReportingApplicationComponentBase):
|
|
65 |
"""
|
66 |
Get a collection by name or creates it if it doesn’t exist.
|
67 |
"""
|
68 |
-
# First ensure we can connect at all.
|
69 |
-
if not await self.ping():
|
70 |
-
logger.error("Cannot get collection '{}' because a MongoDB connection is not available.", name)
|
71 |
-
raise ConnectionError("MongoDB connection is not available.")
|
72 |
-
|
73 |
try:
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
if name not in collection_names:
|
79 |
-
logger.info("Collection '{}' does not exist. Creating it…", name)
|
80 |
-
|
81 |
-
# Create the collection.
|
82 |
-
await self._db.create_collection(name)
|
83 |
-
logger.debug("Successfully created collection: {}", name)
|
84 |
else:
|
85 |
-
|
86 |
-
|
87 |
-
# Get and return the collection.
|
88 |
-
collection = self._db[name]
|
89 |
return collection
|
90 |
except Exception as e:
|
91 |
-
logger.error("Error accessing collection
|
92 |
raise e
|
93 |
|
94 |
def close(self: Self) -> None:
|
@@ -117,19 +103,11 @@ class MongoDBResource(AsyncResource):
|
|
117 |
|
118 |
async def _test_connection(self: Self, mongo_db: MongoDB) -> None:
|
119 |
"""Test MongoDB connection and log the result."""
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
else:
|
125 |
-
logger.error("MongoDB connection test failed!")
|
126 |
-
except Exception as e:
|
127 |
-
logger.error("Error testing MongoDB connection: {}", e)
|
128 |
-
raise e
|
129 |
|
130 |
async def shutdown(self: Self, mongo_db: MongoDB) -> None:
|
131 |
"""Close MongoDB connection on shutdown."""
|
132 |
-
|
133 |
-
mongo_db.close()
|
134 |
-
except Exception as e:
|
135 |
-
logger.error("Error closing MongoDB connection: {}", e)
|
|
|
65 |
"""
|
66 |
Get a collection by name or creates it if it doesn’t exist.
|
67 |
"""
|
|
|
|
|
|
|
|
|
|
|
68 |
try:
|
69 |
+
if name not in await self._db.list_collection_names():
|
70 |
+
collection = await self._db.create_collection(name)
|
71 |
+
logger.debug("Created previously nonexistent collection, {}.", name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
else:
|
73 |
+
collection = self._db[name]
|
74 |
+
logger.debug("Retrieved collection, {}.", name)
|
|
|
|
|
75 |
return collection
|
76 |
except Exception as e:
|
77 |
+
logger.error("Error accessing collection, {}: {}", name, e)
|
78 |
raise e
|
79 |
|
80 |
def close(self: Self) -> None:
|
|
|
103 |
|
104 |
async def _test_connection(self: Self, mongo_db: MongoDB) -> None:
|
105 |
"""Test MongoDB connection and log the result."""
|
106 |
+
if await mongo_db.ping():
|
107 |
+
logger.info("MongoDB connection test successful!")
|
108 |
+
else:
|
109 |
+
logger.error("MongoDB connection test failed!")
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
async def shutdown(self: Self, mongo_db: MongoDB) -> None:
|
112 |
"""Close MongoDB connection on shutdown."""
|
113 |
+
mongo_db.close()
|
|
|
|
|
|
src/ctp_slack_bot/mime_type_handlers/__init__.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
from .base import MimeTypeHandler, MimeTypeHandlerRegistry
|
2 |
from .text.vtt import WebVTTMimeTypeHandler
|
|
|
1 |
+
from .base import MimeTypeHandler, mime_type_handler, MimeTypeHandlerRegistry
|
2 |
from .text.vtt import WebVTTMimeTypeHandler
|
src/ctp_slack_bot/mime_type_handlers/base.py
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
-
from abc import
|
2 |
from importlib import import_module
|
3 |
from types import MappingProxyType
|
4 |
-
from typing import Any, ClassVar, Mapping, Optional
|
5 |
|
6 |
from ctp_slack_bot.core import ApplicationComponentBase
|
7 |
from ctp_slack_bot.models import Content
|
@@ -14,20 +14,38 @@ class MimeTypeHandler(ApplicationComponentBase):
|
|
14 |
pass
|
15 |
|
16 |
|
|
|
|
|
|
|
17 |
class MimeTypeHandlerRegistry:
|
18 |
|
19 |
-
|
20 |
|
21 |
@classmethod
|
22 |
-
def get_registry(cls) -> Mapping[str,
|
23 |
import_module(__package__)
|
24 |
-
return MappingProxyType(cls.
|
25 |
|
26 |
@classmethod
|
27 |
-
def register(cls,
|
28 |
-
|
29 |
-
if mime_type in cls.
|
30 |
raise ValueError(f"The MIME type, {mime_type}, is already registered.")
|
31 |
-
cls.
|
32 |
-
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
from importlib import import_module
|
3 |
from types import MappingProxyType
|
4 |
+
from typing import Any, Callable, ClassVar, Mapping, Optional, overload, Set, TypeVar
|
5 |
|
6 |
from ctp_slack_bot.core import ApplicationComponentBase
|
7 |
from ctp_slack_bot.models import Content
|
|
|
14 |
pass
|
15 |
|
16 |
|
17 |
+
T = TypeVar('T', bound=MimeTypeHandler)
|
18 |
+
|
19 |
+
|
20 |
class MimeTypeHandlerRegistry:
|
21 |
|
22 |
+
__registry: ClassVar[dict[str, T]] = {}
|
23 |
|
24 |
@classmethod
|
25 |
+
def get_registry(cls) -> Mapping[str, T]:
|
26 |
import_module(__package__)
|
27 |
+
return MappingProxyType(cls.__registry)
|
28 |
|
29 |
@classmethod
|
30 |
+
def register(cls, mime_types: Set[str], handler_cls: T):
|
31 |
+
for mime_type in mime_types:
|
32 |
+
if mime_type in cls.__registry:
|
33 |
raise ValueError(f"The MIME type, {mime_type}, is already registered.")
|
34 |
+
cls.__registry[mime_type] = handler_cls
|
35 |
+
|
36 |
+
|
37 |
+
@overload
|
38 |
+
def mime_type_handler(cls: T) -> T: ...
|
39 |
+
|
40 |
+
@overload
|
41 |
+
def mime_type_handler(mime_types: Optional[Set[str] | str] = None) -> Callable[[T], T]: ...
|
42 |
+
|
43 |
+
def mime_type_handler(cls_or_mime_types=None):
|
44 |
+
def register_mime_type_handler(cls: T, mime_types: Optional[Set[str]] = None) -> T:
|
45 |
+
MimeTypeHandlerRegistry.register({mime_types} if isinstance(mime_types, str) else mime_types, cls)
|
46 |
+
return cls
|
47 |
+
if isinstance(cls_or_mime_types, type):
|
48 |
+
return register_mime_type_handler(cls_or_mime_types)
|
49 |
+
def decorator(cls: T) -> T:
|
50 |
+
return register_mime_type_handler(cls, cls_or_mime_types)
|
51 |
+
return decorator
|
src/ctp_slack_bot/mime_type_handlers/text/vtt.py
CHANGED
@@ -6,11 +6,11 @@ from types import MappingProxyType
|
|
6 |
from typing import Any, ClassVar, Mapping, Optional, Self
|
7 |
from webvtt import WebVTT
|
8 |
|
9 |
-
from ctp_slack_bot.mime_type_handlers.base import MimeTypeHandler,
|
10 |
from ctp_slack_bot.models import Content, WebVTTContent, WebVTTFrame
|
11 |
|
12 |
|
13 |
-
@
|
14 |
class WebVTTMimeTypeHandler(MimeTypeHandler):
|
15 |
|
16 |
model_config = ConfigDict(frozen=True)
|
|
|
6 |
from typing import Any, ClassVar, Mapping, Optional, Self
|
7 |
from webvtt import WebVTT
|
8 |
|
9 |
+
from ctp_slack_bot.mime_type_handlers.base import MimeTypeHandler, mime_type_handler
|
10 |
from ctp_slack_bot.models import Content, WebVTTContent, WebVTTFrame
|
11 |
|
12 |
|
13 |
+
@mime_type_handler("text/vtt")
|
14 |
class WebVTTMimeTypeHandler(MimeTypeHandler):
|
15 |
|
16 |
model_config = ConfigDict(frozen=True)
|
src/ctp_slack_bot/services/__init__.py
CHANGED
@@ -8,6 +8,6 @@ from .google_drive_service import GoogleDriveService
|
|
8 |
from .http_server_service import HTTPServer
|
9 |
from .language_model_service import LanguageModelService
|
10 |
from .question_dispatch_service import QuestionDispatchService
|
11 |
-
from. schedule_service import ScheduleService
|
12 |
from .slack_service import SlackService
|
|
|
13 |
from .vectorization_service import VectorizationService
|
|
|
8 |
from .http_server_service import HTTPServer
|
9 |
from .language_model_service import LanguageModelService
|
10 |
from .question_dispatch_service import QuestionDispatchService
|
|
|
11 |
from .slack_service import SlackService
|
12 |
+
from. task_service import TaskService
|
13 |
from .vectorization_service import VectorizationService
|
src/ctp_slack_bot/services/content_ingestion_service.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from loguru import logger
|
2 |
from pydantic import ConfigDict
|
3 |
from typing import Any, Self, Sequence, Set
|
@@ -18,15 +19,9 @@ class ContentIngestionService(ApplicationComponentBase):
|
|
18 |
model_config = ConfigDict(frozen=True)
|
19 |
|
20 |
settings: Settings
|
21 |
-
event_brokerage_service: EventBrokerageService
|
22 |
vectorized_chunk_repository: VectorizedChunkRepository
|
23 |
vectorization_service: VectorizationService
|
24 |
|
25 |
-
def model_post_init(self: Self, context: Any, /) -> None:
|
26 |
-
super().model_post_init(context)
|
27 |
-
self.event_brokerage_service.subscribe(EventType.INCOMING_CONTENT, self.process_incoming_content)
|
28 |
-
# self.event_brokerage_service.subscribe(EventType.INCOMING_SLACK_MESSAGE, self.process_incoming_slack_message)
|
29 |
-
|
30 |
async def process_incoming_content(self: Self, content: Content) -> None:
|
31 |
logger.debug("Content ingestion service received content with metadata: {}", content.get_metadata())
|
32 |
if self.vectorized_chunk_repository.count_by_id(content.get_id()):
|
@@ -49,3 +44,11 @@ class ContentIngestionService(ApplicationComponentBase):
|
|
49 |
@property
|
50 |
def name(self: Self) -> str:
|
51 |
return "content_ingestion_service"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dependency_injector.resources import AsyncResource
|
2 |
from loguru import logger
|
3 |
from pydantic import ConfigDict
|
4 |
from typing import Any, Self, Sequence, Set
|
|
|
19 |
model_config = ConfigDict(frozen=True)
|
20 |
|
21 |
settings: Settings
|
|
|
22 |
vectorized_chunk_repository: VectorizedChunkRepository
|
23 |
vectorization_service: VectorizationService
|
24 |
|
|
|
|
|
|
|
|
|
|
|
25 |
async def process_incoming_content(self: Self, content: Content) -> None:
|
26 |
logger.debug("Content ingestion service received content with metadata: {}", content.get_metadata())
|
27 |
if self.vectorized_chunk_repository.count_by_id(content.get_id()):
|
|
|
44 |
@property
|
45 |
def name(self: Self) -> str:
|
46 |
return "content_ingestion_service"
|
47 |
+
|
48 |
+
|
49 |
+
class ContentIngestionServiceResource(AsyncResource):
|
50 |
+
async def init(self: Self, settings: Settings, event_brokerage_service: EventBrokerageService, vectorized_chunk_repository: VectorizedChunkRepository, vectorization_service: VectorizationService) -> ContentIngestionService:
|
51 |
+
content_ingestion_service = ContentIngestionService(settings=settings, vectorized_chunk_repository=vectorized_chunk_repository, vectorization_service=vectorization_service)
|
52 |
+
await event_brokerage_service.subscribe(EventType.INCOMING_CONTENT, content_ingestion_service.process_incoming_content)
|
53 |
+
# await event_brokerage_service.subscribe(EventType.INCOMING_SLACK_MESSAGE, content_ingestion_service.process_incoming_slack_message)
|
54 |
+
return content_ingestion_service
|
src/ctp_slack_bot/services/event_brokerage_service.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from asyncio import create_task, iscoroutinefunction, to_thread
|
2 |
from collections import defaultdict
|
3 |
from loguru import logger
|
4 |
from pydantic import ConfigDict, PrivateAttr
|
@@ -15,18 +15,19 @@ class EventBrokerageService(ApplicationComponentBase):
|
|
15 |
|
16 |
model_config = ConfigDict(frozen=True)
|
17 |
|
18 |
-
|
|
|
19 |
|
20 |
-
def subscribe(self: Self, type: EventType, callback: Callable) -> None:
|
21 |
"""Subscribe to an event type with a callback function."""
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
|
27 |
async def publish(self: Self, type: EventType, data: Any = None) -> None:
|
28 |
"""Publish an event with optional data to all subscribers."""
|
29 |
-
subscribers = self.
|
30 |
if not subscribers:
|
31 |
logger.debug("No subscribers handle event {}: {}", type, len(subscribers), data)
|
32 |
return
|
|
|
1 |
+
from asyncio import create_task, iscoroutinefunction, Lock, to_thread
|
2 |
from collections import defaultdict
|
3 |
from loguru import logger
|
4 |
from pydantic import ConfigDict, PrivateAttr
|
|
|
15 |
|
16 |
model_config = ConfigDict(frozen=True)
|
17 |
|
18 |
+
__write_lock: Lock = PrivateAttr(default_factory=Lock)
|
19 |
+
__subscribers: MutableMapping[EventType, tuple[Callable]] = PrivateAttr(default_factory=lambda: defaultdict(tuple))
|
20 |
|
21 |
+
async def subscribe(self: Self, type: EventType, callback: Callable) -> None:
|
22 |
"""Subscribe to an event type with a callback function."""
|
23 |
+
async with self.__write_lock:
|
24 |
+
subscribers = self.__subscribers[type]
|
25 |
+
self.__subscribers[type] = subscribers + (callback, )
|
26 |
+
logger.debug("One new subscriber was added for event type {} ({} subscriber(s) in total).", type, len(subscribers))
|
27 |
|
28 |
async def publish(self: Self, type: EventType, data: Any = None) -> None:
|
29 |
"""Publish an event with optional data to all subscribers."""
|
30 |
+
subscribers = self.__subscribers[type]
|
31 |
if not subscribers:
|
32 |
logger.debug("No subscribers handle event {}: {}", type, len(subscribers), data)
|
33 |
return
|
src/ctp_slack_bot/services/google_drive_service.py
CHANGED
@@ -39,7 +39,6 @@ class GoogleDriveService(ApplicationComponentBase):
|
|
39 |
"token_uri": self.settings.google_token_uri,
|
40 |
}, scopes=["https://www.googleapis.com/auth/drive"])
|
41 |
self._google_drive_client = build('drive', 'v3', credentials=credentials)
|
42 |
-
logger.info(type(self._google_drive_client))
|
43 |
|
44 |
def _resolve_folder_id(self: Self, folder_path: str) -> Optional[str]:
|
45 |
"""Resolve a folder path to a Google Drive ID."""
|
|
|
39 |
"token_uri": self.settings.google_token_uri,
|
40 |
}, scopes=["https://www.googleapis.com/auth/drive"])
|
41 |
self._google_drive_client = build('drive', 'v3', credentials=credentials)
|
|
|
42 |
|
43 |
def _resolve_folder_id(self: Self, folder_path: str) -> Optional[str]:
|
44 |
"""Resolve a folder path to a Google Drive ID."""
|
src/ctp_slack_bot/services/question_dispatch_service.py
CHANGED
@@ -1,3 +1,4 @@
|
|
|
|
1 |
from loguru import logger
|
2 |
from pydantic import ConfigDict
|
3 |
from typing import Any, Self
|
@@ -18,15 +19,10 @@ class QuestionDispatchService(ApplicationComponentBase):
|
|
18 |
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
19 |
|
20 |
settings: Settings
|
21 |
-
event_brokerage_service: EventBrokerageService
|
22 |
context_retrieval_service: ContextRetrievalService
|
23 |
answer_retrieval_service: AnswerRetrievalService
|
24 |
|
25 |
-
def
|
26 |
-
super().model_post_init(context)
|
27 |
-
self.event_brokerage_service.subscribe(EventType.INCOMING_SLACK_MESSAGE, self.__process_incoming_slack_message)
|
28 |
-
|
29 |
-
async def __process_incoming_slack_message(self: Self, message: SlackMessage) -> None:
|
30 |
if message.subtype != 'bot_message':
|
31 |
logger.debug("Question dispatch service received an answerable question: {}", message.text)
|
32 |
context = await self.context_retrieval_service.get_context(message)
|
@@ -35,3 +31,10 @@ class QuestionDispatchService(ApplicationComponentBase):
|
|
35 |
@property
|
36 |
def name(self: Self) -> str:
|
37 |
return "question_dispatch_service"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dependency_injector.resources import AsyncResource
|
2 |
from loguru import logger
|
3 |
from pydantic import ConfigDict
|
4 |
from typing import Any, Self
|
|
|
19 |
model_config = ConfigDict(arbitrary_types_allowed=True, frozen=True)
|
20 |
|
21 |
settings: Settings
|
|
|
22 |
context_retrieval_service: ContextRetrievalService
|
23 |
answer_retrieval_service: AnswerRetrievalService
|
24 |
|
25 |
+
async def process_incoming_slack_message(self: Self, message: SlackMessage) -> None:
|
|
|
|
|
|
|
|
|
26 |
if message.subtype != 'bot_message':
|
27 |
logger.debug("Question dispatch service received an answerable question: {}", message.text)
|
28 |
context = await self.context_retrieval_service.get_context(message)
|
|
|
31 |
@property
|
32 |
def name(self: Self) -> str:
|
33 |
return "question_dispatch_service"
|
34 |
+
|
35 |
+
|
36 |
+
class QuestionDispatchServiceResource(AsyncResource):
|
37 |
+
async def init(self: Self, settings: Settings, event_brokerage_service: EventBrokerageService, context_retrieval_service: ContextRetrievalService, answer_retrieval_service: AnswerRetrievalService) -> QuestionDispatchService:
|
38 |
+
question_dispatch_service = QuestionDispatchService(settings=settings, context_retrieval_service=context_retrieval_service, answer_retrieval_service=answer_retrieval_service)
|
39 |
+
await event_brokerage_service.subscribe(EventType.INCOMING_SLACK_MESSAGE, question_dispatch_service.process_incoming_slack_message)
|
40 |
+
return question_dispatch_service
|
src/ctp_slack_bot/services/slack_service.py
CHANGED
@@ -2,14 +2,15 @@ from dependency_injector.resources import AsyncResource
|
|
2 |
from httpx import AsyncClient
|
3 |
from loguru import logger
|
4 |
from openai import OpenAI
|
5 |
-
from pydantic import ConfigDict
|
6 |
from re import compile as compile_re, Pattern
|
7 |
-
from
|
8 |
from slack_bolt.async_app import AsyncApp
|
|
|
9 |
from slack_sdk.web.async_slack_response import AsyncSlackResponse
|
10 |
from typing import Any, ClassVar, Mapping, MutableMapping, Optional, Self, Set
|
11 |
|
12 |
-
from ctp_slack_bot.core import HealthReportingApplicationComponentBase
|
13 |
from ctp_slack_bot.enums import EventType
|
14 |
from ctp_slack_bot.models import SlackMessage, SlackResponse
|
15 |
from .event_brokerage_service import EventBrokerageService
|
@@ -25,17 +26,54 @@ class SlackService(HealthReportingApplicationComponentBase):
|
|
25 |
_SLACK_USER_ID_PATTERN: ClassVar[Pattern] = compile_re(r"U\d+")
|
26 |
_SLACK_USER_MENTION_PATTERN: ClassVar[Pattern] = compile_re(r"<@(U[A-Z0-9]+)>")
|
27 |
|
|
|
28 |
event_brokerage_service: EventBrokerageService
|
29 |
http_client: AsyncClient
|
30 |
slack_bolt_app: AsyncApp
|
31 |
-
|
|
|
32 |
|
33 |
-
def initialize(self: Self) -> None:
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
self.slack_bolt_app.event("message")(self._handle_message_event)
|
36 |
self.slack_bolt_app.event("app_mention")(self._handle_app_mention_event)
|
37 |
logger.debug("Registered 2 handlers for Slack Bolt message and app mention events.")
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
async def send_message(self: Self, message: SlackResponse) -> None:
|
40 |
await self.slack_bolt_app.client.chat_postMessage(channel=message.channel, text=message.text, thread_ts=message.thread_ts)
|
41 |
|
@@ -64,16 +102,16 @@ class SlackService(HealthReportingApplicationComponentBase):
|
|
64 |
)
|
65 |
|
66 |
async def _ensure_ids_in_id_name_map(self: Self, ids: Set[str]) -> None:
|
67 |
-
unknown_ids = ids - self.
|
68 |
if len(unknown_ids) == 0:
|
69 |
return
|
70 |
async with TaskGroup() as task_group:
|
71 |
update_tasks = {unknown_id: task_group.create_task(self._look_up_name(unknown_id)) for unknown_id in unknown_ids}
|
72 |
-
self.
|
73 |
|
74 |
async def _get_name(self: Self, id: str) -> str:
|
75 |
await self._ensure_ids_in_id_name_map({id})
|
76 |
-
return self.
|
77 |
|
78 |
async def _handle_message_event(self: Self, body: Mapping[str, Any]) -> None:
|
79 |
logger.debug("Ignored regular message: {}", body.get("event", {}).get("text"))
|
@@ -113,46 +151,13 @@ class SlackService(HealthReportingApplicationComponentBase):
|
|
113 |
start, end = match.span()
|
114 |
parts.append(text[previous_end:start])
|
115 |
user_id = match.group(1)
|
116 |
-
parts.append(f"@{self.
|
117 |
previous_end = end
|
118 |
parts.append(text[previous_end:])
|
119 |
return ''.join(parts)
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
async def init(self: Self, event_brokerage_service: EventBrokerageService, http_client: AsyncClient, slack_bolt_app: AsyncApp) -> SlackService:
|
124 |
-
async def get_users_list():
|
125 |
-
cursor = None
|
126 |
-
while True:
|
127 |
-
try:
|
128 |
-
response = await slack_bolt_app.client.users_list(cursor=cursor, limit=200)
|
129 |
-
except SlackApiError as e:
|
130 |
-
logger.warning("Could not get a list of users: {}", e)
|
131 |
-
break
|
132 |
-
match response:
|
133 |
-
case AsyncSlackResponse(status_code=200, data={"ok": True, "members": users}):
|
134 |
-
for user in users:
|
135 |
-
yield user
|
136 |
-
match response.data:
|
137 |
-
case {"response_metadata": {"next_cursor": cursor}} if cursor:
|
138 |
-
continue
|
139 |
-
case AsyncSlackResponse(status_code=status_code) if status_code != 200:
|
140 |
-
logger.warning("Could not get a list of users: response status {}", status_code)
|
141 |
-
case AsyncSlackResponse(data={"ok": False}):
|
142 |
-
logger.warning("Could not get a list of users: non-OK response")
|
143 |
-
case _:
|
144 |
-
logger.warning("Could not get a list of users.")
|
145 |
-
break
|
146 |
-
id_name_map = {user["id"]: self._get_name(user)
|
147 |
-
async for user
|
148 |
-
in get_users_list()}
|
149 |
-
logger.debug("Obtained a list of {} user name(s) for the workspace: {}", len(id_name_map), id_name_map)
|
150 |
-
slack_service = SlackService(event_brokerage_service=event_brokerage_service, http_client=http_client, slack_bolt_app=slack_bolt_app, id_name_map=id_name_map)
|
151 |
-
slack_service.initialize()
|
152 |
-
return slack_service
|
153 |
-
|
154 |
-
@classmethod
|
155 |
-
def _get_name(cls, user: Mapping[str, Any]):
|
156 |
match user:
|
157 |
case {"real_name": real_name}:
|
158 |
return real_name
|
@@ -160,3 +165,12 @@ class SlackServiceResource(AsyncResource):
|
|
160 |
return display_name
|
161 |
case {"name": name}:
|
162 |
return name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from httpx import AsyncClient
|
3 |
from loguru import logger
|
4 |
from openai import OpenAI
|
5 |
+
from pydantic import ConfigDict, PrivateAttr
|
6 |
from re import compile as compile_re, Pattern
|
7 |
+
from slack_bolt.adapter.socket_mode.async_handler import AsyncSocketModeHandler
|
8 |
from slack_bolt.async_app import AsyncApp
|
9 |
+
from slack_sdk.errors import SlackApiError
|
10 |
from slack_sdk.web.async_slack_response import AsyncSlackResponse
|
11 |
from typing import Any, ClassVar, Mapping, MutableMapping, Optional, Self, Set
|
12 |
|
13 |
+
from ctp_slack_bot.core import HealthReportingApplicationComponentBase, Settings
|
14 |
from ctp_slack_bot.enums import EventType
|
15 |
from ctp_slack_bot.models import SlackMessage, SlackResponse
|
16 |
from .event_brokerage_service import EventBrokerageService
|
|
|
26 |
_SLACK_USER_ID_PATTERN: ClassVar[Pattern] = compile_re(r"U\d+")
|
27 |
_SLACK_USER_MENTION_PATTERN: ClassVar[Pattern] = compile_re(r"<@(U[A-Z0-9]+)>")
|
28 |
|
29 |
+
settings: Settings
|
30 |
event_brokerage_service: EventBrokerageService
|
31 |
http_client: AsyncClient
|
32 |
slack_bolt_app: AsyncApp
|
33 |
+
socket_mode_handler: AsyncSocketModeHandler
|
34 |
+
_id_name_map: MutableMapping[str, str] = PrivateAttr(default={}) # TODO: Spin message processing out into its own service.
|
35 |
|
36 |
+
async def initialize(self: Self) -> None:
|
37 |
+
async def get_users_list():
|
38 |
+
cursor = None
|
39 |
+
while True:
|
40 |
+
try:
|
41 |
+
response = await self.slack_bolt_app.client.users_list(cursor=cursor, limit=200)
|
42 |
+
except SlackApiError as e:
|
43 |
+
logger.warning("Could not get a list of users: {}", e)
|
44 |
+
break
|
45 |
+
match response:
|
46 |
+
case AsyncSlackResponse(status_code=200, data={"ok": True, "members": users}):
|
47 |
+
for user in users:
|
48 |
+
yield user
|
49 |
+
match response.data:
|
50 |
+
case {"response_metadata": {"next_cursor": cursor}} if cursor:
|
51 |
+
continue
|
52 |
+
case AsyncSlackResponse(status_code=status_code) if status_code != 200:
|
53 |
+
logger.warning("Could not get a list of users: response status {}", status_code)
|
54 |
+
case AsyncSlackResponse(data={"ok": False}):
|
55 |
+
logger.warning("Could not get a list of users: non-OK response")
|
56 |
+
case _:
|
57 |
+
logger.warning("Could not get a list of users.")
|
58 |
+
break
|
59 |
+
id_name_map = {user["id"]: self._resolve_user_name(user)
|
60 |
+
async for user
|
61 |
+
in get_users_list()}
|
62 |
+
self._id_name_map.update(id_name_map)
|
63 |
+
logger.debug("Obtained a list of {} user name(s) for the workspace: {}", len(id_name_map), id_name_map)
|
64 |
+
|
65 |
+
await self.event_brokerage_service.subscribe(EventType.OUTGOING_SLACK_RESPONSE, self.send_message)
|
66 |
self.slack_bolt_app.event("message")(self._handle_message_event)
|
67 |
self.slack_bolt_app.event("app_mention")(self._handle_app_mention_event)
|
68 |
logger.debug("Registered 2 handlers for Slack Bolt message and app mention events.")
|
69 |
|
70 |
+
async def start(self: Self) -> None:
|
71 |
+
await self.socket_mode_handler.start_async()
|
72 |
+
|
73 |
+
async def stop(self: Self) -> None:
|
74 |
+
await self.socket_mode_handler.close_async()
|
75 |
+
logger.info("Stopped Slack Bolt socket mode handler and Slack service.")
|
76 |
+
|
77 |
async def send_message(self: Self, message: SlackResponse) -> None:
|
78 |
await self.slack_bolt_app.client.chat_postMessage(channel=message.channel, text=message.text, thread_ts=message.thread_ts)
|
79 |
|
|
|
102 |
)
|
103 |
|
104 |
async def _ensure_ids_in_id_name_map(self: Self, ids: Set[str]) -> None:
|
105 |
+
unknown_ids = ids - self._id_name_map.keys()
|
106 |
if len(unknown_ids) == 0:
|
107 |
return
|
108 |
async with TaskGroup() as task_group:
|
109 |
update_tasks = {unknown_id: task_group.create_task(self._look_up_name(unknown_id)) for unknown_id in unknown_ids}
|
110 |
+
self._id_name_map.update({id: task.result() for id, task in update_tasks.items() if task.result()})
|
111 |
|
112 |
async def _get_name(self: Self, id: str) -> str:
|
113 |
await self._ensure_ids_in_id_name_map({id})
|
114 |
+
return self._id_name_map.get(id, id)
|
115 |
|
116 |
async def _handle_message_event(self: Self, body: Mapping[str, Any]) -> None:
|
117 |
logger.debug("Ignored regular message: {}", body.get("event", {}).get("text"))
|
|
|
151 |
start, end = match.span()
|
152 |
parts.append(text[previous_end:start])
|
153 |
user_id = match.group(1)
|
154 |
+
parts.append(f"@{self._id_name_map.get(user_id, user_id)}")
|
155 |
previous_end = end
|
156 |
parts.append(text[previous_end:])
|
157 |
return ''.join(parts)
|
158 |
|
159 |
+
@staticmethod
|
160 |
+
def _resolve_user_name(user: Mapping[str, Any]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
match user:
|
162 |
case {"real_name": real_name}:
|
163 |
return real_name
|
|
|
165 |
return display_name
|
166 |
case {"name": name}:
|
167 |
return name
|
168 |
+
|
169 |
+
|
170 |
+
class SlackServiceResource(AsyncResource):
|
171 |
+
async def init(self: Self, settings: Settings, event_brokerage_service: EventBrokerageService, http_client: AsyncClient) -> SlackService:
|
172 |
+
slack_bolt_app = AsyncApp(token=settings.slack_bot_token.get_secret_value())
|
173 |
+
socket_mode_handler = AsyncSocketModeHandler(slack_bolt_app, settings.slack_app_token.get_secret_value())
|
174 |
+
slack_service = SlackService(settings=settings, event_brokerage_service=event_brokerage_service, http_client=http_client, slack_bolt_app=slack_bolt_app, socket_mode_handler=socket_mode_handler)
|
175 |
+
await slack_service.initialize()
|
176 |
+
return slack_service
|
src/ctp_slack_bot/services/{schedule_service.py → task_service.py}
RENAMED
@@ -1,17 +1,16 @@
|
|
1 |
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
2 |
from apscheduler.triggers.cron import CronTrigger
|
3 |
-
from asyncio import create_task, iscoroutinefunction, to_thread
|
4 |
from datetime import datetime
|
5 |
-
from dependency_injector.resources import
|
6 |
from loguru import logger
|
7 |
from pydantic import ConfigDict
|
8 |
from pytz import timezone
|
9 |
-
from typing import Any,
|
10 |
|
11 |
from ctp_slack_bot.core import ApplicationComponentBase, Settings
|
12 |
|
13 |
|
14 |
-
class
|
15 |
"""
|
16 |
Service for running scheduled tasks.
|
17 |
"""
|
@@ -43,28 +42,25 @@ class ScheduleService(ApplicationComponentBase):
|
|
43 |
# )
|
44 |
pass
|
45 |
|
46 |
-
def start(self: Self) -> None:
|
|
|
47 |
self._scheduler.start()
|
48 |
|
49 |
-
def stop(self: Self) -> None:
|
50 |
if self._scheduler.running:
|
51 |
self._scheduler.shutdown()
|
|
|
52 |
else:
|
53 |
logger.debug("The scheduler is not running. There is no scheduler to shut down.")
|
54 |
|
55 |
@property
|
56 |
def name(self: Self) -> str:
|
57 |
-
return "
|
58 |
|
59 |
|
60 |
-
class
|
61 |
-
def init(self: Self, settings: Settings) ->
|
62 |
-
|
63 |
-
schedule_service = ScheduleService(settings=settings)
|
64 |
-
schedule_service.start()
|
65 |
-
return schedule_service
|
66 |
|
67 |
-
def shutdown(self: Self,
|
68 |
-
|
69 |
-
schedule_service.stop()
|
70 |
-
logger.info("Stopped scheduler.")
|
|
|
1 |
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
2 |
from apscheduler.triggers.cron import CronTrigger
|
|
|
3 |
from datetime import datetime
|
4 |
+
from dependency_injector.resources import AsyncResource
|
5 |
from loguru import logger
|
6 |
from pydantic import ConfigDict
|
7 |
from pytz import timezone
|
8 |
+
from typing import Any, Self
|
9 |
|
10 |
from ctp_slack_bot.core import ApplicationComponentBase, Settings
|
11 |
|
12 |
|
13 |
+
class TaskService(ApplicationComponentBase):
|
14 |
"""
|
15 |
Service for running scheduled tasks.
|
16 |
"""
|
|
|
42 |
# )
|
43 |
pass
|
44 |
|
45 |
+
async def start(self: Self) -> None:
|
46 |
+
logger.info("Starting scheduler…")
|
47 |
self._scheduler.start()
|
48 |
|
49 |
+
async def stop(self: Self) -> None:
|
50 |
if self._scheduler.running:
|
51 |
self._scheduler.shutdown()
|
52 |
+
logger.info("Stopped scheduler.")
|
53 |
else:
|
54 |
logger.debug("The scheduler is not running. There is no scheduler to shut down.")
|
55 |
|
56 |
@property
|
57 |
def name(self: Self) -> str:
|
58 |
+
return "task_service"
|
59 |
|
60 |
|
61 |
+
class TaskServiceResource(AsyncResource):
|
62 |
+
async def init(self: Self, settings: Settings) -> TaskService:
|
63 |
+
return TaskService(settings=settings)
|
|
|
|
|
|
|
64 |
|
65 |
+
async def shutdown(self: Self, task_service: TaskService) -> None:
|
66 |
+
await task_service.stop()
|
|
|
|