Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 3,992 Bytes
e8f13e9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import chromadb
import platform
import polars as pl
import polars as pl
from chromadb.utils import embedding_functions
from typing import List, Tuple, Optional
from huggingface_hub import InferenceClient
from tqdm.contrib.concurrent import thread_map
from huggingface_hub import login
from dotenv import load_dotenv
import os
from datetime import datetime, timedelta
import stamina
import requests
import polars as pl
from typing import Literal
load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
def get_save_path() -> Literal["chroma/"] | Literal["/data/chroma/"]:
return "chroma/" if platform.system() == "Darwin" else "/data/chroma/"
save_path = get_save_path()
chroma_client = chromadb.PersistentClient(
path=save_path,
)
sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(
model_name="Snowflake/snowflake-arctic-embed-m-long", trust_remote_code=True
)
collection = chroma_client.create_collection(
name="dataset_cards", get_or_create=True, embedding_function=sentence_transformer_ef
)
def get_last_modified_in_collection() -> datetime | None:
all_items = collection.get(
include=[
"metadatas",
]
)
if last_modified := [
datetime.fromisoformat(item["last_modified"]) for item in all_items["metadatas"]
]:
return max(last_modified)
else:
return None
def parse_markdown_column(
df: pl.DataFrame, markdown_column: str, dataset_id_column: str
) -> pl.DataFrame:
return df.with_columns(
parsed_markdown=(
pl.col(markdown_column)
.str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1)
.fill_null(pl.col(markdown_column))
.str.strip_chars()
),
prepended_markdown=(
pl.concat_str(
[
pl.lit("Dataset ID "),
pl.col(dataset_id_column).cast(pl.Utf8),
pl.lit("\n\n"),
pl.col(markdown_column)
.str.extract(r"(?s)^---.*?---\s*(.*)", group_index=1)
.fill_null(pl.col(markdown_column))
.str.strip_chars(),
]
)
),
)
def load_cards(
min_len: int = 50,
min_likes: int | None = None,
last_modified: Optional[datetime] = None,
) -> (
None
| Tuple[
List[str],
List[str],
List[datetime],
]
):
df = pl.read_parquet(
"hf://datasets/librarian-bots/dataset_cards_with_metadata_with_embeddings/data/train-00000-of-00001.parquet"
)
df = parse_markdown_column(df, "card", "datasetId")
df = df.with_columns(pl.col("parsed_markdown").str.len_chars().alias("card_len"))
print(df)
df = df.filter(pl.col("card_len") > min_len)
print(df)
if min_likes:
df = df.filter(pl.col("likes") > min_likes)
if last_modified:
df = df.filter(pl.col("last_modified") > last_modified)
if len(df) == 0:
return None
cards = df.get_column("prepended_markdown").to_list()
model_ids = df.get_column("datasetId").to_list()
last_modifieds = df.get_column("last_modified").to_list()
return cards, model_ids, last_modifieds
client = InferenceClient(
model="https://pqzap00ebpl1ydt4.us-east-1.aws.endpoints.huggingface.cloud",
token=HF_TOKEN,
)
@stamina.retry(on=requests.HTTPError, attempts=3, wait_initial=10)
def embed_card(text):
text = text[:8192]
return client.feature_extraction(text)
most_recent = get_last_modified_in_collection()
if data := load_cards(min_len=200, min_likes=None, last_modified=most_recent):
cards, model_ids, last_modifieds = data
print("mapping...")
results = thread_map(embed_card, cards)
collection.upsert(
ids=model_ids,
embeddings=[embedding.tolist()[0] for embedding in results],
metadatas=[{"last_modified": str(lm)} for lm in last_modifieds],
)
print("done")
else:
print("no new data")
|