Spaces:
Runtime error
Runtime error
| """Used to import existing DB as a new DB.""" | |
| import argparse | |
| import itertools | |
| import sqlite3 | |
| from typing import Iterable, NamedTuple | |
| import numpy as np | |
| import buster.documents.sqlite.documents as dest | |
| from buster.documents.sqlite import DocumentsDB | |
| IMPORT_QUERY = ( | |
| r"""SELECT source, url, title, content FROM documents WHERE current = 1 ORDER BY source, url, title, id""" | |
| ) | |
| CHUNK_QUERY = r"""SELECT source, url, title, content, n_tokens, embedding FROM documents WHERE current = 1 ORDER BY source, url, id""" | |
| class Document(NamedTuple): | |
| """Document from the original db.""" | |
| source: str | |
| url: str | |
| title: str | |
| content: str | |
| class Section(NamedTuple): | |
| """Reassemble section from the original db.""" | |
| url: str | |
| title: str | |
| content: str | |
| class Chunk(NamedTuple): | |
| """Chunk from the original db.""" | |
| source: str | |
| url: str | |
| title: str | |
| content: str | |
| n_tokens: int | |
| embedding: np.ndarray | |
| def get_documents(conn: sqlite3.Connection) -> Iterable[tuple[str, Iterable[Section]]]: | |
| """Reassemble documents from the source db's chunks.""" | |
| documents = (Document(*row) for row in conn.execute(IMPORT_QUERY)) | |
| by_sources = itertools.groupby(documents, lambda doc: doc.source) | |
| for source, documents in by_sources: | |
| documents = itertools.groupby(documents, lambda doc: (doc.url, doc.title)) | |
| sections = ( | |
| Section(url, title, "".join(chunk.content for chunk in chunks)) for (url, title), chunks in documents | |
| ) | |
| yield source, sections | |
| def get_max_size(conn: sqlite3.Connection) -> int: | |
| """Get the maximum chunk size from the source db.""" | |
| sizes = (size for size, in conn.execute("select max(length(content)) FROM documents")) | |
| (size,) = sizes | |
| return size | |
| def get_chunks(conn: sqlite3.Connection) -> Iterable[tuple[str, Iterable[Iterable[dest.Chunk]]]]: | |
| """Retrieve chunks from the source db.""" | |
| chunks = (Chunk(*row) for row in conn.execute(CHUNK_QUERY)) | |
| by_sources = itertools.groupby(chunks, lambda chunk: chunk.source) | |
| for source, chunks in by_sources: | |
| by_section = itertools.groupby(chunks, lambda chunk: (chunk.url, chunk.title)) | |
| sections = ( | |
| (dest.Chunk(chunk.content, chunk.n_tokens, chunk.embedding) for chunk in chunks) for _, chunks in by_section | |
| ) | |
| yield source, sections | |
| def main(): | |
| """Import the source db into the destination db.""" | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("source") | |
| parser.add_argument("destination") | |
| parser.add_argument("--size", type=int, default=2000) | |
| args = parser.parse_args() | |
| org = sqlite3.connect(args.source) | |
| db = DocumentsDB(args.destination) | |
| for source, content in get_documents(org): | |
| # sid, vid = db.start_version(source) | |
| sections = (dest.Section(section.title, section.url, section.content) for section in content) | |
| db.add_parse(source, sections) | |
| size = max(args.size, get_max_size(org)) | |
| for source, chunks in get_chunks(org): | |
| sid, vid = db.get_current_version(source) | |
| db.add_chunking(sid, vid, size, chunks) | |
| db.conn.commit() | |
| return | |
| if __name__ == "__main__": | |
| main() | |