Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		Marc-Antoine Rondeau
		
	commited on
		
		
					Commit 
							
							·
						
						97aefb5
	
1
								Parent(s):
							
							71e7dd8
								
New db schema
Browse files- buster/db/__init__.py +3 -0
- buster/db/backward.py +108 -0
- buster/db/documents.py +169 -0
- buster/db/schema.py +135 -0
    	
        buster/db/__init__.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .documents import DocumentsDB
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            __all__ = [DocumentsDB]
         | 
    	
        buster/db/backward.py
    ADDED
    
    | @@ -0,0 +1,108 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """Used to import existing DB as a new DB."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import argparse
         | 
| 4 | 
            +
            import itertools
         | 
| 5 | 
            +
            from typing import Iterable, NamedTuple
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import numpy as np
         | 
| 8 | 
            +
            import sqlite3
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            from buster.db import DocumentsDB
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            import buster.db.documents as dest
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            IMPORT_QUERY = (
         | 
| 16 | 
            +
                r"""SELECT source, url, title, content FROM documents WHERE current = 1 ORDER BY source, url, title, id"""
         | 
| 17 | 
            +
            )
         | 
| 18 | 
            +
            CHUNK_QUERY = r"""SELECT source, url, title, content, n_tokens, embedding FROM documents WHERE current = 1 ORDER BY source, url, id"""
         | 
| 19 | 
            +
             | 
| 20 | 
            +
             | 
| 21 | 
            +
            class Document(NamedTuple):
         | 
| 22 | 
            +
                """Document from the original db."""
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                source: str
         | 
| 25 | 
            +
                url: str
         | 
| 26 | 
            +
                title: str
         | 
| 27 | 
            +
                content: str
         | 
| 28 | 
            +
             | 
| 29 | 
            +
             | 
| 30 | 
            +
            class Section(NamedTuple):
         | 
| 31 | 
            +
                """Reassemble section from the original db."""
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                url: str
         | 
| 34 | 
            +
                title: str
         | 
| 35 | 
            +
                content: str
         | 
| 36 | 
            +
             | 
| 37 | 
            +
             | 
| 38 | 
            +
            class Chunk(NamedTuple):
         | 
| 39 | 
            +
                """Chunk from the original db."""
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                source: str
         | 
| 42 | 
            +
                url: str
         | 
| 43 | 
            +
                title: str
         | 
| 44 | 
            +
                content: str
         | 
| 45 | 
            +
                n_tokens: int
         | 
| 46 | 
            +
                embedding: np.ndarray
         | 
| 47 | 
            +
             | 
| 48 | 
            +
             | 
| 49 | 
            +
            def get_documents(conn: sqlite3.Connection) -> Iterable[tuple[str, Iterable[Section]]]:
         | 
| 50 | 
            +
                """Reassemble documents from the source db's chunks."""
         | 
| 51 | 
            +
                documents = (Document(*row) for row in conn.execute(IMPORT_QUERY))
         | 
| 52 | 
            +
                by_sources = itertools.groupby(documents, lambda doc: doc.source)
         | 
| 53 | 
            +
                for source, documents in by_sources:
         | 
| 54 | 
            +
                    documents = itertools.groupby(documents, lambda doc: (doc.url, doc.title))
         | 
| 55 | 
            +
                    sections = (
         | 
| 56 | 
            +
                        Section(url, title, "".join(chunk.content for chunk in chunks)) for (url, title), chunks in documents
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    yield source, sections
         | 
| 59 | 
            +
             | 
| 60 | 
            +
             | 
| 61 | 
            +
            def get_max_size(conn: sqlite3.Connection) -> int:
         | 
| 62 | 
            +
                """Get the maximum chunk size from the source db."""
         | 
| 63 | 
            +
                sizes = (size for size, in conn.execute("select max(length(content)) FROM documents"))
         | 
| 64 | 
            +
                (size,) = sizes
         | 
| 65 | 
            +
                return size
         | 
| 66 | 
            +
             | 
| 67 | 
            +
             | 
| 68 | 
            +
            def get_chunks(conn: sqlite3.Connection) -> Iterable[tuple[str, Iterable[Iterable[dest.Chunk]]]]:
         | 
| 69 | 
            +
                """Retrieve chunks from the source db."""
         | 
| 70 | 
            +
                chunks = (Chunk(*row) for row in conn.execute(CHUNK_QUERY))
         | 
| 71 | 
            +
                by_sources = itertools.groupby(chunks, lambda chunk: chunk.source)
         | 
| 72 | 
            +
                for source, chunks in by_sources:
         | 
| 73 | 
            +
                    by_section = itertools.groupby(chunks, lambda chunk: (chunk.url, chunk.title))
         | 
| 74 | 
            +
             | 
| 75 | 
            +
                    sections = (
         | 
| 76 | 
            +
                        (dest.Chunk(chunk.content, chunk.n_tokens, chunk.embedding) for chunk in chunks) for _, chunks in by_section
         | 
| 77 | 
            +
                    )
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                    yield source, sections
         | 
| 80 | 
            +
             | 
| 81 | 
            +
             | 
| 82 | 
            +
            def main():
         | 
| 83 | 
            +
                """Import the source db into the destination db."""
         | 
| 84 | 
            +
                parser = argparse.ArgumentParser()
         | 
| 85 | 
            +
                parser.add_argument("source")
         | 
| 86 | 
            +
                parser.add_argument("destination")
         | 
| 87 | 
            +
                parser.add_argument("--size", type=int, default=2000)
         | 
| 88 | 
            +
                args = parser.parse_args()
         | 
| 89 | 
            +
                org = sqlite3.connect(args.source)
         | 
| 90 | 
            +
                db = DocumentsDB(args.destination)
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                for source, content in get_documents(org):
         | 
| 93 | 
            +
                    sid, vid = db.start_version(source)
         | 
| 94 | 
            +
                    sections = (dest.Section(section.title, section.url, section.content) for section in content)
         | 
| 95 | 
            +
                    db.add_sections(sid, vid, sections)
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                size = max(args.size, get_max_size(org))
         | 
| 98 | 
            +
                for source, chunks in get_chunks(org):
         | 
| 99 | 
            +
                    sid, vid = db.get_current_version(source)
         | 
| 100 | 
            +
                    cid = db.add_chunking(sid, vid, size)
         | 
| 101 | 
            +
                    db.add_chunks(sid, vid, cid, chunks)
         | 
| 102 | 
            +
                db.conn.commit()
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                return
         | 
| 105 | 
            +
             | 
| 106 | 
            +
             | 
| 107 | 
            +
            if __name__ == "__main__":
         | 
| 108 | 
            +
                main()
         | 
    	
        buster/db/documents.py
    ADDED
    
    | @@ -0,0 +1,169 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sqlite3
         | 
| 2 | 
            +
            from typing import Iterable, NamedTuple
         | 
| 3 | 
            +
            import warnings
         | 
| 4 | 
            +
            import zlib
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            import numpy as np
         | 
| 7 | 
            +
            import pandas as pd
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            import buster.db.schema as schema
         | 
| 10 | 
            +
             | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Section(NamedTuple):
         | 
| 13 | 
            +
                title: str
         | 
| 14 | 
            +
                url: str
         | 
| 15 | 
            +
                content: str
         | 
| 16 | 
            +
                parent: int | None = None
         | 
| 17 | 
            +
                type: str = "section"
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class Chunk(NamedTuple):
         | 
| 21 | 
            +
                content: str
         | 
| 22 | 
            +
                n_tokens: int
         | 
| 23 | 
            +
                emb: np.ndarray
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            class DocumentsDB:
         | 
| 27 | 
            +
                """Simple SQLite database for storing documents and questions/answers.
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                The database is just a file on disk. It can store documents from different sources, and it can store multiple versions of the same document (e.g. if the document is updated).
         | 
| 30 | 
            +
                Questions/answers refer to the version of the document that was used at the time.
         | 
| 31 | 
            +
             | 
| 32 | 
            +
                Example:
         | 
| 33 | 
            +
                    >>> db = DocumentsDB("/path/to/the/db.db")
         | 
| 34 | 
            +
                    >>> db.write_documents("source", df)  # df is a DataFrame containing the documents from a given source, obtained e.g. by using buster.docparser.generate_embeddings
         | 
| 35 | 
            +
                    >>> df = db.get_documents("source")
         | 
| 36 | 
            +
                """
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                def __init__(self, db_path: sqlite3.Connection | str):
         | 
| 39 | 
            +
                    if isinstance(db_path, str):
         | 
| 40 | 
            +
                        self.db_path = db_path
         | 
| 41 | 
            +
                        self.conn = sqlite3.connect(db_path)
         | 
| 42 | 
            +
                    else:
         | 
| 43 | 
            +
                        self.db_path = None
         | 
| 44 | 
            +
                        self.conn = db_path
         | 
| 45 | 
            +
                    self.cursor = self.conn.cursor()
         | 
| 46 | 
            +
                    schema.initialize_db(self.conn)
         | 
| 47 | 
            +
                    schema.setup_db(self.conn)
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                def __del__(self):
         | 
| 50 | 
            +
                    if self.db_path is not None:
         | 
| 51 | 
            +
                        self.conn.close()
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                def get_current_version(self, source: str) -> tuple[int, int]:
         | 
| 54 | 
            +
                    cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
         | 
| 55 | 
            +
                    row = cur.fetchone()
         | 
| 56 | 
            +
                    if row is None:
         | 
| 57 | 
            +
                        raise KeyError(f'"{source}" is not a known source')
         | 
| 58 | 
            +
                    sid, vid = row
         | 
| 59 | 
            +
                    return sid, vid
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                def get_source(self, source: str) -> int:
         | 
| 62 | 
            +
                    cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
         | 
| 63 | 
            +
                    row = cur.fetchone()
         | 
| 64 | 
            +
                    if row is not None:
         | 
| 65 | 
            +
                        (sid,) = row
         | 
| 66 | 
            +
                    else:
         | 
| 67 | 
            +
                        cur = self.conn.execute("INSERT INTO sources (name) VALUES (?)", (source,))
         | 
| 68 | 
            +
                        cur = self.conn.execute("SELECT id FROM sources WHERE name = ?", (source,))
         | 
| 69 | 
            +
                        row = cur.fetchone()
         | 
| 70 | 
            +
                        (sid,) = row
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    return sid
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                def start_version(self, source: str) -> tuple[int, int]:
         | 
| 75 | 
            +
                    cur = self.conn.execute("SELECT source, version FROM latest_version WHERE name = ?", (source,))
         | 
| 76 | 
            +
                    row = cur.fetchone()
         | 
| 77 | 
            +
                    if row is None:
         | 
| 78 | 
            +
                        sid = self.get_source(source)
         | 
| 79 | 
            +
                        vid = 0
         | 
| 80 | 
            +
                    else:
         | 
| 81 | 
            +
                        sid, vid = row
         | 
| 82 | 
            +
                        vid = vid + 1
         | 
| 83 | 
            +
                    self.conn.execute("INSERT INTO versions (source, version) VALUES (?, ?)", (sid, vid))
         | 
| 84 | 
            +
                    return sid, vid
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def add_sections(self, sid: int, vid: int, sections: Iterable[Section]):
         | 
| 87 | 
            +
                    values = (
         | 
| 88 | 
            +
                        (sid, vid, ind, section.title, section.url, section.content, section.parent, section.type)
         | 
| 89 | 
            +
                        for ind, section in enumerate(sections)
         | 
| 90 | 
            +
                    )
         | 
| 91 | 
            +
                    self.conn.executemany(
         | 
| 92 | 
            +
                        "INSERT INTO sections "
         | 
| 93 | 
            +
                        "(source, version, section, title, url, content, parent, type) "
         | 
| 94 | 
            +
                        "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
         | 
| 95 | 
            +
                        values,
         | 
| 96 | 
            +
                    )
         | 
| 97 | 
            +
                    return
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def add_chunking(self, sid: int, vid: int, size: int, overlap: int = 0, strategy: str = "simple") -> int:
         | 
| 100 | 
            +
                    self.conn.execute(
         | 
| 101 | 
            +
                        "INSERT INTO chunkings (size, overlap, strategy, source, version) VALUES (?, ?, ?, ?, ?)",
         | 
| 102 | 
            +
                        (size, overlap, strategy, sid, vid),
         | 
| 103 | 
            +
                    )
         | 
| 104 | 
            +
                    cur = self.conn.execute(
         | 
| 105 | 
            +
                        "SELECT chunking FROM chunkings "
         | 
| 106 | 
            +
                        "WHERE size = ? AND overlap = ? AND strategy = ? AND source = ? AND version = ?",
         | 
| 107 | 
            +
                        (size, overlap, strategy, sid, vid),
         | 
| 108 | 
            +
                    )
         | 
| 109 | 
            +
                    (id,) = (id for id, in cur)
         | 
| 110 | 
            +
                    return id
         | 
| 111 | 
            +
             | 
| 112 | 
            +
                def add_chunks(self, sid: int, vid: int, cid: int, sections: Iterable[Iterable[Chunk]]):
         | 
| 113 | 
            +
                    chunks = ((ind, jnd, chunk) for ind, section in enumerate(sections) for jnd, chunk in enumerate(section))
         | 
| 114 | 
            +
                    values = ((sid, vid, ind, cid, jnd, chunk.content, chunk.n_tokens, chunk.emb) for ind, jnd, chunk in chunks)
         | 
| 115 | 
            +
                    self.conn.executemany(
         | 
| 116 | 
            +
                        "INSERT INTO chunks "
         | 
| 117 | 
            +
                        "(source, version, section, chunking, sequence, content, n_tokens, embedding) "
         | 
| 118 | 
            +
                        "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
         | 
| 119 | 
            +
                        values,
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    return
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def write_documents(self, source: str, df: pd.DataFrame):
         | 
| 124 | 
            +
                    """Write all documents from the dataframe into the db. All previous documents from that source will be set to `current = 0`."""
         | 
| 125 | 
            +
                    df = df.copy()
         | 
| 126 | 
            +
             | 
| 127 | 
            +
                    # Prepare the rows
         | 
| 128 | 
            +
                    df["source"] = source
         | 
| 129 | 
            +
                    df["current"] = 1
         | 
| 130 | 
            +
                    columns = ["source", "title", "url", "content", "current"]
         | 
| 131 | 
            +
                    if "embedding" in df.columns:
         | 
| 132 | 
            +
                        columns.extend(
         | 
| 133 | 
            +
                            [
         | 
| 134 | 
            +
                                "n_tokens",
         | 
| 135 | 
            +
                                "embedding",
         | 
| 136 | 
            +
                            ]
         | 
| 137 | 
            +
                        )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                        # Check that the embeddings are float32
         | 
| 140 | 
            +
                        if not df["embedding"].iloc[0].dtype == np.float32:
         | 
| 141 | 
            +
                            warnings.warn(
         | 
| 142 | 
            +
                                f"Embeddings are not float32, converting them to float32 from {df['embedding'].iloc[0].dtype}.",
         | 
| 143 | 
            +
                                RuntimeWarning,
         | 
| 144 | 
            +
                            )
         | 
| 145 | 
            +
                            df["embedding"] = df["embedding"].apply(lambda x: x.astype(np.float32))
         | 
| 146 | 
            +
             | 
| 147 | 
            +
                        # ZLIB compress the embeddings
         | 
| 148 | 
            +
                        df["embedding"] = df["embedding"].apply(lambda x: sqlite3.Binary(zlib.compress(x.tobytes())))
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                    data = df[columns].values.tolist()
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                    # Set `current` to 0 for all previous documents from that source
         | 
| 153 | 
            +
                    self.cursor.execute("UPDATE documents SET current = 0 WHERE source = ?", (source,))
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                    # Insert the new documents
         | 
| 156 | 
            +
                    insert_statement = f"INSERT INTO documents ({', '.join(columns)}) VALUES ({', '.join(['?']*len(columns))})"
         | 
| 157 | 
            +
                    self.cursor.executemany(insert_statement, data)
         | 
| 158 | 
            +
             | 
| 159 | 
            +
                    self.conn.commit()
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def get_documents(self, source: str) -> pd.DataFrame:
         | 
| 162 | 
            +
                    """Get all current documents from a given source."""
         | 
| 163 | 
            +
                    # Execute the SQL statement and fetch the results
         | 
| 164 | 
            +
                    results = self.cursor.execute("SELECT * FROM documents WHERE source = ?", (source,))
         | 
| 165 | 
            +
                    rows = results.fetchall()
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                    # Convert the results to a pandas DataFrame
         | 
| 168 | 
            +
                    df = pd.DataFrame(rows, columns=[description[0] for description in results.description])
         | 
| 169 | 
            +
                    return df
         | 
    	
        buster/db/schema.py
    ADDED
    
    | @@ -0,0 +1,135 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import zlib
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            import numpy as np
         | 
| 5 | 
            +
            import sqlite3
         | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
            SOURCE_TABLE = r"""CREATE TABLE IF NOT EXISTS sources (
         | 
| 9 | 
            +
                id INTEGER PRIMARY KEY AUTOINCREMENT,
         | 
| 10 | 
            +
                name TEXT NOT NULL,
         | 
| 11 | 
            +
                note TEXT,
         | 
| 12 | 
            +
                UNIQUE(name)
         | 
| 13 | 
            +
            )"""
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            VERSION_TABLE = r"""CREATE TABLE IF NOT EXISTS versions (
         | 
| 17 | 
            +
                source INTEGER,
         | 
| 18 | 
            +
                version INTEGER,
         | 
| 19 | 
            +
                parser TEXT,
         | 
| 20 | 
            +
                note TEXT,
         | 
| 21 | 
            +
                PRIMARY KEY (version, source, parser)
         | 
| 22 | 
            +
                FOREIGN KEY (source) REFERENCES sources (id)
         | 
| 23 | 
            +
            )"""
         | 
| 24 | 
            +
             | 
| 25 | 
            +
             | 
| 26 | 
            +
            CHUNKING_TABLE = r"""CREATE TABLE IF NOT EXISTS chunkings (
         | 
| 27 | 
            +
                chunking INTEGER PRIMARY KEY AUTOINCREMENT,
         | 
| 28 | 
            +
                size INTEGER,
         | 
| 29 | 
            +
                overlap INTEGER,
         | 
| 30 | 
            +
                strategy TEXT,
         | 
| 31 | 
            +
                chunker TEXT,
         | 
| 32 | 
            +
                source INTEGER,
         | 
| 33 | 
            +
                version INTEGER,
         | 
| 34 | 
            +
                UNIQUE (size, overlap, strategy, chunker, source, version),
         | 
| 35 | 
            +
                FOREIGN KEY (source, version) REFERENCES versions (source, version)
         | 
| 36 | 
            +
            )"""
         | 
| 37 | 
            +
             | 
| 38 | 
            +
             | 
| 39 | 
            +
            SECTION_TABLE = r"""CREATE TABLE IF NOT EXISTS sections (
         | 
| 40 | 
            +
                source INTEGER,
         | 
| 41 | 
            +
                version INTEGER,
         | 
| 42 | 
            +
                section INTEGER,
         | 
| 43 | 
            +
                title TEXT NOT NULL,
         | 
| 44 | 
            +
                url TEXT NOT NULL,
         | 
| 45 | 
            +
                content TEXT NOT NULL,
         | 
| 46 | 
            +
                parent INTEGER,
         | 
| 47 | 
            +
                type TEXT,
         | 
| 48 | 
            +
                PRIMARY KEY (version, source, section),
         | 
| 49 | 
            +
                FOREIGN KEY (source) REFERENCES versions (source),
         | 
| 50 | 
            +
                FOREIGN KEY (version) REFERENCES versions (version)
         | 
| 51 | 
            +
            )"""
         | 
| 52 | 
            +
             | 
| 53 | 
            +
             | 
| 54 | 
            +
            CHUNK_TABLE = r"""CREATE TABLE IF NOT EXISTS chunks (
         | 
| 55 | 
            +
                source INTEGER,
         | 
| 56 | 
            +
                version INTEGER,
         | 
| 57 | 
            +
                section INTEGER,
         | 
| 58 | 
            +
                chunking INTEGER,
         | 
| 59 | 
            +
                sequence INTEGER,
         | 
| 60 | 
            +
                content TEXT NOT NULL,
         | 
| 61 | 
            +
                n_tokens INTEGER,
         | 
| 62 | 
            +
                embedding VECTOR,
         | 
| 63 | 
            +
                PRIMARY KEY (source, version, section, chunking, sequence),
         | 
| 64 | 
            +
                FOREIGN KEY (source, version, section) REFERENCES sections (source, version, section),
         | 
| 65 | 
            +
                FOREIGN KEY (source, version, chunking) REFERENCES chunkings (source, version, chunking)
         | 
| 66 | 
            +
            )"""
         | 
| 67 | 
            +
             | 
| 68 | 
            +
             | 
| 69 | 
            +
            VERSION_VIEW = r"""CREATE VIEW IF NOT EXISTS latest_version (
         | 
| 70 | 
            +
                name, source, version) AS
         | 
| 71 | 
            +
                SELECT sources.name, versions.source, max(versions.version)
         | 
| 72 | 
            +
                FROM sources INNER JOIN versions on sources.id = versions.source
         | 
| 73 | 
            +
                GROUP BY sources.id
         | 
| 74 | 
            +
            """
         | 
| 75 | 
            +
             | 
| 76 | 
            +
            CHUNKING_VIEW = r"""CREATE VIEW IF NOT EXISTS latest_chunking (
         | 
| 77 | 
            +
                name, source, version, chunking) AS
         | 
| 78 | 
            +
                SELECT name, source, version, max(chunking) FROM
         | 
| 79 | 
            +
                chunkings INNER JOIN latest_version USING (source, version)
         | 
| 80 | 
            +
                GROUP by source, version
         | 
| 81 | 
            +
            """
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            DOCUMENT_VIEW = r"""CREATE VIEW IF NOT EXISTS documents (
         | 
| 84 | 
            +
                source, title, url, content, n_tokens, embedding)
         | 
| 85 | 
            +
                AS SELECT latest_chunking.name, sections.title, sections.url,
         | 
| 86 | 
            +
                chunks.content, chunks.n_tokens, chunks.embedding
         | 
| 87 | 
            +
                FROM chunks INNER JOIN sections USING (source, version, section)
         | 
| 88 | 
            +
                INNER JOIN latest_chunking USING (source, version, chunking)
         | 
| 89 | 
            +
            """
         | 
| 90 | 
            +
             | 
| 91 | 
            +
             | 
| 92 | 
            +
            INIT_STATEMENTS = [
         | 
| 93 | 
            +
                SOURCE_TABLE,
         | 
| 94 | 
            +
                VERSION_TABLE,
         | 
| 95 | 
            +
                CHUNKING_TABLE,
         | 
| 96 | 
            +
                SECTION_TABLE,
         | 
| 97 | 
            +
                CHUNK_TABLE,
         | 
| 98 | 
            +
                VERSION_VIEW,
         | 
| 99 | 
            +
                CHUNKING_VIEW,
         | 
| 100 | 
            +
                DOCUMENT_VIEW,
         | 
| 101 | 
            +
            ]
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            def initialize_db(connection: sqlite3.Connection):
         | 
| 105 | 
            +
                for statement in INIT_STATEMENTS:
         | 
| 106 | 
            +
                    try:
         | 
| 107 | 
            +
                        connection.execute(statement)
         | 
| 108 | 
            +
                    except sqlite3.Error as error:
         | 
| 109 | 
            +
                        connection.rollback()
         | 
| 110 | 
            +
                        raise
         | 
| 111 | 
            +
                connection.commit()
         | 
| 112 | 
            +
                return connection
         | 
| 113 | 
            +
             | 
| 114 | 
            +
             | 
| 115 | 
            +
            def adapt_vector(vector: np.ndarray) -> bytes:
         | 
| 116 | 
            +
                return sqlite3.Binary(zlib.compress(vector.astype(np.float32).tobytes()))
         | 
| 117 | 
            +
             | 
| 118 | 
            +
             | 
| 119 | 
            +
            def convert_vector(buffer: bytes) -> np.ndarray:
         | 
| 120 | 
            +
                return np.frombuffer(zlib.decompress(buffer), dtype=np.float32)
         | 
| 121 | 
            +
             | 
| 122 | 
            +
             | 
| 123 | 
            +
            def cosine_similarity(a: bytes, b: bytes) -> float:
         | 
| 124 | 
            +
                a = convert_vector(a)
         | 
| 125 | 
            +
                b = convert_vector(b)
         | 
| 126 | 
            +
                a = a / np.linalg.norm(a)
         | 
| 127 | 
            +
                b = b / np.linalg.norm(b)
         | 
| 128 | 
            +
                dopt = 0.5 * np.dot(a, b) + 0.5
         | 
| 129 | 
            +
                return float(dopt)
         | 
| 130 | 
            +
             | 
| 131 | 
            +
             | 
| 132 | 
            +
            def setup_db(connection: sqlite3.Connection):
         | 
| 133 | 
            +
                sqlite3.register_adapter(np.ndarray, adapt_vector)
         | 
| 134 | 
            +
                sqlite3.register_converter("VECTOR", convert_vector)
         | 
| 135 | 
            +
                connection.create_function("sim", 2, cosine_similarity, deterministic=True)
         | 
