diff --git a/packages/bigframes/bigframes/core/events.py b/packages/bigframes/bigframes/core/events.py index 0724cc5414bb..3b3e6013f7a3 100644 --- a/packages/bigframes/bigframes/core/events.py +++ b/packages/bigframes/bigframes/core/events.py @@ -18,7 +18,7 @@ import datetime import threading import uuid -from typing import Any, Callable, Optional, Set +from typing import Any, Callable, Literal, Set import google.cloud.bigquery._job_helpers import google.cloud.bigquery.job.query @@ -26,9 +26,13 @@ import bigframes.session.executor +_DEFAULT: Literal["default"] = "default" + class Subscriber: - def __init__(self, callback: Callable[[Event], None], *, publisher: Publisher): + def __init__( + self, callback: Callable[[Event], None], *, publisher: Publisher + ): # noqa: E501 self._publisher = publisher self._callback = callback self._subscriber_id = uuid.uuid4() @@ -81,16 +85,22 @@ def unsubscribe(self, subscriber: Subscriber): with self._subscribers_lock: self._subscribers.remove(subscriber) - def publish(self, event: Event): + def publish(self, envelope: "EventEnvelope"): with self._subscribers_lock: for subscriber in self._subscribers: - subscriber(event) + subscriber(envelope) class Event: pass +@dataclasses.dataclass(frozen=True) +class EventEnvelope: + event: Event + progress_bar: Literal["default", "auto", "notebook", "terminal"] | None = None + + @dataclasses.dataclass(frozen=True) class SessionClosed(Event): session_id: str @@ -106,7 +116,7 @@ class ExecutionRunning(Event): @dataclasses.dataclass(frozen=True) class ExecutionFinished(Event): - result: Optional[bigframes.session.executor.ExecuteResult] = None + result: bigframes.session.executor.ExecuteResult | None = None @dataclasses.dataclass(frozen=True) @@ -121,13 +131,16 @@ class BigQuerySentEvent(ExecutionRunning): """Query sent to BigQuery.""" query: str - billing_project: Optional[str] = None - location: Optional[str] = None - job_id: Optional[str] = None - request_id: Optional[str] = None + billing_project: str | None = None + location: str | None = None + job_id: str | None = None + request_id: str | None = None @classmethod - def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QuerySentEvent): + def from_bqclient( + cls, + event: google.cloud.bigquery._job_helpers.QuerySentEvent, + ): return cls( query=event.query, billing_project=event.billing_project, @@ -142,13 +155,16 @@ class BigQueryRetryEvent(ExecutionRunning): """Query sent another time because the previous attempt failed.""" query: str - billing_project: Optional[str] = None - location: Optional[str] = None - job_id: Optional[str] = None - request_id: Optional[str] = None + billing_project: str | None = None + location: str | None = None + job_id: str | None = None + request_id: str | None = None @classmethod - def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QueryRetryEvent): + def from_bqclient( + cls, + event: google.cloud.bigquery._job_helpers.QueryRetryEvent, + ): return cls( query=event.query, billing_project=event.billing_project, @@ -162,19 +178,20 @@ def from_bqclient(cls, event: google.cloud.bigquery._job_helpers.QueryRetryEvent class BigQueryReceivedEvent(ExecutionRunning): """Query received and acknowledged by the BigQuery API.""" - billing_project: Optional[str] = None - location: Optional[str] = None - job_id: Optional[str] = None - statement_type: Optional[str] = None - state: Optional[str] = None - query_plan: Optional[list[google.cloud.bigquery.job.query.QueryPlanEntry]] = None - created: Optional[datetime.datetime] = None - started: Optional[datetime.datetime] = None - ended: Optional[datetime.datetime] = None + billing_project: str | None = None + location: str | None = None + job_id: str | None = None + statement_type: str | None = None + state: str | None = None + query_plan: list[google.cloud.bigquery.job.query.QueryPlanEntry] | None = None + created: datetime.datetime | None = None + started: datetime.datetime | None = None + ended: datetime.datetime | None = None @classmethod def from_bqclient( - cls, event: google.cloud.bigquery._job_helpers.QueryReceivedEvent + cls, + event: google.cloud.bigquery._job_helpers.QueryReceivedEvent, ): return cls( billing_project=event.billing_project, @@ -193,21 +210,22 @@ def from_bqclient( class BigQueryFinishedEvent(ExecutionRunning): """Query finished successfully.""" - billing_project: Optional[str] = None - location: Optional[str] = None - query_id: Optional[str] = None - job_id: Optional[str] = None - destination: Optional[google.cloud.bigquery.table.TableReference] = None - total_rows: Optional[int] = None - total_bytes_processed: Optional[int] = None - slot_millis: Optional[int] = None - created: Optional[datetime.datetime] = None - started: Optional[datetime.datetime] = None - ended: Optional[datetime.datetime] = None + billing_project: str | None = None + location: str | None = None + query_id: str | None = None + job_id: str | None = None + destination: google.cloud.bigquery.table.TableReference | None = None + total_rows: int | None = None + total_bytes_processed: int | None = None + slot_millis: int | None = None + created: datetime.datetime | None = None + started: datetime.datetime | None = None + ended: datetime.datetime | None = None @classmethod def from_bqclient( - cls, event: google.cloud.bigquery._job_helpers.QueryFinishedEvent + cls, + event: google.cloud.bigquery._job_helpers.QueryFinishedEvent, ): return cls( billing_project=event.billing_project, diff --git a/packages/bigframes/bigframes/formatting_helpers.py b/packages/bigframes/bigframes/formatting_helpers.py index cef14d39a3f6..1676d6dc35a1 100644 --- a/packages/bigframes/bigframes/formatting_helpers.py +++ b/packages/bigframes/bigframes/formatting_helpers.py @@ -71,7 +71,10 @@ def repr_query_job(query_job: Optional[bigquery.QueryJob]): if query_job is None: return "No job information available" if query_job.dry_run: - return f"Computation deferred. Computation will process {get_formatted_bytes(query_job.total_bytes_processed)}" + return ( + f"Computation deferred. Computation will process " + f"{get_formatted_bytes(query_job.total_bytes_processed)}" + ) res = "Query Job Info" for key, value in query_job_prop_pairs.items(): job_val = getattr(query_job, value) @@ -105,11 +108,15 @@ def repr_query_job_html(query_job: Optional[bigquery.QueryJob]): if query_job is None: return "No job information available" if query_job.dry_run: - return f"Computation deferred. Computation will process {get_formatted_bytes(query_job.total_bytes_processed)}" + return ( + f"Computation deferred. Computation will process " + f"{get_formatted_bytes(query_job.total_bytes_processed)}" + ) # We can reuse the plaintext repr for now or make a nicer table. - # For deferred mode consistency, let's just wrap the text in a pre block or similar, - # but the request implies we want a distinct HTML representation if possible. + # For deferred mode consistency, let's just wrap the text in a pre + # block or similar, but the request implies we want a distinct HTML + # representation if possible. # However, existing repr_query_job returns a simple string. # Let's format it as a simple table or list. @@ -123,7 +130,10 @@ def repr_query_job_html(query_job: Optional[bigquery.QueryJob]): location=query_job.location, job_id=query_job.job_id, ) - res += f'
  • Job: {query_job.job_id}
  • ' + res += ( + f'
  • Job: ' + f"{query_job.job_id}
  • " + ) elif key == "Slot Time": res += f"
  • {key}: {get_formatted_time(job_val)}
  • " elif key == "Bytes Processed": @@ -138,7 +148,7 @@ def repr_query_job_html(query_job: Optional[bigquery.QueryJob]): def progress_callback( - event: bigframes.core.events.Event, + envelope: Any, ): """Displays a progress bar while the query is running""" global current_display_id @@ -147,12 +157,21 @@ def progress_callback( import bigframes._config import bigframes.core.events except ImportError: - # Since this gets called from __del__, skip if the import fails to avoid + # Since this gets called from __del__, skip if the import fails + # to avoid # ImportError: sys.meta_path is None, Python is likely shutting down. # This will allow cleanup to continue. return - progress_bar = bigframes._config.options.display.progress_bar + if isinstance(envelope, bigframes.core.events.EventEnvelope): + event = envelope.event + progress_bar = envelope.progress_bar + else: + event = envelope + progress_bar = bigframes.core.events._DEFAULT + + if progress_bar == bigframes.core.events._DEFAULT: + progress_bar = bigframes._config.options.display.progress_bar if progress_bar == "auto": progress_bar = "notebook" if in_ipython() else "terminal" @@ -232,7 +251,8 @@ def wait_for_job(job: GenericJob, progress_bar: Optional[str] = None): job.result() job.reload() display.update_display( - display.HTML(get_base_job_loading_html(job)), display_id=display_id + display.HTML(get_base_job_loading_html(job)), + display_id=display_id, ) elif progress_bar == "terminal": inital_loading_bar = get_base_job_loading_string(job) @@ -286,7 +306,10 @@ def render_job_link_html( job_id=job_id, ) if job_url: - job_link = f' [Job {project_id}:{location}.{job_id} details]' + job_link = ( + f' [' + f"Job {project_id}:{location}.{job_id} details]" + ) else: job_link = "" return job_link @@ -323,7 +346,10 @@ def get_job_url( """ if project_id is None or location is None or job_id is None: return None - return f"""https://console.cloud.google.com/bigquery?project={project_id}&j=bq:{location}:{job_id}&page=queryresults""" + return ( + f"https://console.cloud.google.com/bigquery?project={project_id}" + f"&j=bq:{location}:{job_id}&page=queryresults" + ) def render_bqquery_sent_event_html( @@ -348,7 +374,10 @@ def render_bqquery_sent_event_html( job_id=event.job_id, request_id=event.request_id, ) - query_text_details = f"
    SQL
    {html.escape(event.query)}
    " + query_text_details = ( + f"
    SQL
    "
    +        f"{html.escape(event.query)}
    " + ) return f""" Query started{query_id}.{job_link}{query_text_details} @@ -397,7 +426,10 @@ def render_bqquery_retry_event_html( job_id=event.job_id, request_id=event.request_id, ) - query_text_details = f"
    SQL
    {html.escape(event.query)}
    " + query_text_details = ( + f"
    SQL
    "
    +        f"{html.escape(event.query)}
    " + ) return f""" Retrying query{query_id}.{job_link}{query_text_details} @@ -443,7 +475,10 @@ def render_bqquery_received_event_html( query_plan_details = "" if event.query_plan: plan_str = "\n".join([str(entry) for entry in event.query_plan]) - query_plan_details = f"
    Query Plan
    {html.escape(plan_str)}
    " + query_plan_details = ( + f"
    Query Plan
    "
    +            f"{html.escape(plan_str)}
    " + ) return f""" Query{query_id} is {event.state}.{job_link}{query_plan_details} @@ -506,7 +541,8 @@ def render_bqquery_finished_event_plaintext( bytes_str = "" if event.total_bytes_processed is not None: - bytes_str = f" {humanize.naturalsize(event.total_bytes_processed)} processed." + size_str = humanize.naturalsize(event.total_bytes_processed) + bytes_str = f" {size_str} processed." slot_time_str = "" if event.slot_millis is not None: @@ -572,7 +608,8 @@ def get_formatted_time(val): Duration string """ try: - return humanize.naturaldelta(datetime.timedelta(milliseconds=float(val))) + delta = datetime.timedelta(milliseconds=float(val)) + return humanize.naturaldelta(delta) except Exception: return val @@ -591,7 +628,10 @@ def get_formatted_bytes(val): def get_bytes_processed_string(val: Any): - """Try to get bytes processed string. Return empty if passed non int value""" + """Try to get bytes processed string. + + Return empty if passed non int value. + """ bytes_processed_string = "" if isinstance(val, int): bytes_processed_string = f"""{get_formatted_bytes(val)} processed. """ diff --git a/packages/bigframes/bigframes/functions/function.py b/packages/bigframes/bigframes/functions/function.py index 4dee14674042..ac6d3c541dbd 100644 --- a/packages/bigframes/bigframes/functions/function.py +++ b/packages/bigframes/bigframes/functions/function.py @@ -162,7 +162,8 @@ class Udf(Protocol): """ @property - def udf_def(self) -> udf_def.BigqueryUdf: ... + def udf_def(self) -> udf_def.BigqueryUdf: + ... class BigqueryCallableRoutine: diff --git a/packages/bigframes/bigframes/session/_io/bigquery/__init__.py b/packages/bigframes/bigframes/session/_io/bigquery/__init__.py index 780ba55c50db..b6e3b72af7fc 100644 --- a/packages/bigframes/bigframes/session/_io/bigquery/__init__.py +++ b/packages/bigframes/bigframes/session/_io/bigquery/__init__.py @@ -22,7 +22,16 @@ import textwrap import types import typing -from typing import Dict, Iterable, Literal, Mapping, Optional, Tuple, Union, overload +from typing import ( + Dict, + Iterable, + Literal, + Mapping, + Optional, + Tuple, + Union, + overload, +) import bigframes_vendored.google_cloud_bigquery.retry as third_party_gcb_retry import bigframes_vendored.pandas.io.gbq as third_party_pandas_gbq @@ -38,7 +47,10 @@ from bigframes.core.compile.sqlglot import sql as sg_sql from bigframes.core.logging import log_adapter -CHECK_DRIVE_PERMISSIONS = "\nCheck https://cloud.google.com/bigquery/docs/query-drive-data#Google_Drive_permissions." +CHECK_DRIVE_PERMISSIONS = ( + "\nCheck https://cloud.google.com/bigquery/docs/" + "query-drive-data#Google_Drive_permissions." +) IO_ORDERING_ID = "bqdf_row_nums" @@ -85,7 +97,10 @@ def create_job_configs_labels( def create_export_data_statement( - table_id: str, uri: str, format: str, export_options: Dict[str, Union[bool, str]] + table_id: str, + uri: str, + format: str, + export_options: Dict[str, Union[bool, str]], ) -> str: all_options: Dict[str, Union[bool, str]] = { "uri": uri, @@ -142,7 +157,8 @@ def create_temp_table( destination.encryption_configuration = bigquery.EncryptionConfiguration( kms_key_name=kms_key ) - # Ok if already exists, since this will only happen from retries internal to this method + # Ok if already exists, since this will only happen from retries + # internal to this method # as the requested table id has a random UUID4 component. bqclient.create_table(destination, exists_ok=True) return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" @@ -165,7 +181,8 @@ def create_temp_view( destination.expires = expiration destination.view_query = sql - # Ok if already exists, since this will only happen from retries internal to this method + # Ok if already exists, since this will only happen from retries + # internal to this method # as the requested table id has a random UUID4 component. bqclient.create_table(destination, exists_ok=True) return f"{table_ref.project}.{table_ref.dataset_id}.{table_ref.table_id}" @@ -199,7 +216,10 @@ def bq_field_to_type_sql(field: bigquery.SchemaField): if field.mode == "REPEATED": nested_type = bq_field_to_type_sql( bigquery.SchemaField( - field.name, field.field_type, mode="NULLABLE", fields=field.fields + field.name, + field.field_type, + mode="NULLABLE", + fields=field.fields, ) ) return f"ARRAY<{nested_type}>" @@ -232,8 +252,9 @@ def format_option(key: str, value: Union[bool, str]) -> str: def add_and_trim_labels(job_config, session=None): """ - Add additional labels to the job configuration and trim the total number of labels - to ensure they do not exceed MAX_LABELS_COUNT labels per job. + Add additional labels to the job configuration and trim the total + number of labels to ensure they do not exceed MAX_LABELS_COUNT labels + per job. """ api_methods = log_adapter.get_and_reset_api_methods( dry_run=job_config.dry_run, session=session @@ -245,19 +266,35 @@ def add_and_trim_labels(job_config, session=None): def create_bq_event_callback(publisher): - def publish_bq_event(event): - if isinstance(event, google.cloud.bigquery._job_helpers.QueryFinishedEvent): - bf_event = bigframes.core.events.BigQueryFinishedEvent.from_bqclient(event) - elif isinstance(event, google.cloud.bigquery._job_helpers.QueryReceivedEvent): - bf_event = bigframes.core.events.BigQueryReceivedEvent.from_bqclient(event) - elif isinstance(event, google.cloud.bigquery._job_helpers.QueryRetryEvent): - bf_event = bigframes.core.events.BigQueryRetryEvent.from_bqclient(event) - elif isinstance(event, google.cloud.bigquery._job_helpers.QuerySentEvent): - bf_event = bigframes.core.events.BigQuerySentEvent.from_bqclient(event) - else: - bf_event = bigframes.core.events.BigQueryUnknownEvent(event) + import bigframes._config + + progress_bar = bigframes._config.options.display.progress_bar + + event_map = { + google.cloud.bigquery._job_helpers.QueryFinishedEvent: ( + bigframes.core.events.BigQueryFinishedEvent + ), + google.cloud.bigquery._job_helpers.QueryReceivedEvent: ( + bigframes.core.events.BigQueryReceivedEvent + ), + google.cloud.bigquery._job_helpers.QueryRetryEvent: ( + bigframes.core.events.BigQueryRetryEvent + ), + google.cloud.bigquery._job_helpers.QuerySentEvent: ( + bigframes.core.events.BigQuerySentEvent + ), + } - publisher.publish(bf_event) + def publish_bq_event(event): + bf_event = bigframes.core.events.BigQueryUnknownEvent(event) + for bq_type, bf_type in event_map.items(): + if isinstance(event, bq_type): + bf_event = bf_type.from_bqclient(event) # type: ignore + break + envelope = bigframes.core.events.EventEnvelope( + event=bf_event, progress_bar=progress_bar + ) + publisher.publish(envelope) return publish_bq_event @@ -275,7 +312,8 @@ def start_query_with_client( query_with_job: Literal[True], publisher: bigframes.core.events.Publisher, session=None, -) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: ... +) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: + ... @overload @@ -291,7 +329,8 @@ def start_query_with_client( query_with_job: Literal[False], publisher: bigframes.core.events.Publisher, session=None, -) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: ... +) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: + ... @overload @@ -308,7 +347,8 @@ def start_query_with_client( job_retry: google.api_core.retry.Retry, publisher: bigframes.core.events.Publisher, session=None, -) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: ... +) -> Tuple[google.cloud.bigquery.table.RowIterator, bigquery.QueryJob]: + ... @overload @@ -325,7 +365,8 @@ def start_query_with_client( job_retry: google.api_core.retry.Retry, publisher: bigframes.core.events.Publisher, session=None, -) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: ... +) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: + ... def start_query_with_client( @@ -342,7 +383,7 @@ def start_query_with_client( # google-cloud-bigquery version with # https://github.com/googleapis/python-bigquery/pull/2256 merged, likely # version 3.36.0 or later. - job_retry: google.api_core.retry.Retry = third_party_gcb_retry.DEFAULT_JOB_RETRY, + job_retry: google.api_core.retry.Retry = (third_party_gcb_retry.DEFAULT_JOB_RETRY), publisher: bigframes.core.events.Publisher, session=None, ) -> Tuple[google.cloud.bigquery.table.RowIterator, Optional[bigquery.QueryJob]]: @@ -415,7 +456,9 @@ def start_query_with_client( def delete_tables_matching_session_id( - client: bigquery.Client, dataset: bigquery.DatasetReference, session_id: str + client: bigquery.Client, + dataset: bigquery.DatasetReference, + session_id: str, ) -> None: """Searches within the dataset for tables conforming to the expected session_id form, and instructs bigquery to delete them. @@ -469,7 +512,8 @@ def create_bq_dataset_reference( The project id of the project to create the dataset in. Returns: - bigquery.DatasetReference: The constructed reference to the anonymous dataset. + bigquery.DatasetReference: The constructed reference to the + anonymous dataset. """ job_config = google.cloud.bigquery.QueryJobConfig() @@ -503,7 +547,8 @@ def is_query(query_or_table: str) -> bool: def is_table_with_wildcard_suffix(query_or_table: str) -> bool: - """Determine if `query_or_table` is a table and contains a wildcard suffix.""" + """Determine if `query_or_table` is a table and contains a wildcard + suffix.""" return not is_query(query_or_table) and query_or_table.endswith("*") @@ -519,7 +564,8 @@ def to_query( from_item = f"({query_or_table})" else: # Table ID can have 1, 2, 3, or 4 parts. Quoting all parts to be safe. - # See: https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers + # See: + # https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers parts = query_or_table.split(".") from_item = ".".join(f"`{part}`" for part in parts) @@ -569,7 +615,8 @@ def compile_filters(filters: third_party_pandas_gbq.FiltersType) -> str: "!=": "!=", } - # If single layer filter, add another pseudo layer. So the single layer represents "and" logic. + # If single layer filter, add another pseudo layer. So the single + # layer represents "and" logic. filters_list: list = list(filters) if isinstance(filters_list[0], tuple) and ( len(filters_list[0]) == 0 or not isinstance(list(filters_list[0])[0], tuple) @@ -586,14 +633,16 @@ def compile_filters(filters: third_party_pandas_gbq.FiltersType) -> str: for filter_item in group: if not isinstance(filter_item, tuple) or (len(filter_item) != 3): raise ValueError( - f"Elements of filters must be tuples of length 3, but got {repr(filter_item)}.", + f"Elements of filters must be tuples of length 3, " + f"but got {repr(filter_item)}.", ) column, operator, value = filter_item if not isinstance(column, str): raise ValueError( - f"Column name should be a string, but received '{column}' of type {type(column).__name__}." + f"Column name should be a string, but received " + f"'{column}' of type {type(column).__name__}." ) if operator not in valid_operators: diff --git a/packages/bigframes/bigframes/session/metrics.py b/packages/bigframes/bigframes/session/metrics.py index d2682bbcaf7f..206d3da2f4d8 100644 --- a/packages/bigframes/bigframes/session/metrics.py +++ b/packages/bigframes/bigframes/session/metrics.py @@ -236,13 +236,18 @@ def count_job_stats( exec_seconds=exec_seconds, ) - def on_event(self, event: Any): + def on_event(self, envelope: Any): try: import bigframes.core.events from bigframes.session.executor import LocalExecuteResult except ImportError: return + if isinstance(envelope, bigframes.core.events.EventEnvelope): + event = envelope.event + else: + event = envelope + if isinstance(event, bigframes.core.events.ExecutionFinished): if event.result and isinstance(event.result, LocalExecuteResult): self.execution_count += 1 diff --git a/packages/bigframes/tests/system/small/test_progress_bar.py b/packages/bigframes/tests/system/small/test_progress_bar.py index bc247f6078ce..a179e18332af 100644 --- a/packages/bigframes/tests/system/small/test_progress_bar.py +++ b/packages/bigframes/tests/system/small/test_progress_bar.py @@ -104,6 +104,23 @@ def test_progress_bar_load_jobs( assert_loading_msg_exist(capsys.readouterr().out, pattern="Load") +def test_progress_bar_uniqueness_check(session: bf.Session, capsys): + # Ensure strictly_ordered is True (default) to trigger uniqueness check + assert session._strictly_ordered + + capsys.readouterr() # clear output + + with bf.option_context("display.progress_bar", "terminal"): + # Read a table and specify a non-unique index_col to trigger the check. + # We use a public table to make it a "real" test. + session.read_gbq_table( + "bigquery-public-data.ml_datasets.penguins", + index_col="island", + ) + + assert_loading_msg_exist(capsys.readouterr().out) + + def assert_loading_msg_exist(capstdout: str, pattern=job_load_message_regex): num_loading_msg = 0 lines = capstdout.split("\n") diff --git a/packages/bigframes/tests/unit/test_formatting_helpers.py b/packages/bigframes/tests/unit/test_formatting_helpers.py index ec681b36ab05..a90d372e9dbe 100644 --- a/packages/bigframes/tests/unit/test_formatting_helpers.py +++ b/packages/bigframes/tests/unit/test_formatting_helpers.py @@ -212,3 +212,29 @@ def test_get_job_url(): job_id=job_id, location=location, project_id=project_id ) assert actual_url == expected_url + + +def test_progress_callback_falls_back_to_global(): + event = bfevents.BigQuerySentEvent( + query="SELECT * FROM my_table", + ) + envelope = bfevents.EventEnvelope(event=event, progress_bar=bfevents._DEFAULT) + + with mock.patch("bigframes._config.options.display.progress_bar", "terminal"): + with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False): + with mock.patch("builtins.print") as mock_print: + formatting_helpers.progress_callback(envelope) + mock_print.assert_called_once() + + +def test_progress_callback_respects_envelope_progress_bar(): + event = bfevents.BigQuerySentEvent( + query="SELECT * FROM my_table", + ) + envelope = bfevents.EventEnvelope(event=event, progress_bar=None) + + with mock.patch("bigframes._config.options.display.progress_bar", "terminal"): + with mock.patch("bigframes.formatting_helpers.in_ipython", return_value=False): + with mock.patch("builtins.print") as mock_print: + formatting_helpers.progress_callback(envelope) + mock_print.assert_not_called()