""" Helper class to simplify common read-only BigQuery tasks. """ import pandas as pd import time from google.cloud import bigquery class BigQueryHelper(object): """ Helper class to simplify common BigQuery tasks like executing queries, showing table schemas, etc without worrying about table or dataset pointers. See the BigQuery docs for details of the steps this class lets you skip: https://googlecloudplatform.github.io/google-cloud-python/latest/bigquery/reference.html """ def __init__(self, active_project, dataset_name, max_wait_seconds=180): self.project_name = active_project self.dataset_name = dataset_name self.max_wait_seconds = max_wait_seconds self.client = bigquery.Client() self.__dataset_ref = self.client.dataset(self.dataset_name, project=self.project_name) self.dataset = None self.tables = dict() # {table name (str): table object} self.__table_refs = dict() # {table name (str): table reference} self.total_gb_used_net_cache = 0 self.BYTES_PER_GB = 2**30 def __fetch_dataset(self): """ Lazy loading of dataset. For example, if the user only calls `self.query_to_pandas` then the dataset never has to be fetched. """ if self.dataset is None: self.dataset = self.client.get_dataset(self.__dataset_ref) def __fetch_table(self, table_name): """ Lazy loading of table """ self.__fetch_dataset() if table_name not in self.__table_refs: self.__table_refs[table_name] = self.dataset.table(table_name) if table_name not in self.tables: self.tables[table_name] = self.client.get_table(self.__table_refs[table_name]) def __handle_record_field(self, row, schema_details, top_level_name=''): """ Unpack a single row, including any nested fields. """ name = row['name'] if top_level_name != '': name = top_level_name + '.' + name schema_details.append([{ 'name': name, 'type': row['type'], 'mode': row['mode'], 'fields': pd.np.nan, 'description': row['description'] }]) # float check is to dodge row['fields'] == np.nan if type(row.get('fields', 0.0)) == float: return None for entry in row['fields']: self.__handle_record_field(entry, schema_details, name) def __unpack_all_schema_fields(self, schema): """ Unrolls nested schemas. Returns dataframe with one row per field, and the field names in the format accepted by the API. Results will look similar to the website schema, such as: https://bigquery.cloud.google.com/table/bigquery-public-data:github_repos.commits?pli=1 Args: schema: DataFrame derived from api repr of raw table.schema Returns: Dataframe of the unrolled schema. """ schema_details = [] schema.apply(lambda row: self.__handle_record_field(row, schema_details), axis=1) result = pd.concat([pd.DataFrame.from_dict(x) for x in schema_details]) result.reset_index(drop=True, inplace=True) del result['fields'] return result def table_schema(self, table_name): """ Get the schema for a specific table from a dataset. Unrolls nested field names into the format that can be copied directly into queries. For example, for the `github.commits` table, the this will return `committer.name`. This is a very different return signature than BigQuery's table.schema. """ self.__fetch_table(table_name) raw_schema = self.tables[table_name].schema schema = pd.DataFrame.from_dict([x.to_api_repr() for x in raw_schema]) # the api_repr only has the fields column for tables with nested data if 'fields' in schema.columns: schema = self.__unpack_all_schema_fields(schema) # Set the column order schema = schema[['name', 'type', 'mode', 'description']] return schema def list_tables(self): """ List the names of the tables in a dataset """ self.__fetch_dataset() return([x.table_id for x in self.client.list_tables(self.dataset)]) def estimate_query_size(self, query): """ Estimate gigabytes scanned by query. Does not consider if there is a cached query table. See https://cloud.google.com/bigquery/docs/reference/rest/v2/jobs#configuration.dryRun """ my_job_config = bigquery.job.QueryJobConfig() my_job_config.dry_run = True my_job = self.client.query(query, job_config=my_job_config) return my_job.total_bytes_processed / self.BYTES_PER_GB def query_to_pandas(self, query): """ Execute a SQL query & return a pandas dataframe """ my_job = self.client.query(query) start_time = time.time() while not my_job.done(): if (time.time() - start_time) > self.max_wait_seconds: print("Max wait time elapsed, query cancelled.") self.client.cancel_job(my_job.job_id) return None time.sleep(0.1) # Queries that hit errors will return an exception type. # Those exceptions don't get raised until we call my_job.to_dataframe() # In that case, my_job.total_bytes_billed can be called but is None if my_job.total_bytes_billed: self.total_gb_used_net_cache += my_job.total_bytes_billed / self.BYTES_PER_GB return my_job.to_dataframe() def query_to_pandas_safe(self, query, max_gb_scanned=1): """ Execute a query, but only if the query would scan less than `max_gb_scanned` of data. """ query_size = self.estimate_query_size(query) if query_size <= max_gb_scanned: return self.query_to_pandas(query) msg = "Query cancelled; estimated size of {0} exceeds limit of {1} GB" print(msg.format(query_size, max_gb_scanned)) def head(self, table_name, num_rows=5, start_index=None, selected_columns=None): """ Get the first n rows of a table as a DataFrame. Does not perform a full table scan; should use a trivial amount of data as long as n is small. """ self.__fetch_table(table_name) active_table = self.tables[table_name] schema_subset = None if selected_columns: schema_subset = [col for col in active_table.schema if col.name in selected_columns] results = self.client.list_rows(active_table, selected_fields=schema_subset, max_results=num_rows, start_index=start_index) results = [x for x in results] return pd.DataFrame( data=[list(x.values()) for x in results], columns=list(results[0].keys()))