aittalam commited on
Commit
7ba9119
·
verified ·
1 Parent(s): ca737e0

Upload 13 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ bin/llamafiler filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Create container
2
+ FROM python:3.11-slim AS out
3
+
4
+ RUN apt-get update && \
5
+ apt-get install -y wget git
6
+
7
+ # Create a non-root user
8
+ RUN addgroup --gid 1000 user && \
9
+ adduser --uid 1000 --gid 1000 --disabled-password --gecos "" user
10
+
11
+ # Set working directory
12
+ WORKDIR /home/user
13
+
14
+ # Download default embedding model
15
+ RUN wget https://huggingface.co/leliuga/all-MiniLM-L6-v2-GGUF/resolve/main/all-MiniLM-L6-v2.F16.gguf
16
+
17
+ # Copy the repo's code
18
+ COPY . /home/user/byota
19
+ ## Clone the repo's code - when the repo is public
20
+ #RUN git clone https://github.com/mozilla-ai/byota.git && \
21
+
22
+ RUN chown -R user:user /home/user && \
23
+ chmod +x /home/user/byota/entrypoint.sh && \
24
+ chmod +x /home/user/byota/bin/llamafiler && \
25
+ pip install -r /home/user/byota/requirements.txt
26
+
27
+ ENV PATH="/home/user/:${PATH}"
28
+
29
+ # Switch to user
30
+ USER user
31
+
32
+ # Set entrypoint
33
+ ENTRYPOINT ["/home/user/byota/entrypoint.sh"]
34
+ CMD ["demo.py"]
README.md CHANGED
@@ -1,8 +1,8 @@
1
  ---
2
  title: Byota
3
- emoji: 🐠
4
- colorFrom: red
5
- colorTo: green
6
  sdk: docker
7
  pinned: false
8
  ---
 
1
  ---
2
  title: Byota
3
+ emoji:
4
+ colorFrom: purple
5
+ colorTo: pink
6
  sdk: docker
7
  pinned: false
8
  ---
bin/llamafiler ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39a12593adf6b6ab055ff339fd44fab6c8444646400968a8eef3183dd9084e9e
3
+ size 10492893
entrypoint.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e
3
+
4
+ # Start llamafiler
5
+ echo "Starting llamafiler..."
6
+ byota/bin/llamafiler -m all-MiniLM-L6-v2.F16.gguf -l 0.0.0.0:8080 -H "Access-Control-Allow-Origin: *" --trust 127.0.0.1/32 2> /tmp/llamafiler.logs &
7
+
8
+ # show llamafile start messages
9
+ sleep 1
10
+ head /tmp/llamafiler.logs
11
+
12
+ # Start marimo
13
+ cd byota/src && marimo run --headless --host 0.0.0.0 --port 7860 $@
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ altair==5.5.0
2
+ beautifulsoup4==4.13.3
3
+ loguru==0.7.3
4
+ marimo==0.11.21
5
+ Mastodon.py==2.0.1
6
+ pandas==2.2.3
7
+ platformdirs>=2.1
8
+ pyarrow==19.0.1
9
+ scikit-learn==1.6.1
src/byota/__init__.py ADDED
File without changes
src/byota/embeddings.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import numpy as np
3
+ import requests
4
+
5
+ # -- Embeddings --------------------------------------------------------------
6
+
7
+
8
+ class EmbeddingService:
9
+ def __init__(self, url: str, model: str = None):
10
+ self._url = url
11
+ self._model = model
12
+
13
+ def is_working(self) -> bool:
14
+ """Checks if the service is there and working by trying
15
+ to send an actual embedding request.
16
+ """
17
+ pass
18
+
19
+ def get_embedding(self, text: str) -> list:
20
+ """Given an input text, returns the embeddings as calculated
21
+ by the embedding service.
22
+ """
23
+ pass
24
+
25
+ def calculate_embeddings(self, texts: list[str], bar=None) -> np.ndarray:
26
+ """Given a list of input texts, returns all the embeddings
27
+ as a numpy array.
28
+ """
29
+
30
+ embeddings = []
31
+ for i, t in enumerate(texts):
32
+ embeddings.append(self.get_embedding(str(t)))
33
+ if bar is not None:
34
+ bar.update()
35
+ if not (i % 10):
36
+ print(".", end="")
37
+ return np.array(embeddings)
38
+
39
+
40
+ class LLamafileEmbeddingService(EmbeddingService):
41
+ def is_working(self):
42
+ response = requests.request(
43
+ url=self._url,
44
+ method="POST",
45
+ )
46
+ return response.status_code == 200
47
+
48
+ def get_embedding(self, text: str) -> list:
49
+ try:
50
+ response = requests.request(
51
+ url=self._url,
52
+ method="POST",
53
+ data={"content": text},
54
+ )
55
+ response.raise_for_status()
56
+ except requests.RequestException as e:
57
+ print(f"Request failed: {e}")
58
+ raise
59
+
60
+ return json.loads(response.text)["embedding"]
61
+
62
+
63
+ class OllamaEmbeddingService(EmbeddingService):
64
+ def __init__(self, url: str, model: str):
65
+ # model is compulsory for ollama
66
+ super().__init__(url, model)
67
+
68
+ def is_working(self):
69
+ response = requests.request(
70
+ url=self._url,
71
+ method="POST",
72
+ data=json.dumps({"model": self._model, "input": ""}),
73
+ )
74
+ return response.status_code
75
+
76
+ def get_embedding(self, text: str):
77
+ # workaround for ollama breaking with empty input text
78
+ if not text:
79
+ text = " "
80
+
81
+ try:
82
+ response = requests.request(
83
+ url=self._url,
84
+ method="POST",
85
+ data=json.dumps({"model": self._model, "input": text}),
86
+ )
87
+ response.raise_for_status()
88
+ except requests.RequestException as e:
89
+ print(f"Request failed: {e}")
90
+ raise
91
+
92
+ return json.loads(response.text)["embeddings"][0]
src/byota/mastodon.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mastodon
2
+ import marimo as mo
3
+ from loguru import logger
4
+
5
+ # -- Mastodon ----------------------------------------------------------------
6
+
7
+
8
+ def login(access_token: str, api_base_url: str):
9
+ """Checks if client credentials are available and logs user in."""
10
+
11
+ try:
12
+ mastodon_client = mastodon.Mastodon(
13
+ access_token=access_token, api_base_url=api_base_url
14
+ )
15
+
16
+ logger.debug(mastodon_client.app_verify_credentials())
17
+ except mastodon.errors.MastodonUnauthorizedError as e:
18
+ print(f"Mastodon auth error: {e}")
19
+ mastodon_client = None
20
+
21
+ return mastodon_client
22
+
23
+
24
+ def get_paginated_data(
25
+ mastodon_client: mastodon.Mastodon, timeline_type: str, max_pages: int = 40
26
+ ):
27
+ """Gets paginated statuses from one of the following timelines:
28
+ `home`, `local`, `public`, `tag/hashtag` or `list/id`.
29
+
30
+ See https://mastodonpy.readthedocs.io/en/stable/07_timelines.html
31
+ and https://docs.joinmastodon.org/methods/timelines/#home
32
+ """
33
+
34
+ tl = mastodon_client.timeline(timeline_type)
35
+
36
+ paginated_data = []
37
+ max_id = None
38
+ i = 1
39
+ with mo.status.progress_bar(
40
+ total=max_pages,
41
+ title=f"Downloading {max_pages} pages of posts from: {timeline_type}",
42
+ ) as bar:
43
+ while len(tl) > 0 and i <= max_pages:
44
+ print(f"Loading page {i}: max_id = {max_id}")
45
+ tl = mastodon_client.timeline(timeline_type, max_id=max_id)
46
+ if len(tl) > 0:
47
+ paginated_data.append(tl)
48
+
49
+ bar.update()
50
+ i += 1
51
+ if hasattr(tl, "_pagination_next") and tl._pagination_next is not None:
52
+ max_id = tl._pagination_next.get("max_id")
53
+ else:
54
+ print("No more pages available.")
55
+ break
56
+
57
+ return paginated_data
58
+
59
+
60
+ def get_paginated_statuses(
61
+ mastodon_client: mastodon.Mastodon,
62
+ max_pages: int = 1,
63
+ exclude_replies=False,
64
+ exclude_reblogs=False,
65
+ ):
66
+ """Gets paginated statuses from one of the following timelines:
67
+ `home`, `local`, `public`, `tag/hashtag` or `list/id`.
68
+
69
+ See https://mastodonpy.readthedocs.io/en/stable/07_timelines.html
70
+ and https://docs.joinmastodon.org/methods/timelines/#home
71
+ """
72
+
73
+ tl = mastodon_client.account_statuses(
74
+ mastodon_client.me()["id"],
75
+ exclude_replies=exclude_replies,
76
+ exclude_reblogs=exclude_reblogs,
77
+ )
78
+
79
+ paginated_data = []
80
+ max_id = None
81
+ i = 1
82
+ with mo.status.progress_bar(
83
+ total=max_pages,
84
+ title=f"Account Statuses (replies={not exclude_replies}, reblogs={not exclude_reblogs})",
85
+ ) as bar:
86
+ while len(tl) > 0 and i <= max_pages:
87
+ print(f"Loading page {i}: max_id = {max_id}")
88
+ tl = mastodon_client.account_statuses(
89
+ mastodon_client.me()["id"],
90
+ exclude_replies=exclude_replies,
91
+ exclude_reblogs=exclude_reblogs,
92
+ max_id=max_id,
93
+ )
94
+ if len(tl) > 0:
95
+ paginated_data.append(tl)
96
+
97
+ bar.update()
98
+ i += 1
99
+ if hasattr(tl, "_pagination_next") and tl._pagination_next is not None:
100
+ max_id = tl._pagination_next.get("max_id")
101
+ else:
102
+ print("No more pages available.")
103
+ break
104
+ return paginated_data
src/byota/search.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy import spatial
2
+ import numpy as np
3
+ from byota.embeddings import EmbeddingService
4
+ from loguru import logger
5
+
6
+ # -- Similarity --------------------------------------------------------------
7
+
8
+
9
+ class SearchService:
10
+ def __init__(self, embeddings: np.ndarray, embedding_service: EmbeddingService):
11
+ self._embeddings = embeddings
12
+ self._embedding_service = embedding_service
13
+ self._tree = spatial.KDTree(self._embeddings)
14
+
15
+ def prepare_query(self, query):
16
+ """A query can either be an integer ID (index in the dataframe)
17
+ or a string. As similarity is calculated among embeddings, this
18
+ method makes sure we always return an embedding.
19
+ """
20
+
21
+ def is_integer_string(s):
22
+ try:
23
+ int(s)
24
+ return True
25
+ except ValueError:
26
+ return False
27
+
28
+ if is_integer_string(query):
29
+ return self._embeddings[int(query)]
30
+ else:
31
+ return self._embedding_service.get_embedding(query)
32
+
33
+ def most_similar_indices(self, query, k=5):
34
+ """Given a query (whether as an integer index to a status or plain
35
+ text), return the k indices of the most similar embeddings.
36
+ """
37
+ if k > len(self._embeddings):
38
+ logger.warning(
39
+ "The number of neighbors k is greater than the number of samples. Setting k=num_samples"
40
+ )
41
+ k = len(self._embeddings)
42
+
43
+ q = self.prepare_query(query)
44
+
45
+ # get the k nearest neighbors' indices
46
+ return self._tree.query(q, k=k + 1)[1]
47
+
48
+ def most_similar_embeddings(self, query, k=5):
49
+ """Given a query (whether as an integer index to a status or plain
50
+ text), return the k most similar embeddings."""
51
+ indices = self.most_similar_indices(query, k)
52
+
53
+ return self._embeddings[indices]
src/data/dump_dataframes_demo.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ee52a4cb0e53367e049ee575dde509498ce5464b0addff64b98a9c447b0fb44
3
+ size 1296696
src/data/dump_embeddings_demo.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e53fb44ddaa7860f699bf9cf4f6d093b856ea65c550ba777ebef1df0bfa4584
3
+ size 7373084
src/data/dump_user_statuses_demo.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c7a2958dbf83679fb0f2a28e2c6bdc53b7ea57b0410813b61922c1efb24e3a2c
3
+ size 50811
src/demo.py ADDED
@@ -0,0 +1,645 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import marimo
2
+
3
+ __generated_with = "0.11.21"
4
+ app = marimo.App(width="medium")
5
+
6
+
7
+ @app.cell
8
+ def _():
9
+ # # Uncomment this code if you want to run the notebook on marimo cloud
10
+ # import micropip # type: ignore
11
+
12
+ # await micropip.install("Mastodon.py")
13
+ # await micropip.install("loguru")
14
+ return
15
+
16
+
17
+ @app.cell
18
+ def _():
19
+ import marimo as mo
20
+ import pickle
21
+ import time
22
+ import altair as alt
23
+ from sklearn.manifold import TSNE
24
+ import pandas as pd
25
+ from pathlib import Path
26
+ import json
27
+ import os
28
+ import numpy as np
29
+
30
+ from byota.embeddings import EmbeddingService, LLamafileEmbeddingService
31
+
32
+ from byota.search import SearchService
33
+
34
+ return (
35
+ EmbeddingService,
36
+ LLamafileEmbeddingService,
37
+ Path,
38
+ SearchService,
39
+ TSNE,
40
+ alt,
41
+ json,
42
+ mo,
43
+ np,
44
+ os,
45
+ pd,
46
+ pickle,
47
+ time,
48
+ )
49
+
50
+
51
+ @app.cell
52
+ def _():
53
+ # internal variables
54
+
55
+ # dump files for offline mode
56
+ dataframes_data_file = "data/dump_dataframes_demo.pkl"
57
+ embeddings_data_file = "data/dump_embeddings_demo.pkl"
58
+ user_statuses_data_file = "data/dump_user_statuses_demo.pkl"
59
+ return dataframes_data_file, embeddings_data_file, user_statuses_data_file
60
+
61
+
62
+ @app.cell
63
+ def _(mo):
64
+ mo.md(
65
+ """
66
+ # Build Your Own Timeline Algorithm
67
+
68
+ Welcome to BYOTA's demo!
69
+
70
+ This small Web application shows some of the things you could do running BYOTA's code on your own timeline.
71
+ As this is open for anyone to use, this version of the code does not connect to any real social network, but uses either synthetic data (to simulate posts in the home, local, and public timelines) or posts from [my Mastodon account](http://fosstodon.org/@mala).
72
+
73
+ If you want to use BYOTA with your own data, feel free to check its [⌨️ code](https://github.com/mozilla-ai/byota)
74
+ and [📖 documentation](https://mozilla-ai.github.io/byota/).
75
+
76
+ So, feel free to just click "submit" in the following Configuration form and... see what happens!
77
+ """
78
+ )
79
+ return
80
+
81
+
82
+ @app.cell
83
+ def _(configuration_form):
84
+ configuration_form
85
+ return
86
+
87
+
88
+ @app.cell
89
+ def _(
90
+ LLamafileEmbeddingService,
91
+ configuration_form,
92
+ dataframes_data_file,
93
+ invalid_form,
94
+ load_dataframes,
95
+ mo,
96
+ ):
97
+ mo.stop(
98
+ invalid_form(configuration_form),
99
+ mo.md("**Submit the form to continue.**").center(),
100
+ )
101
+
102
+ embedding_service = LLamafileEmbeddingService("http://localhost:8080/embedding")
103
+
104
+ mo.stop(
105
+ not embedding_service.is_working(),
106
+ mo.md("**Cannot access embedding server.**"),
107
+ )
108
+
109
+ # choose what to read from cache
110
+ cached_embeddings = configuration_form.value["offline_mode"]
111
+
112
+ dataframes = load_dataframes(dataframes_data_file)
113
+ mo.stop(dataframes is None, mo.md("**Issues loading dataframes**"))
114
+ return cached_embeddings, dataframes, embedding_service
115
+
116
+
117
+ @app.cell
118
+ def _(dataframes, mo):
119
+ mo.stop(dataframes is None)
120
+ mo.md(f"""
121
+ ### Calculating embeddings for the downloaded timeline{"s" if len(dataframes.keys())>1 else ""}.
122
+ """).center()
123
+ return
124
+
125
+
126
+ @app.cell
127
+ def _(
128
+ build_cache_embeddings,
129
+ cached_embeddings,
130
+ dataframes,
131
+ embedding_service,
132
+ embeddings_data_file,
133
+ mo,
134
+ ):
135
+ # calculate embeddings
136
+ embeddings = build_cache_embeddings(
137
+ embedding_service, dataframes, cached_embeddings, embeddings_data_file
138
+ )
139
+ mo.stop(embeddings is None, mo.md("**Issues calculating embeddings**"))
140
+ return (embeddings,)
141
+
142
+
143
+ @app.cell
144
+ def _(TSNE, alt, dataframes, embeddings, mo, np, pd):
145
+ def tsne(dataframes, embeddings, perplexity, random_state=42):
146
+ """Runs dimensionality reduction using TSNE on the input embeddings.
147
+ Returns dataframes containing status id, text, and 2D coordinates
148
+ for plotting.
149
+ """
150
+ tsne = TSNE(n_components=2, random_state=random_state, perplexity=perplexity)
151
+
152
+ all_embeddings = np.concatenate([v for v in embeddings.values()])
153
+ all_projections = tsne.fit_transform(all_embeddings)
154
+
155
+ dfs = []
156
+ start_idx = 0
157
+ end_idx = 0
158
+ for kk in embeddings:
159
+ end_idx += len(embeddings[kk])
160
+ df = dataframes[kk]
161
+ df["x"] = all_projections[start_idx:end_idx, 0]
162
+ df["y"] = all_projections[start_idx:end_idx, 1]
163
+ df["label"] = kk
164
+ dfs.append(df)
165
+ start_idx = end_idx
166
+
167
+ return pd.concat(dfs, ignore_index=True), all_embeddings
168
+
169
+ df_, all_embeddings = tsne(dataframes, embeddings, perplexity=4)
170
+
171
+ chart = mo.ui.altair_chart(
172
+ alt.Chart(df_, title="Timeline Visualization", height=500)
173
+ .mark_point()
174
+ .encode(x="x", y="y", color="label")
175
+ )
176
+ return all_embeddings, chart, df_, tsne
177
+
178
+
179
+ @app.cell
180
+ def _(chart, mo):
181
+ mo.vstack(
182
+ [
183
+ mo.md("# Embeddings visualization").center(),
184
+ mo.md("""
185
+ In this section, you can see posts from different timelines represented as points on a plane:
186
+ You can click on a timeline label on the top right to highlight only posts from that timeline.
187
+ If you select one or more points, you will see them in the table below the plot.
188
+ By clicking on the column names (e.g. `label`, `text`) you can sort them, wrap text (to see full
189
+ post contents), or search their content.
190
+ """),
191
+ chart,
192
+ chart.value[["id", "label", "text"]]
193
+ if len(chart.value) > 0
194
+ else chart.value,
195
+ ]
196
+ )
197
+ return
198
+
199
+
200
+ @app.cell
201
+ def _(embeddings, mo, query_form):
202
+ mo.stop(embeddings is None)
203
+
204
+ mo.vstack(
205
+ [
206
+ mo.md("# Timeline search"),
207
+ mo.md("""
208
+ Here you can search for the most similar posts to a given one.
209
+ You can either provide a row id (the leftmost column in the previous table) to refer to an existing post,
210
+ or freeform text to look for posts which are similar in content to what you wrote. Some examples:
211
+
212
+ - Book suggestions for scifi lovers
213
+ - Digital rights and free software
214
+ - Recipes for vegetarians (warning: sadly you won't get recipes from this dataset!)
215
+ - I like retrocomputing but also bouldering, now what?
216
+
217
+ """),
218
+ query_form,
219
+ ]
220
+ )
221
+ return
222
+
223
+
224
+ @app.cell
225
+ def _(SearchService, all_embeddings, df_, embedding_service, query_form):
226
+ search_service = SearchService(all_embeddings, embedding_service)
227
+ indices = search_service.most_similar_indices(query_form.value)
228
+ df_.iloc[indices][["label", "text"]]
229
+ return indices, search_service
230
+
231
+
232
+ @app.cell
233
+ def _(embeddings, mo, rerank_form):
234
+ mo.stop(embeddings is None)
235
+
236
+ mo.vstack(
237
+ [
238
+ mo.md("# Timeline Re-ranking"),
239
+ mo.md("""
240
+ In the previous sections, you saw that embeddings are reasonable descriptors for social media posts,
241
+ as they allow semantic similar statuses to be close in the embedding space. This allows you to use
242
+ the simple concept of *distance between points* to group statuses and search them.
243
+
244
+ In this section, you will perform actual timeline re-ranking. To do this, you'll still rely on the
245
+ concept of text similarity, assigning a higher score to those posts which are most similar to *a set
246
+ of other posts*. The set you'll use as a reference is the one of the posts you wrote or
247
+ reposted from others.
248
+
249
+ **NOTE**: For the sake of this open demo, the posts are not the ones *you* wrote, but I provided a subset of
250
+ those posted by https://fosstodon.org/@mala (that's me!). This way, you can get a better sense of
251
+ how this would work with some real data rather than a fully synthetic dataset.
252
+ """),
253
+ rerank_form,
254
+ ]
255
+ )
256
+ return
257
+
258
+
259
+ @app.cell
260
+ def _(
261
+ dataframes,
262
+ embedding_service,
263
+ embeddings,
264
+ load_dataframes,
265
+ mo,
266
+ np,
267
+ rerank_form,
268
+ time,
269
+ user_statuses_data_file,
270
+ ):
271
+ mo.stop(embeddings is None)
272
+
273
+ # check for anything invalid in the form
274
+ mo.stop(rerank_form.value is None, mo.md("**Submit the form to continue.**"))
275
+
276
+ timeline_to_rerank = rerank_form.value["timeline_to_rerank"]
277
+
278
+ user_statuses_df = load_dataframes(user_statuses_data_file)[
279
+ : 20 * rerank_form.value["num_user_status_pages"]
280
+ ]
281
+
282
+ mo.stop(user_statuses_df is None, mo.md("**Issues loading dataframes**"))
283
+
284
+ user_statuses_embeddings = embedding_service.calculate_embeddings(
285
+ user_statuses_df["text"]
286
+ )
287
+
288
+ # build an index of most similar statuses to the ones
289
+ # published / boosted by the user
290
+ rerank_start_time = time.time()
291
+ # index is in reverse order (from largest to smallest similarity)
292
+ idx = np.flip(
293
+ # return indices of the sorted list, instead of values
294
+ # we want to get pointers to statuses, not actual similarities
295
+ np.argsort(
296
+ # to measure how much I might like a timeline status,
297
+ # I sum all the similarity values calculated between
298
+ # that status and all the statuses in my feed
299
+ np.sum(
300
+ # dot product is a decent quick'n'dirty way to calculate
301
+ # similarity between two vectors (the more similar they
302
+ # are, the larger the product)
303
+ np.dot(user_statuses_embeddings, embeddings[timeline_to_rerank].T),
304
+ axis=0,
305
+ )
306
+ )
307
+ )
308
+
309
+ print(time.time() - rerank_start_time)
310
+
311
+ # show everything
312
+ mo.vstack(
313
+ [
314
+ mo.md("""## Your statuses:
315
+ This table shows the content of the posts that are used for re-ranking the timeline. You can change
316
+ their number in the form above (1 page = 20 posts), check them out here, and verify in the table below
317
+ this one how ranking changes depending on the contents you include.
318
+ """),
319
+ user_statuses_df,
320
+ mo.md("""## Your re-ranked timeline:
321
+ This table shows posts from the synthetic timelines (you can choose between home, local, and public
322
+ in the form above), re-ranked to prioritize the main topics inferred from the posts in the previous table.
323
+ """),
324
+ # show statuses sorted by idx
325
+ dataframes[timeline_to_rerank].iloc[idx][["label", "text"]],
326
+ ]
327
+ )
328
+ return (
329
+ idx,
330
+ rerank_start_time,
331
+ timeline_to_rerank,
332
+ user_statuses_df,
333
+ user_statuses_embeddings,
334
+ )
335
+
336
+
337
+ @app.cell
338
+ def _():
339
+ # # Wanna get some intuition re: the similarity measure?
340
+ # # Here's a simple example: the seven values you get are
341
+ # # the scores for the seven vectors in bbb (the higher
342
+ # # they are, the more similar vectors they have in aaa).
343
+ # # ... Can you tell why the third vector in bbb ([1,1,0,0])
344
+ # # is the most similar to vectors found in aaa?
345
+
346
+ # aaa = np.array([
347
+ # [1,0,0,0],
348
+ # [0,1,0,0],
349
+ # [0,0,1,0],
350
+ # [1,1,0,0],
351
+ # ]).astype(np.float32)
352
+
353
+ # bbb = np.array([
354
+ # [1,0,0,0],
355
+ # [0,1,0,0],
356
+ # [1,1,0,0],
357
+ # [0,0,1,0],
358
+ # [0,1,1,0],
359
+ # [0,0,0,1],
360
+ # [0,0,1,1],
361
+ # ]).astype(np.float32)
362
+
363
+ # np.sum(np.dot(aaa, bbb.T), axis=0)
364
+ return
365
+
366
+
367
+ @app.cell
368
+ def _(mo, rerank_form, tag_form):
369
+ mo.stop(rerank_form.value is None)
370
+
371
+ mo.vstack(
372
+ [
373
+ mo.md("""
374
+ # Re-Ranking your own posts
375
+ Depending on the timeline you are considering, it might be more or less hard
376
+ to understand how well the re-ranking worked.
377
+ To give you a better sense of the effect of re-ranking, let us take the posts
378
+ you wrote and re-rank them according to some well-known tag.
379
+ Feel free to test the following code with different tags, depending on your
380
+ various interests, and see whether your own posts related to a given interest
381
+ are surfaced by a related tag.
382
+
383
+ **NOTE: a couple of changes have been applied for the sake of having a functional demo:**
384
+
385
+ 1. Posts are not actually your own (see above).
386
+
387
+ 2. The word(s) that you enter below will be used to filter the existing posts in the
388
+ (synthetic) public timeline, rather than running a new tag search on the mastodon server.
389
+ This allows you to still get meaningful posts back without having to connect to an instance.
390
+
391
+ Some example search terms you could use: `#AI`, `bouldering`, `books`, `scifi`, `retrogaming`, `movies`.
392
+ If a search term is not found, you will simply see no results.
393
+ """),
394
+ tag_form,
395
+ ]
396
+ )
397
+ return
398
+
399
+
400
+ @app.cell
401
+ def _(
402
+ dataframes,
403
+ embedding_service,
404
+ mo,
405
+ np,
406
+ tag_form,
407
+ user_statuses_df,
408
+ user_statuses_embeddings,
409
+ ):
410
+ tag_name = tag_form.value
411
+
412
+ tag_posts_df = dataframes["public"][
413
+ dataframes["public"]["text"].str.contains(tag_name)
414
+ ]
415
+ tag_posts_embeddings = embedding_service.calculate_embeddings(tag_posts_df["text"])
416
+
417
+ # calculate the re-ranking index
418
+ my_idx = np.flip(
419
+ np.argsort(
420
+ np.sum(np.dot(tag_posts_embeddings, user_statuses_embeddings.T), axis=0)
421
+ )
422
+ )
423
+ # let us also show the similarity scores used to calculate the index
424
+ user_statuses_df["scores"] = np.sum(
425
+ np.dot(tag_posts_embeddings, user_statuses_embeddings.T), axis=0
426
+ )
427
+
428
+ mo.vstack(
429
+ [
430
+ mo.md(
431
+ f"### Your own posts, re-ranked according to their similarity to posts in {tag_name}"
432
+ ),
433
+ user_statuses_df.iloc[my_idx][["text", "scores"]],
434
+ ]
435
+ )
436
+ # my_posts_df[['text', 'scores']]
437
+ return my_idx, tag_name, tag_posts_df, tag_posts_embeddings
438
+
439
+
440
+ @app.cell
441
+ def _(mo):
442
+ # Create the Configuration form
443
+
444
+ configuration_form = (
445
+ mo.md(
446
+ """
447
+ # Configuration
448
+ (NOTE: settings will be ignored in this demo, data will be loaded from a file)
449
+
450
+ **Timelines**
451
+
452
+ {tl_home} {tl_local} {tl_public}
453
+
454
+ {tl_hashtag} {tl_hashtag_txt} {tl_list} {tl_list_txt}
455
+
456
+ **Embeddings**
457
+
458
+ {emb_server}
459
+
460
+ {emb_server_url}
461
+
462
+ {emb_server_model}
463
+
464
+ **Caching**
465
+
466
+ {offline_mode}
467
+ """
468
+ )
469
+ .batch(
470
+ tl_home=mo.ui.checkbox(label="Home", value=True),
471
+ tl_local=mo.ui.checkbox(label="Local", value=True),
472
+ tl_public=mo.ui.checkbox(label="Public", value=True),
473
+ tl_hashtag=mo.ui.checkbox(label="Hashtag"),
474
+ tl_list=mo.ui.checkbox(label="List"),
475
+ tl_hashtag_txt=mo.ui.text(),
476
+ tl_list_txt=mo.ui.text(),
477
+ emb_server=mo.ui.radio(
478
+ label="Server type:",
479
+ options=["llamafile", "ollama"],
480
+ value="llamafile",
481
+ inline=True,
482
+ ),
483
+ emb_server_url=mo.ui.text(
484
+ label="Embedding server URL:",
485
+ value="http://localhost:8080/embedding",
486
+ full_width=True,
487
+ ),
488
+ emb_server_model=mo.ui.text(
489
+ label="Embedding server model:", value="all-minilm"
490
+ ),
491
+ offline_mode=mo.ui.checkbox(label="Run in offline mode (experimental)"),
492
+ )
493
+ .form(show_clear_button=True, bordered=True)
494
+ )
495
+
496
+ # a dictionary mapping Timeline UI checkboxes with the respective
497
+ # strings that identify them in the Mastodon API
498
+ timelines_dict = {
499
+ "tl_home": "home",
500
+ "tl_local": "local",
501
+ "tl_public": "public",
502
+ "tl_hashtag": "tag",
503
+ "tl_list": "list",
504
+ }
505
+
506
+ def invalid_form(form):
507
+ """A form (e.g. login) is invalid if it has no value,
508
+ or if any of its keys have no value."""
509
+ if form.value is None:
510
+ return True
511
+
512
+ for k in form.value.keys():
513
+ if form.value[k] is None:
514
+ return True
515
+
516
+ return False
517
+
518
+ return configuration_form, invalid_form, timelines_dict
519
+
520
+
521
+ @app.cell
522
+ def _(mo):
523
+ # Create a form for timeline re-ranking
524
+ rerank_form = (
525
+ mo.md(
526
+ """
527
+ # Re-ranking settings
528
+
529
+ **User statuses** (NOTE: data will be loaded from a file)
530
+
531
+
532
+ {num_user_status_pages} {exclude_reblogs}
533
+
534
+ **Timeline to rerank**
535
+
536
+ {timeline_to_rerank}
537
+ """
538
+ )
539
+ .batch(
540
+ num_user_status_pages=mo.ui.slider(
541
+ start=1, stop=20, label="Number of pages to load", value=1
542
+ ),
543
+ timeline_to_rerank=mo.ui.radio(
544
+ options=["home", "local", "public"], value="public"
545
+ ),
546
+ exclude_reblogs=mo.ui.checkbox(label="Exclude reblogs", value=True),
547
+ )
548
+ .form(show_clear_button=True, bordered=True)
549
+ )
550
+ return (rerank_form,)
551
+
552
+
553
+ @app.cell
554
+ def _(mo):
555
+ query_form = mo.ui.text(
556
+ value="42",
557
+ label="Enter a status id or some free-form text to find the most similar statuses:\n",
558
+ full_width=True,
559
+ )
560
+ return (query_form,)
561
+
562
+
563
+ @app.cell
564
+ def _(mo):
565
+ tag_form = mo.ui.text(
566
+ value="retrogaming",
567
+ label="Enter a tag name:\n",
568
+ )
569
+ return (tag_form,)
570
+
571
+
572
+ @app.cell
573
+ def _(BeautifulSoup, EmbeddingService, mo, pickle, time):
574
+ def load_dataframes(data_file):
575
+ dataframes = None
576
+ print(f"Loading cached dataframes from {data_file}")
577
+ try:
578
+ with open(data_file, "rb") as f:
579
+ dataframes = pickle.load(f)
580
+ except FileNotFoundError:
581
+ print(f"File {data_file} not found.")
582
+
583
+ return dataframes
584
+
585
+ def build_cache_embeddings(
586
+ embedding_service: EmbeddingService, # type: ignore
587
+ dataframes: dict[str, any],
588
+ cached: bool,
589
+ embeddings_data_file: str,
590
+ ) -> dict[str, any]:
591
+ """Given a dictionary with dataframes from different timelines,
592
+ return another dictionary that contains, for each timeline, the
593
+ respective embeddings calculated with the provided embedding service.
594
+ If cached==True, the `embeddings_data_file` file will be loaded.
595
+ """
596
+ if not cached:
597
+ embeddings = {}
598
+ for k in dataframes:
599
+ with mo.status.progress_bar(
600
+ total=len(dataframes[k]), title=f"Embedding posts from: {k}"
601
+ ) as bar:
602
+ print(f"Embedding statuses from timeline: {k}")
603
+ tt_ = time.time()
604
+ embeddings[k] = embedding_service.calculate_embeddings(
605
+ dataframes[k]["text"], bar
606
+ )
607
+ print(time.time() - tt_)
608
+ with open(embeddings_data_file, "wb") as f:
609
+ pickle.dump(embeddings, f)
610
+ else:
611
+ print(f"Loading cached embeddings from {embeddings_data_file}")
612
+ try:
613
+ with open(embeddings_data_file, "rb") as f:
614
+ embeddings = pickle.load(f)
615
+ except FileNotFoundError:
616
+ print(f"File {embeddings_data_file} not found.")
617
+ return None
618
+
619
+ return embeddings
620
+
621
+ def get_compact_data(paginated_data: list) -> list[tuple[int, str]]:
622
+ """Extract compact (id, text) pairs from a paginated list of statuses."""
623
+ compact_data = []
624
+ for page in paginated_data:
625
+ for toot in page:
626
+ id = toot.id
627
+ cont = toot.content
628
+ if toot.reblog:
629
+ id = toot.reblog.id
630
+ cont = toot.reblog.content
631
+ soup = BeautifulSoup(cont, features="html.parser")
632
+ # print(f"{id}: {soup.get_text()}")
633
+ compact_data.append((id, soup.get_text()))
634
+ return compact_data
635
+
636
+ return build_cache_embeddings, get_compact_data, load_dataframes
637
+
638
+
639
+ @app.cell
640
+ def _():
641
+ return
642
+
643
+
644
+ if __name__ == "__main__":
645
+ app.run()