|
import os |
|
import json |
|
import sys |
|
import time |
|
|
|
import grequests |
|
import sqlite3 |
|
from tqdm import tqdm |
|
|
|
import list_files |
|
import list_repos |
|
|
|
SQLITE3_DB = "data/reconstructions.sqlite3" |
|
|
|
HF_ENDPOINT = os.getenv("HF_ENDPOINT", "https://huggingface.co") |
|
XET_CAS_ENDPOINT = os.getenv("XET_CAS_ENDPOINT", "https://cas-server.xethub.hf.co") |
|
|
|
RESOLVE_URL_TEMPLATE = HF_ENDPOINT + "/{}/resolve/main" |
|
|
|
|
|
def exception_handler(req, exc): |
|
print(exc, file=sys.stderr) |
|
|
|
|
|
def list_reconstructions_from_hub(repo): |
|
print( |
|
"Listing reconstructions using:\nHF Hub Endpoint: {}\nXet CAS Endpoint: {}".format( |
|
HF_ENDPOINT, XET_CAS_ENDPOINT |
|
), |
|
file=sys.stderr, |
|
) |
|
|
|
ret = [] |
|
files = [] |
|
resolve_reqs = [] |
|
reconstruct_reqs = [] |
|
err_count = 0 |
|
|
|
print("Listing files for repo {}".format(repo), file=sys.stderr) |
|
total = 0 |
|
for i, file in tqdm(enumerate(list_files.list_lfs_files(repo))): |
|
total += 1 |
|
files.append(file["name"]) |
|
if repo.startswith("models/"): |
|
repo = repo.replace("models/", "", 1) |
|
url = file["name"].replace(repo, RESOLVE_URL_TEMPLATE.format(repo), 1) |
|
headers = {"Authorization": "Bearer {}".format(os.getenv("HF_TOKEN"))} |
|
resolve_reqs.append( |
|
grequests.head(url, headers=headers, allow_redirects=False) |
|
) |
|
|
|
print("", file=sys.stderr) |
|
print("Calling /resolve/ for repo {}".format(repo), file=sys.stderr) |
|
for i, resp in tqdm( |
|
grequests.imap_enumerated( |
|
resolve_reqs, size=4, exception_handler=exception_handler |
|
), |
|
total=total, |
|
): |
|
if resp is None: |
|
err_count += 1 |
|
continue |
|
|
|
refresh_route = resp.headers.get("x-xet-refresh-route") |
|
xet_hash = resp.headers.get("x-xet-hash") |
|
access_token = resp.headers.get("x-xet-access-token") |
|
if xet_hash is not None and xet_hash != "": |
|
url = "{}/reconstruction/{}".format(XET_CAS_ENDPOINT, xet_hash) |
|
headers = {"Authorization": "Bearer {}".format(access_token)} |
|
reconstruct_reqs.append(grequests.get(url, headers=headers)) |
|
|
|
print("", file=sys.stderr) |
|
print( |
|
"Calling /reconstruct/ with grequests for repo {}".format(repo), |
|
file=sys.stderr, |
|
) |
|
|
|
for i, resp in tqdm( |
|
grequests.imap_enumerated( |
|
reconstruct_reqs, size=4, exception_handler=exception_handler |
|
), |
|
total=total, |
|
): |
|
if resp is None: |
|
continue |
|
if resp.status_code != 200: |
|
continue |
|
body = resp.json() |
|
for term in body["terms"]: |
|
entry = { |
|
"start": term["range"]["start"], |
|
"end": term["range"]["end"], |
|
"file_path": files[i + err_count], |
|
"xorb_id": term["hash"], |
|
"unpacked_length": term["unpacked_length"] |
|
} |
|
ret.append(entry) |
|
|
|
return ret |
|
|
|
|
|
def list_reconstructions(repos, limit=None): |
|
ret = [] |
|
con = sqlite3.connect(SQLITE3_DB) |
|
cur = con.cursor() |
|
for repo in repos: |
|
if limit is None: |
|
res = cur.execute("SELECT * FROM reconstructions WHERE repo = '{}'".format(repo)) |
|
else: |
|
res = cur.execute("SELECT * FROM reconstructions WHERE repo = '{}' LIMIT {}".format(repo, limit)) |
|
for row in res.fetchall(): |
|
entry = { |
|
"xorb_id": row[1], |
|
"last_updated_timestamp": row[2], |
|
"repo": row[3], |
|
"file_path": row[4], |
|
"unpacked_length": row[5], |
|
"start": row[6], |
|
"end": row[7] |
|
} |
|
ret.append(entry) |
|
return ret |
|
|
|
|
|
def write_files_to_db(repo): |
|
print("Opening database", SQLITE3_DB, file=sys.stderr) |
|
con = sqlite3.connect(SQLITE3_DB) |
|
cur = con.cursor() |
|
print("Creating reconstructions table if not exists", file=sys.stderr) |
|
cur.execute( |
|
"CREATE TABLE IF NOT EXISTS reconstructions (id INTEGER PRIMARY KEY AUTOINCREMENT, xorb_id TEXT, last_updated_datetime INTEGER, repo TEXT, file_path TEXT, unpacked_length INTEGER, start INTEGER, end INTEGER)" |
|
) |
|
con.commit() |
|
print("Deleting existing rows for repo {}".format(repo), file=sys.stderr) |
|
cur.execute("DELETE FROM reconstructions WHERE repo = '{}'".format(repo)) |
|
con.commit() |
|
print("Inserting rows from HFFileSystem query", file=sys.stderr) |
|
for reconstruction in list_reconstructions_from_hub(repo): |
|
query = "INSERT INTO reconstructions VALUES (NULL, '{}', {}, '{}', '{}', {}, {}, {})".format( |
|
reconstruction["xorb_id"], |
|
int(time.time()), |
|
repo, |
|
reconstruction["file_path"], |
|
reconstruction["unpacked_length"], |
|
reconstruction["start"], |
|
reconstruction["end"] |
|
) |
|
cur.execute(query) |
|
con.commit() |
|
|
|
|
|
if __name__ == "__main__": |
|
for repo in list_repos.list_repos(): |
|
write_files_to_db(repo) |
|
print("Done writing to DB. Sample of 5 rows:") |
|
json.dump( |
|
list_reconstructions(list_repos.list_repos(), limit=5), |
|
sys.stdout, |
|
sort_keys=True, |
|
indent=4, |
|
) |
|
|