"""This module defines functions used by several dialects."""

import sqlalchemy
from packaging import version
from sqlalchemy import Column
from sqlalchemy.sql import expression
from sqlalchemy.types import TypeDecorator

from geoalchemy2.types import Geometry

_SQLALCHEMY_VERSION_BEFORE_14 = version.parse(sqlalchemy.__version__) < version.parse("1.4")


def _spatial_idx_name(table_name, column_name):
    return "idx_{}_{}".format(table_name, column_name)


def _format_select_args(*args):
    if _SQLALCHEMY_VERSION_BEFORE_14:
        return [args]
    else:
        return args


def check_management(*args):
    """Default function to check management (always True by default)."""
    return True


def _get_gis_cols(table, spatial_types, dialect, check_col_management=None):
    if check_col_management is not None:
        func = check_col_management
    else:
        func = check_management
    return [
        col
        for col in table.columns
        if (
            isinstance(col, Column)
            and _check_spatial_type(col.type, spatial_types, dialect)
            and func(col)
        )
    ]


def _check_spatial_type(tested_type, spatial_types, dialect=None):
    return isinstance(tested_type, spatial_types) or (
        isinstance(tested_type, TypeDecorator)
        and isinstance(tested_type.load_dialect_impl(dialect), spatial_types)
    )


def _get_dispatch_info(table, bind, check_col_management=None):
    """Get info required for dispatch events."""
    dialect = bind.dialect

    # Filter Geometry columns from the table
    # Note: Geography and PostGIS >= 2.0 don't need this
    gis_cols = _get_gis_cols(table, Geometry, dialect, check_col_management=check_col_management)

    # Find all other columns that are not managed Geometries
    regular_cols = [x for x in table.columns if x not in gis_cols]

    return dialect, gis_cols, regular_cols


def _update_table_for_dispatch(table, regular_cols):
    """Update the table before dispatch events."""
    # Save original table column list for later
    table.info["_saved_columns"] = table.columns

    # Temporarily patch a set of columns not including the
    # managed Geometry columns
    column_collection = expression.ColumnCollection()
    for col in regular_cols:
        column_collection.add(col)
    table.columns = column_collection


def setup_create_drop(table, bind, check_col_management=None):
    """Prepare the table for before_create and before_drop events."""
    dialect, gis_cols, regular_cols = _get_dispatch_info(table, bind, check_col_management)
    _update_table_for_dispatch(table, regular_cols)
    return dialect, gis_cols, regular_cols


def reflect_geometry_column(inspector, table, column_info):
    return  # pragma: no cover


def before_create(table, bind, **kw):
    return  # pragma: no cover


def after_create(table, bind, **kw):
    return  # pragma: no cover


def before_drop(table, bind, **kw):
    return  # pragma: no cover


def after_drop(table, bind, **kw):
    return  # pragma: no cover
