xet-repo-data-collection / list_reconstructions.py
znation's picture
znation HF staff
let's try this
f624d68
import os
import json
import sys
import time
import grequests
import sqlite3
from tqdm import tqdm
import list_files
import list_repos
SQLITE3_DB = "data/reconstructions.sqlite3"
HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co")
XET_CAS_ENDPOINT = os.getenv("XET_CAS_ENDPOINT", "https://cas-server.xethub.hf.co")
RESOLVE_URL_TEMPLATE = HF_ENDPOINT + "/{}/resolve/main"
def exception_handler(req, exc):
print(exc, file=sys.stderr)
def list_reconstructions_from_hub(repo):
print(
"Listing reconstructions using:\nHF Hub Endpoint: {}\nXet CAS Endpoint: {}".format(
HF_ENDPOINT, XET_CAS_ENDPOINT
),
file=sys.stderr,
)
ret = []
files = []
resolve_reqs = []
reconstruct_reqs = []
err_count = 0
print("Listing files for repo {}".format(repo), file=sys.stderr)
total = 0
for i, file in tqdm(enumerate(list_files.list_lfs_files(repo))):
total += 1
files.append(file["name"])
if repo.startswith("models/"):
repo = repo.replace("models/", "", 1)
url = file["name"].replace(repo, RESOLVE_URL_TEMPLATE.format(repo), 1)
headers = {"Authorization": "Bearer {}".format(os.getenv("HF_TOKEN"))}
resolve_reqs.append(
grequests.head(url, headers=headers, allow_redirects=False)
)
print("", file=sys.stderr)
print("Calling /resolve/ for repo {}".format(repo), file=sys.stderr)
for i, resp in tqdm(
grequests.imap_enumerated(
resolve_reqs, size=4, exception_handler=exception_handler
),
total=total,
):
if resp is None:
err_count += 1
continue
# todo: use refresh_route when access_token is expired
refresh_route = resp.headers.get("x-xet-refresh-route")
xet_hash = resp.headers.get("x-xet-hash")
access_token = resp.headers.get("x-xet-access-token")
if xet_hash is not None and xet_hash != "":
url = "{}/reconstruction/{}".format(XET_CAS_ENDPOINT, xet_hash)
headers = {"Authorization": "Bearer {}".format(access_token)}
reconstruct_reqs.append(grequests.get(url, headers=headers))
print("", file=sys.stderr)
print(
"Calling /reconstruct/ with grequests for repo {}".format(repo),
file=sys.stderr,
)
for i, resp in tqdm(
grequests.imap_enumerated(
reconstruct_reqs, size=4, exception_handler=exception_handler
),
total=total,
):
if resp is None:
continue
if resp.status_code != 200:
continue
body = resp.json()
for term in body["terms"]:
entry = {
"start": term["range"]["start"],
"end": term["range"]["end"],
"file_path": files[i + err_count],
"xorb_id": term["hash"],
"unpacked_length": term["unpacked_length"]
}
ret.append(entry)
return ret
def list_reconstructions(repos, limit=None):
ret = []
con = sqlite3.connect(SQLITE3_DB)
cur = con.cursor()
for repo in repos:
if limit is None:
res = cur.execute("SELECT * FROM reconstructions WHERE repo = '{}'".format(repo))
else:
res = cur.execute("SELECT * FROM reconstructions WHERE repo = '{}' LIMIT {}".format(repo, limit))
for row in res.fetchall():
entry = {
"xorb_id": row[1],
"last_updated_timestamp": row[2],
"repo": row[3],
"file_path": row[4],
"unpacked_length": row[5],
"start": row[6],
"end": row[7]
}
ret.append(entry)
return ret
def write_files_to_db(repo):
print("Opening database", SQLITE3_DB, file=sys.stderr)
con = sqlite3.connect(SQLITE3_DB)
cur = con.cursor()
print("Creating reconstructions table if not exists", file=sys.stderr)
cur.execute(
"CREATE TABLE IF NOT EXISTS reconstructions (id INTEGER PRIMARY KEY AUTOINCREMENT, xorb_id TEXT, last_updated_datetime INTEGER, repo TEXT, file_path TEXT, unpacked_length INTEGER, start INTEGER, end INTEGER)"
)
con.commit()
print("Deleting existing rows for repo {}".format(repo), file=sys.stderr)
cur.execute("DELETE FROM reconstructions WHERE repo = '{}'".format(repo))
con.commit()
print("Inserting rows from HFFileSystem query", file=sys.stderr)
for reconstruction in list_reconstructions_from_hub(repo):
query = "INSERT INTO reconstructions VALUES (NULL, '{}', {}, '{}', '{}', {}, {}, {})".format(
reconstruction["xorb_id"],
int(time.time()),
repo,
reconstruction["file_path"],
reconstruction["unpacked_length"],
reconstruction["start"],
reconstruction["end"]
)
cur.execute(query)
con.commit()
if __name__ == "__main__":
for repo in list_repos.list_repos():
write_files_to_db(repo)
print("Done writing to DB. Sample of 5 rows:")
json.dump(
list_reconstructions(list_repos.list_repos(), limit=5),
sys.stdout,
sort_keys=True,
indent=4,
)