"""SQL Object Comparator for Drift Detection.
This module provides the ObjectComparator class which compares SQL Model objects
from different sources (parsed scripts vs. database introspection) and generates
structured diff results.
Key Features:
- Compare tables, views, procedures, triggers, sequences
- Detect missing, extra, and modified objects
- Type-aware comparison using DataTypeNormalizer
- Generate structured diff results
- Handle case sensitivity and identifier normalization
"""
import logging
import re
from typing import Any, Dict, List, Optional, Tuple
from sqlmeta.comparison.diff_models import (
ColumnDiff,
ConstraintDiff,
DatabaseLinkDiff,
EventDiff,
ExtensionDiff,
ForeignDataWrapperDiff,
ForeignServerDiff,
FunctionDiff,
IndexDiff,
LinkedServerDiff,
ModuleDiff,
PackageDiff,
ProcedureDiff,
SchemaDiff,
SequenceDiff,
SynonymDiff,
TableDiff,
TriggerDiff,
UserDefinedTypeDiff,
ViewDiff,
)
from sqlmeta.comparison.type_normalizer import DataTypeNormalizer
from sqlmeta.base import ConstraintType, SqlColumn, SqlConstraint
from sqlmeta.objects.database_link import DatabaseLink
from sqlmeta.objects.event import Event
from sqlmeta.objects.extension import Extension
from sqlmeta.objects.foreign_data_wrapper import ForeignDataWrapper
from sqlmeta.objects.foreign_server import ForeignServer
from sqlmeta.objects.index import Index
from sqlmeta.objects.linked_server import LinkedServer
from sqlmeta.objects.module import Module
from sqlmeta.objects.package import Package
from sqlmeta.objects.procedure import Procedure
from sqlmeta.objects.sequence import Sequence
from sqlmeta.objects.synonym import Synonym
from sqlmeta.objects.table import Table
from sqlmeta.objects.trigger import Trigger
from sqlmeta.objects.user_defined_type import UserDefinedType
from sqlmeta.objects.view import View
logger = logging.getLogger(__name__)
def _is_system_generated_constraint_name(name: str) -> bool:
"""Check if constraint name is system-generated by the database.
System-generated names should be ignored when matching constraints,
as they vary between database instances.
Args:
name: Constraint name to check
Returns:
True if the name appears to be system-generated
Examples:
Oracle: SYS_C0013220 -> True
SQL Server: PK__users__3213E83F -> True
User-defined: pk_users_id -> False
"""
if not name:
return False
name_upper = name.upper()
# Oracle: SYS_C followed by numbers
if name_upper.startswith("SYS_C"):
# Check if rest is numeric
suffix = name_upper[5:]
if suffix and suffix.isdigit():
return True
# SQL Server auto-generated: PK__TableName__Hash or FK__TableName__Hash
if "__" in name_upper and (name_upper.startswith("PK__") or name_upper.startswith("FK__")):
return True
# MySQL auto-generated: Often just the column name for single-column constraints
# We can't reliably detect these, so we don't mark them as system-generated
return False
def _extract_base_identity_type(data_type: str, dialect: str) -> str:
"""Extract base data type from identity column definition.
Identity columns can be defined with various syntax:
- Oracle: NUMBER GENERATED ALWAYS AS IDENTITY
- PostgreSQL: SERIAL, BIGSERIAL
- SQL Server: INT IDENTITY(1,1)
- MySQL: INT AUTO_INCREMENT
This function extracts just the base type for comparison.
Args:
data_type: Full data type string
dialect: Database dialect
Returns:
Base data type without identity keywords
"""
if not data_type:
return data_type
data_type_upper = data_type.upper().strip()
# PostgreSQL SERIAL types map to INTEGER/BIGINT
if dialect == "postgresql":
if data_type_upper.startswith("BIGSERIAL"):
return "BIGINT"
if data_type_upper.startswith("SMALLSERIAL"):
return "SMALLINT"
if data_type_upper.startswith("SERIAL"):
return "INTEGER"
# Remove identity-related keywords for other dialects
identity_keywords = [
"GENERATED ALWAYS AS IDENTITY",
"GENERATED BY DEFAULT AS IDENTITY",
"AUTO_INCREMENT",
"IDENTITY",
]
result = data_type_upper
for keyword in identity_keywords:
result = result.replace(keyword, "").strip()
# Remove IDENTITY(seed, increment) syntax
import re
result = re.sub(r"IDENTITY\s*\(\s*\d+\s*,\s*\d+\s*\)", "", result).strip()
result = re.sub(r"\(\s*\d+\s*,\s*\d+\s*\)", "", result).strip()
return result if result else data_type
[docs]
class ObjectComparator:
"""Compares SQL Model objects and generates diff results.
This class provides methods to compare SQL objects from different sources
(e.g., parsed SQL scripts vs. database metadata) and identify differences.
Example:
>>> normalizer = DataTypeNormalizer()
>>> comparator = ObjectComparator(normalizer)
>>> diff = comparator.compare_tables(script_table, db_table, "postgresql")
>>> if diff.has_diffs:
... print(f"Found differences: {diff}")
"""
[docs]
def __init__(self, type_normalizer: DataTypeNormalizer):
"""Initialize the object comparator.
Args:
type_normalizer: DataTypeNormalizer for type comparison
"""
self.type_normalizer = type_normalizer
[docs]
def compare_tables(
self, expected: Table, actual: Table, dialect: str = "postgresql"
) -> TableDiff:
"""Compare two table objects.
Args:
expected: Expected table (from scripts)
actual: Actual table (from database)
dialect: SQL dialect for type normalization
Returns:
TableDiff object with comparison results
Example:
>>> diff = comparator.compare_tables(script_table, db_table, "postgresql")
>>> print(f"Missing columns: {diff.missing_columns}")
"""
# Check if table is derived (CREATE TABLE AS SELECT or CREATE TABLE LIKE)
# For derived tables, columns and constraints are defined at execution time, so we skip comparison
is_derived = getattr(expected, "derived_from", None) is not None
if is_derived:
# Skip column and constraint comparison for derived tables
# Columns and constraints are determined from source table/query at execution time
missing_cols: List[SqlColumn] = []
extra_cols: List[SqlColumn] = []
modified_cols: List[ColumnDiff] = []
missing_consts: List[SqlConstraint] = []
extra_consts: List[SqlConstraint] = []
modified_consts: List[ConstraintDiff] = []
else:
# Compare columns normally
missing_cols, extra_cols, modified_cols = self._compare_columns(
expected.columns, actual.columns, dialect
)
# Compare constraints
missing_consts, extra_consts, modified_consts = self._compare_constraints(
expected.constraints, actual.constraints, dialect
)
# Create TableDiff
table_diff = TableDiff(
object_name=expected.name,
table_name=expected.name,
missing_columns=[col.name for col in missing_cols],
extra_columns=[col.name for col in extra_cols],
modified_columns=modified_cols,
missing_constraints=[c.name or f"unnamed_{i}" for i, c in enumerate(missing_consts)],
extra_constraints=[c.name or f"unnamed_{i}" for i, c in enumerate(extra_consts)],
modified_constraints=modified_consts,
)
# Grammar-based: Compare temporary property
expected_temp = getattr(expected, "temporary", False)
actual_temp = getattr(actual, "temporary", False)
if expected_temp != actual_temp:
table_diff.temporary_changed = True
# T-SQL grammar-based: Compare T-SQL-specific properties
if dialect == "sqlserver":
# Compare filegroup
# Note: None and 'PRIMARY' are equivalent (PRIMARY is default filegroup)
expected_filegroup = getattr(expected, "filegroup", None)
actual_filegroup = getattr(actual, "filegroup", None)
# Normalize: None and 'PRIMARY' are equivalent
expected_fg_norm = (
expected_filegroup
if expected_filegroup and expected_filegroup.upper() != "PRIMARY"
else None
)
actual_fg_norm = (
actual_filegroup
if actual_filegroup and actual_filegroup.upper() != "PRIMARY"
else None
)
if expected_fg_norm != actual_fg_norm:
table_diff.filegroup_changed = True
# Compare memory-optimized
expected_memory_opt = getattr(expected, "memory_optimized", False)
actual_memory_opt = getattr(actual, "memory_optimized", False)
if expected_memory_opt != actual_memory_opt:
table_diff.memory_optimized_changed = True
# Compare system-versioned
expected_sys_ver = getattr(expected, "system_versioned", False)
actual_sys_ver = getattr(actual, "system_versioned", False)
if expected_sys_ver != actual_sys_ver:
table_diff.system_versioned_changed = True
# Compare history table if system-versioned
if expected_sys_ver and actual_sys_ver:
expected_hist_table = getattr(expected, "history_table", None)
actual_hist_table = getattr(actual, "history_table", None)
expected_hist_schema = getattr(expected, "history_schema", None)
actual_hist_schema = getattr(actual, "history_schema", None)
if (
expected_hist_table != actual_hist_table
or expected_hist_schema != actual_hist_schema
):
table_diff.history_table_changed = True
# Grammar-based: Compare DB2-specific table properties
if dialect == "db2":
# Compare compress (normalize None to False - DB2 default is not compressed)
expected_compress = getattr(expected, "compress", None)
actual_compress = getattr(actual, "compress", None)
# Normalize: None means not explicitly set, treat as False (DB2 default)
expected_compress_norm = False if expected_compress is None else expected_compress
actual_compress_norm = False if actual_compress is None else actual_compress
if expected_compress_norm != actual_compress_norm:
table_diff.compress_changed = True
# Compare compress_type if both are compressed
if expected_compress and actual_compress:
expected_compress_type = getattr(expected, "compress_type", None)
actual_compress_type = getattr(actual, "compress_type", None)
if expected_compress_type != actual_compress_type:
table_diff.compress_type_changed = True
# Compare logged (normalize None - only flag if explicitly different)
# In DB2, tables are logged by default, but we only care if explicitly set differently
expected_logged = getattr(expected, "logged", None)
actual_logged = getattr(actual, "logged", None)
# Only flag as changed if both are explicitly set and different
if expected_logged is not None and actual_logged is not None:
if expected_logged != actual_logged:
table_diff.logged_changed = True
# Compare organize_by (both must be explicitly set to trigger diff)
expected_organize = getattr(expected, "organize_by", None)
actual_organize = getattr(actual, "organize_by", None)
# Only flag as changed if both are explicitly set and different
if expected_organize is not None and actual_organize is not None:
if expected_organize != actual_organize:
table_diff.organize_by_changed = True
# Compare partition scheme (method and columns, NOT individual partitions)
# Individual partitions can be auto-created (especially Oracle INTERVAL partitions)
expected_part_method = getattr(expected, "partition_method", None)
actual_part_method = getattr(actual, "partition_method", None)
expected_part_cols = getattr(expected, "partition_columns", None)
actual_part_cols = getattr(actual, "partition_columns", None)
# Debug logging for partition comparison
if expected_part_method or actual_part_method:
logger.debug(
f"Partition comparison for table '{expected.name}': "
f"expected_method={expected_part_method}, actual_method={actual_part_method}, "
f"expected_cols={expected_part_cols}, actual_cols={actual_part_cols}"
)
if expected_part_method != actual_part_method:
table_diff.partition_method_changed = True
logger.debug(
f"Partition method changed: {expected_part_method} != {actual_part_method}"
)
if expected_part_method and actual_part_method:
# Only compare columns if both are partitioned
# Normalize column lists (case-insensitive, sorted)
exp_cols_norm = sorted([c.lower() for c in (expected_part_cols or [])])
act_cols_norm = sorted([c.lower() for c in (actual_part_cols or [])])
if exp_cols_norm != act_cols_norm:
table_diff.partition_columns_changed = True
logger.debug(f"Partition columns changed: {exp_cols_norm} != {act_cols_norm}")
# Recalculate has_diffs and severity with new properties
table_diff._calculate_diffs()
return table_diff
[docs]
def compare_schemas(
self,
expected_tables: List[Table],
actual_tables: List[Table],
dialect: str = "postgresql",
schema_name: str = "public",
) -> SchemaDiff:
"""Compare lists of tables from two schemas.
Args:
expected_tables: Expected tables (from scripts)
actual_tables: Actual tables (from database)
dialect: SQL dialect for type normalization
schema_name: Name of the schema being compared
Returns:
SchemaDiff object with comparison results
"""
# Create lookup maps (case-insensitive)
expected_map = {t.name.lower(): t for t in expected_tables}
actual_map = {t.name.lower(): t for t in actual_tables}
# Find missing, extra, and common tables
expected_names = set(expected_map.keys())
actual_names = set(actual_map.keys())
missing_table_names = list(expected_names - actual_names)
extra_table_names = list(actual_names - expected_names)
common_table_names = expected_names & actual_names
# Compare common tables
modified_tables = []
for table_name in common_table_names:
expected_table = expected_map[table_name]
actual_table = actual_map[table_name]
table_diff = self.compare_tables(expected_table, actual_table, dialect)
if table_diff.has_diffs:
modified_tables.append(table_diff)
# Log details about what changed
import logging
logger = logging.getLogger(__name__)
logger.debug(
f"Table '{table_name}' has differences (severity: {table_diff.severity})"
)
if table_diff.modified_columns:
logger.debug(
f" Modified columns: {[c.object_name for c in table_diff.modified_columns]}"
)
for col_diff in table_diff.modified_columns:
logger.debug(
f" Column '{col_diff.object_name}' (severity: {col_diff.severity}):"
)
if col_diff.data_type_diff:
logger.debug(f" Data type: {col_diff.data_type_diff}")
if col_diff.nullable_diff:
logger.debug(f" Nullable: {col_diff.nullable_diff}")
if col_diff.default_diff:
logger.debug(f" Default: {col_diff.default_diff}")
# Create SchemaDiff
schema_diff = SchemaDiff(
object_name=schema_name,
schema_name=schema_name,
missing_tables=missing_table_names,
extra_tables=extra_table_names,
modified_tables=modified_tables,
)
return schema_diff
def _compare_columns(
self, expected_columns: List[SqlColumn], actual_columns: List[SqlColumn], dialect: str
) -> Tuple[List[SqlColumn], List[SqlColumn], List[ColumnDiff]]:
"""Compare column lists and identify differences.
Args:
expected_columns: Expected columns list
actual_columns: Actual columns list
dialect: SQL dialect
Returns:
Tuple of (missing, extra, modified) where:
- missing: Columns in expected but not in actual
- extra: Columns in actual but not in expected
- modified: ColumnDiff objects for columns with differences
"""
# Create lookup maps (case-insensitive)
expected_map = {col.name.lower(): col for col in expected_columns}
actual_map = {col.name.lower(): col for col in actual_columns}
# Find missing, extra, and common columns
expected_names = set(expected_map.keys())
actual_names = set(actual_map.keys())
missing = [expected_map[name] for name in (expected_names - actual_names)]
extra = [actual_map[name] for name in (actual_names - expected_names)]
# Compare common columns for modifications
modified = []
common_names = expected_names & actual_names
for col_name in common_names:
expected_col = expected_map[col_name]
actual_col = actual_map[col_name]
col_diff = self._compare_column_details(expected_col, actual_col, dialect)
if col_diff is not None and col_diff.has_diffs:
modified.append(col_diff)
return missing, extra, modified
def _compare_constraints(
self,
expected_constraints: List[SqlConstraint],
actual_constraints: List[SqlConstraint],
dialect: str,
) -> Tuple[List[SqlConstraint], List[SqlConstraint], List[ConstraintDiff]]:
"""Compare constraint lists and identify differences.
Args:
expected_constraints: Expected constraints list
actual_constraints: Actual constraints list
dialect: SQL dialect
Returns:
Tuple of (missing, extra, modified) where:
- missing: Constraints in expected but not in actual
- extra: Constraints in actual but not in expected
- modified: ConstraintDiff objects for constraints with differences
"""
# Create lookup maps by constraint signature (type + columns)
def constraint_key(c: SqlConstraint) -> str:
"""Generate unique key for constraint.
Matches constraints by type and columns, ignoring system-generated names.
This ensures that unnamed constraints or those with auto-generated names
(like Oracle's SYS_C*) are matched correctly.
"""
key_parts = [c.constraint_type.value.lower()]
# Always try to match by columns first (most reliable)
has_column_signature = False
if getattr(c, "column_names", None):
column_signature = ",".join(sorted(col.lower() for col in c.column_names if col))
if column_signature:
key_parts.append(column_signature)
has_column_signature = True
# For foreign keys, add reference information
if c.constraint_type == ConstraintType.FOREIGN_KEY:
reference_table = (getattr(c, "reference_table", "") or "").lower()
reference_columns = None
if getattr(c, "reference_columns", None):
reference_columns = ",".join(
sorted(col.lower() for col in c.reference_columns if col)
)
if reference_table:
key_parts.append(reference_table)
if reference_columns:
key_parts.append(reference_columns)
# Only use name if:
# 1. We don't have a column signature (e.g., check constraints without columns)
# 2. The name is NOT system-generated
if (
not has_column_signature
and c.name
and not _is_system_generated_constraint_name(c.name)
):
key_parts.append(c.name.lower())
return "|".join(part for part in key_parts if part)
# Filter out duplicate UNIQUE constraints that match PRIMARY KEY constraints
# (Oracle/DB2 create both a PK constraint and a unique index/constraint with potentially different names)
def filter_duplicate_unique_constraints(
constraints: List[SqlConstraint],
) -> List[SqlConstraint]:
"""Remove UNIQUE constraints that duplicate PRIMARY KEY constraints.
For databases like DB2 and Oracle, PRIMARY KEY constraints can generate
UNIQUE constraints with different names but identical columns. We filter
based on column signature only, not name matching, to avoid false positives.
"""
from sqlmeta.base import ConstraintType
# Find all PRIMARY KEY constraints with their column signatures
pk_signatures = set()
for c in constraints:
if c.constraint_type == ConstraintType.PRIMARY_KEY:
col_sig = tuple(sorted(getattr(c, "column_names", [])))
pk_signatures.add(col_sig)
# Filter out UNIQUE constraints that match PK signatures
filtered = []
for c in constraints:
if c.constraint_type == ConstraintType.UNIQUE:
col_sig = tuple(sorted(getattr(c, "column_names", [])))
# Skip if this UNIQUE constraint has same columns as a PK
# DB2 and Oracle create separate UNIQUE constraints for PRIMARY KEYs
# with different names, so we only check column signature
if col_sig in pk_signatures:
continue
filtered.append(c)
return filtered
# Apply filtering to both expected and actual constraints
expected_constraints = filter_duplicate_unique_constraints(expected_constraints)
actual_constraints = filter_duplicate_unique_constraints(actual_constraints)
# Log constraint details for debugging
import logging
logger = logging.getLogger(__name__)
logger.info(
f"[COMPARATOR] Expected constraints: {[(c.name, c.constraint_type, getattr(c, 'column_names', [])) for c in expected_constraints]}"
)
logger.info(
f"[COMPARATOR] Actual constraints: {[(c.name, c.constraint_type, getattr(c, 'column_names', [])) for c in actual_constraints]}"
)
expected_map = {constraint_key(c): c for c in expected_constraints}
actual_map = {constraint_key(c): c for c in actual_constraints}
logger.info(f"[COMPARATOR] Expected constraint keys: {list(expected_map.keys())}")
logger.info(f"[COMPARATOR] Actual constraint keys: {list(actual_map.keys())}")
# Find missing, extra, and common constraints
expected_keys = set(expected_map.keys())
actual_keys = set(actual_map.keys())
missing = [expected_map[key] for key in (expected_keys - actual_keys)]
extra = [actual_map[key] for key in (actual_keys - expected_keys)]
# Compare common constraints for modifications
modified = []
common_keys = expected_keys & actual_keys
for key in common_keys:
expected_const = expected_map[key]
actual_const = actual_map[key]
const_diff = self._compare_constraint_details(expected_const, actual_const)
if const_diff is not None and const_diff.has_diffs:
modified.append(const_diff)
# Additional pass: Match constraints by explicit name if they have the same name
# but different signatures (indicating a modification)
missing_by_name = {
c.name.lower(): c
for c in missing
if c.name and not _is_system_generated_constraint_name(c.name)
}
extra_by_name = {
c.name.lower(): c
for c in extra
if c.name and not _is_system_generated_constraint_name(c.name)
}
# Find constraints that exist in both missing and extra with the same name
# These are likely modifications
common_names = set(missing_by_name.keys()) & set(extra_by_name.keys())
for name in common_names:
expected_const = missing_by_name[name]
actual_const = extra_by_name[name]
# Only consider as modification if same type
if expected_const.constraint_type == actual_const.constraint_type:
const_diff = self._compare_constraint_details(expected_const, actual_const)
if const_diff is not None and const_diff.has_diffs:
modified.append(const_diff)
# Remove from missing and extra lists
# Use safe name comparison (handle None names)
missing = [c for c in missing if not (c.name and c.name.lower() == name)]
extra = [c for c in extra if not (c.name and c.name.lower() == name)]
return missing, extra, modified
def _compare_column_details(
self, expected_col: SqlColumn, actual_col: SqlColumn, dialect: str
) -> Optional[ColumnDiff]:
"""Compare two column objects in detail.
Args:
expected_col: Expected column
actual_col: Actual column
dialect: SQL dialect
Returns:
ColumnDiff if differences found, None otherwise
"""
expected_raw_type = (expected_col.data_type or "").upper()
actual_raw_type = (actual_col.data_type or "").upper()
# Check if both columns are identity/auto-increment columns
expected_is_identity = (
getattr(expected_col, "is_identity", False)
or getattr(expected_col, "is_autoincrement", False)
or expected_raw_type.startswith("SERIAL")
)
actual_is_identity = getattr(actual_col, "is_identity", False) or getattr(
actual_col, "is_autoincrement", False
)
# Log column comparison details
import logging
logger = logging.getLogger(__name__)
logger.info(
f"[COMPARATOR] Comparing column '{expected_col.name}': expected_type='{expected_col.data_type}', actual_type='{actual_col.data_type}'"
)
logger.info(
f"[COMPARATOR] expected_is_identity={expected_is_identity}, actual_is_identity={actual_is_identity}"
)
# For identity columns, compare base types without identity keywords
if expected_is_identity and actual_is_identity:
expected_base = _extract_base_identity_type(expected_col.data_type, dialect)
actual_base = _extract_base_identity_type(actual_col.data_type, dialect)
expected_type = self.type_normalizer.normalize(expected_base, dialect)
actual_type = self.type_normalizer.normalize(actual_base, dialect)
else:
# Normal type comparison
expected_type = self.type_normalizer.normalize(expected_col.data_type, dialect)
actual_type = self.type_normalizer.normalize(actual_col.data_type, dialect)
# Check for differences
data_type_diff = None
if expected_type.upper() != actual_type.upper():
# Check if they're cross-dialect equivalents
if not self.type_normalizer.are_equivalent(
expected_col.data_type, actual_col.data_type, dialect, dialect
):
data_type_diff = (expected_type, actual_type)
nullable_diff = None
if expected_col.nullable != actual_col.nullable:
nullable_diff = (expected_col.nullable, actual_col.nullable)
default_diff = None
expected_default = self._normalize_default_value(expected_col.default_value)
actual_default = self._normalize_default_value(actual_col.default_value)
if expected_default != actual_default:
default_diff = (expected_col.default_value, actual_col.default_value)
logger.info(f"[COMPARATOR] Default value difference for column '{expected_col.name}':")
logger.info(f"[COMPARATOR] Expected (raw): '{expected_col.default_value}'")
logger.info(f"[COMPARATOR] Actual (raw): '{actual_col.default_value}'")
logger.info(f"[COMPARATOR] Expected (normalized): '{expected_default}'")
logger.info(f"[COMPARATOR] Actual (normalized): '{actual_default}'")
identity_diff = None
if expected_col.is_identity != actual_col.is_identity:
identity_diff = (expected_col.is_identity, actual_col.is_identity)
computed_diff = None
if expected_col.is_computed != actual_col.is_computed:
computed_diff = (expected_col.is_computed, actual_col.is_computed)
elif expected_col.is_computed and actual_col.is_computed:
# Compare computed expressions if both are computed
# Oracle stores virtual column expressions in default field if computed_expression is not populated
expected_expr = self._normalize_expression(expected_col.computed_expression)
actual_expr = self._normalize_expression(actual_col.computed_expression)
# If actual doesn't have computed_expression but has default with expression-like content,
# try to extract expression from default (Oracle-specific workaround)
if not actual_expr and actual_col.default_value:
# Check if default looks like a computed expression (contains operators, not a literal)
default_val = actual_col.default_value
if any(
op in default_val
for op in ["*", "+", "-", "/", "(", ")", "||", "AND", "OR", "CASE"]
):
actual_expr = self._normalize_expression(default_val)
if expected_expr and actual_expr and expected_expr != actual_expr:
computed_diff = (expected_col.computed_expression, actual_col.computed_expression) # type: ignore[assignment]
elif expected_expr and not actual_expr:
# Expected has expression but actual doesn't (shouldn't happen if enrichment works)
computed_diff = (expected_col.computed_expression, None) # type: ignore[assignment]
elif not expected_expr and actual_expr:
# Actual has expression but expected doesn't (shouldn't happen)
computed_diff = (None, actual_col.computed_expression) # type: ignore[assignment]
# For identity columns, ignore nullable/default/identity flag differences
# These are implicit properties of identity columns
if expected_is_identity and actual_is_identity:
nullable_diff = None
default_diff = None
identity_diff = None
# For computed columns, ignore default differences
# The default field may contain the expression (Oracle) but it's not a real default
if expected_col.is_computed and actual_col.is_computed:
default_diff = None
# Create ColumnDiff if any differences found
if any([data_type_diff, nullable_diff, default_diff, identity_diff, computed_diff]):
logger.info(f"[COMPARATOR] Column '{expected_col.name}' HAS DIFFERENCES:")
logger.info(f"[COMPARATOR] data_type_diff={data_type_diff}")
logger.info(f"[COMPARATOR] nullable_diff={nullable_diff}")
logger.info(f"[COMPARATOR] default_diff={default_diff}")
logger.info(f"[COMPARATOR] identity_diff={identity_diff}")
logger.info(f"[COMPARATOR] computed_diff={computed_diff}")
return ColumnDiff(
object_name=expected_col.name,
column_name=expected_col.name,
data_type_diff=data_type_diff,
nullable_diff=nullable_diff,
default_diff=default_diff,
identity_diff=identity_diff,
computed_diff=computed_diff,
)
return None
def _compare_constraint_details(
self, expected_const: SqlConstraint, actual_const: SqlConstraint
) -> Optional[ConstraintDiff]:
"""Compare two constraint objects in detail.
Args:
expected_const: Expected constraint
actual_const: Actual constraint
Returns:
ConstraintDiff if differences found, None otherwise
"""
# Check for differences
columns_diff = None
expected_cols = sorted([c.lower() for c in expected_const.column_names])
actual_cols = sorted([c.lower() for c in actual_const.column_names])
if expected_cols != actual_cols:
columns_diff = (expected_const.column_names, actual_const.column_names)
references_diff = None
if expected_const.reference_table or actual_const.reference_table:
expected_ref = expected_const.reference_table
actual_ref = actual_const.reference_table
if (expected_ref or "").lower() != (actual_ref or "").lower():
references_diff = (expected_ref, actual_ref)
else:
# Check reference columns
expected_ref_cols = sorted(
[c.lower() for c in (expected_const.reference_columns or [])]
)
actual_ref_cols = sorted(
[c.lower() for c in (actual_const.reference_columns or [])]
)
if expected_ref_cols != actual_ref_cols:
references_diff = ( # type: ignore[assignment]
expected_const.reference_columns,
actual_const.reference_columns,
)
check_clause_diff = None
if expected_const.check_expression or actual_const.check_expression:
expected_expr = self._normalize_expression(expected_const.check_expression)
actual_expr = self._normalize_expression(actual_const.check_expression)
if expected_expr != actual_expr:
check_clause_diff = (
expected_const.check_expression,
actual_const.check_expression,
)
# Create ConstraintDiff if any differences found
if any([columns_diff, references_diff, check_clause_diff]):
constraint_name = expected_const.name or actual_const.name or "unnamed"
return ConstraintDiff(
object_name=constraint_name,
constraint_name=constraint_name,
columns_diff=columns_diff,
references_diff=references_diff,
check_clause_diff=check_clause_diff,
)
return None
def _normalize_default_value(self, value: Optional[str]) -> Optional[str]:
"""Normalize a default value for comparison.
Args:
value: Default value string
Returns:
Normalized value
"""
if value is None:
return None
# Remove quotes and whitespace
normalized = value.strip().strip("'").strip('"')
# Convert common variations
if normalized.upper() in ["NULL", "NONE", ""]:
return None
# Normalize boolean values
if normalized.upper() in ["TRUE", "T", "1", "YES", "Y"]:
return "TRUE"
if normalized.upper() in ["FALSE", "F", "0", "NO", "N"]:
return "FALSE"
# SQL Server specific normalization
# Remove outer parentheses: (getdate()) -> getdate()
if normalized.startswith("(") and normalized.endswith(")"):
normalized = normalized[1:-1].strip()
# DB2 specific: Normalize CURRENT TIMESTAMP, CURRENT DATE, etc.
# DB2 allows both "CURRENT" and "CURRENT TIMESTAMP" as synonyms
normalized_upper = normalized.upper()
if normalized_upper in ("CURRENT TIMESTAMP", "CURRENT_TIMESTAMP"):
return "CURRENT_TIMESTAMP"
elif normalized_upper in ("CURRENT DATE", "CURRENT_DATE"):
return "CURRENT_DATE"
elif normalized_upper in ("CURRENT TIME", "CURRENT_TIME"):
return "CURRENT_TIME"
elif normalized_upper == "CURRENT":
# DB2: "CURRENT" alone is a synonym for "CURRENT TIMESTAMP"
return "CURRENT_TIMESTAMP"
# Normalize function names to uppercase for consistency
# This handles: getdate() -> GETDATE(), suser_name() -> SUSER_NAME()
import re
func_pattern = r"^([a-zA-Z_][a-zA-Z0-9_]*)\s*(\(.*\))$"
func_match = re.match(func_pattern, normalized)
if func_match:
func_name = func_match.group(1).upper()
func_args = func_match.group(2)
# For timestamp/datetime functions with empty parentheses, remove them
# MySQL accepts CURRENT_TIMESTAMP and CURRENT_TIMESTAMP() as equivalent
if (
func_name
in (
"CURRENT_TIMESTAMP",
"CURRENT_DATE",
"CURRENT_TIME",
"LOCALTIMESTAMP",
"LOCALTIME",
"NOW",
)
and func_args == "()"
):
normalized = func_name
else:
normalized = f"{func_name}{func_args}"
return normalized
def _normalize_expression(self, expr: Optional[str]) -> Optional[str]:
"""Normalize an expression for comparison.
Args:
expr: Expression string
Returns:
Normalized expression
Note:
For Oracle, expressions may contain quoted identifiers like "PRICE"*"QUANTITY".
We normalize by removing quotes and normalizing whitespace/case.
"""
if expr is None:
return None
# Remove extra whitespace and normalize case
normalized = " ".join(expr.split()).upper()
# Oracle-specific: Remove quotes from identifiers in expressions
# Oracle returns: "PRICE"*"QUANTITY" but migration has: price * quantity
# We remove quotes to normalize: PRICE*QUANTITY
# But preserve quotes in string literals (e.g., 'text')
# Pattern: Match quoted identifiers (double quotes) but not string literals (single quotes)
normalized = re.sub(r'"([^"]+)"', r"\1", normalized)
# Normalize whitespace around operators (*, +, -, /, =, etc.)
# Add single space around operators for consistent comparison
normalized = re.sub(r"\s*([*/+\-=<>])\s*", r" \1 ", normalized)
normalized = " ".join(normalized.split()) # Normalize multiple spaces
return normalized
def _normalize_identifier(self, identifier: Optional[str]) -> str:
"""Normalize an identifier for case-insensitive comparison.
Args:
identifier: Identifier to normalize
Returns:
Normalized identifier (lowercase)
"""
if identifier is None:
return ""
return identifier.lower()
[docs]
def compare_views(
self, expected: View, actual: View, dialect: str = "postgresql"
) -> Optional[ViewDiff]:
"""Compare two view objects.
Args:
expected: Expected view from migrations
actual: Actual view from database
dialect: SQL dialect
Returns:
ViewDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
view_name = expected.name or actual.name
diff = ViewDiff(object_name=view_name, view_name=view_name)
# Compare definitions (normalize whitespace and case)
expected_def = self._normalize_view_definition(expected.query)
actual_def = self._normalize_view_definition(actual.query)
if expected_def != actual_def:
diff.definition_changed = True
diff.expected_definition = expected.query
diff.actual_definition = actual.query
logger.info(f"View '{view_name}': definition changed")
# Compare materialized status (PostgreSQL)
expected_mat = getattr(expected, "materialized", False)
actual_mat = getattr(actual, "materialized", False)
if expected_mat != actual_mat:
diff.materialized_changed = (expected_mat, actual_mat)
logger.info(
f"View '{view_name}': materialized status changed from {expected_mat} to {actual_mat}"
)
# Grammar-based: Compare PostgreSQL UNLOGGED (materialized views)
if dialect in ("postgresql", "postgres"):
expected_unlogged = getattr(expected, "unlogged", None)
actual_unlogged = getattr(actual, "unlogged", None)
if expected_unlogged is not None and actual_unlogged is not None:
if expected_unlogged != actual_unlogged:
diff.unlogged_changed = (expected_unlogged, actual_unlogged)
logger.info(
f"View '{view_name}': UNLOGGED status changed from {expected_unlogged} to {actual_unlogged}"
)
# Grammar-based: Compare MySQL view properties
if dialect in ("mysql", "mariadb"):
# Compare algorithm
expected_algorithm = getattr(expected, "algorithm", None)
actual_algorithm = getattr(actual, "algorithm", None)
if expected_algorithm != actual_algorithm:
diff.algorithm_changed = (expected_algorithm, actual_algorithm)
logger.info(
f"View '{view_name}': algorithm changed from {expected_algorithm} to {actual_algorithm}"
)
# Compare SQL SECURITY
expected_sql_sec = getattr(expected, "sql_security", None)
actual_sql_sec = getattr(actual, "sql_security", None)
if expected_sql_sec != actual_sql_sec:
diff.sql_security_changed = (expected_sql_sec, actual_sql_sec)
logger.info(
f"View '{view_name}': SQL SECURITY changed from {expected_sql_sec} to {actual_sql_sec}"
)
# Compare definer
expected_definer = getattr(expected, "definer", None)
actual_definer = getattr(actual, "definer", None)
if expected_definer != actual_definer:
diff.definer_changed = (expected_definer, actual_definer)
logger.info(
f"View '{view_name}': definer changed from {expected_definer} to {actual_definer}"
)
# Grammar-based: Compare Oracle FORCE/NOFORCE
if dialect == "oracle":
expected_force = getattr(expected, "force", None)
actual_force = getattr(actual, "force", None)
if expected_force is not None and actual_force is not None:
if expected_force != actual_force:
diff.force_changed = (expected_force, actual_force)
logger.info(
f"View '{view_name}': FORCE/NOFORCE changed from {expected_force} to {actual_force}"
)
# Compare materialized view specific properties (only if both are materialized)
if expected_mat and actual_mat:
# Compare is_populated status
expected_populated = getattr(expected, "is_populated", None)
actual_populated = getattr(actual, "is_populated", None)
if expected_populated is not None and actual_populated is not None:
if expected_populated != actual_populated:
diff.is_populated_changed = (expected_populated, actual_populated)
logger.info(
f"Materialized view '{view_name}': populated status changed from {expected_populated} to {actual_populated}"
)
# Compare refresh_method
expected_method = getattr(expected, "refresh_method", None)
actual_method = getattr(actual, "refresh_method", None)
if expected_method and actual_method:
# Normalize for comparison (case-insensitive)
if expected_method.upper() != actual_method.upper():
diff.refresh_method_changed = (expected_method, actual_method)
logger.info(
f"Materialized view '{view_name}': refresh method changed from {expected_method} to {actual_method}"
)
# Compare refresh_mode (Oracle)
expected_mode = getattr(expected, "refresh_mode", None)
actual_mode = getattr(actual, "refresh_mode", None)
if expected_mode and actual_mode:
# Normalize for comparison (case-insensitive)
if expected_mode.upper() != actual_mode.upper():
diff.refresh_mode_changed = (expected_mode, actual_mode)
logger.info(
f"Materialized view '{view_name}': refresh mode changed from {expected_mode} to {actual_mode}"
)
# Compare fast_refreshable (Oracle)
expected_fast = getattr(expected, "fast_refreshable", None)
actual_fast = getattr(actual, "fast_refreshable", None)
if expected_fast is not None and actual_fast is not None:
if expected_fast != actual_fast:
diff.fast_refreshable_changed = (expected_fast, actual_fast)
logger.info(
f"Materialized view '{view_name}': fast refresh capability changed from {expected_fast} to {actual_fast}"
)
diff._calculate_diffs()
return diff if diff.has_diffs else None
def _normalize_view_definition(self, definition: Optional[str]) -> str:
"""Normalize view definition for comparison.
Removes extra whitespace, normalizes case of keywords, removes comments.
Args:
definition: View definition SQL
Returns:
Normalized definition
"""
if not definition:
return ""
# Remove comments
import re
# Remove single-line comments
definition = re.sub(r"--.*$", "", definition, flags=re.MULTILINE)
# Remove multi-line comments
definition = re.sub(r"/\*.*?\*/", "", definition, flags=re.DOTALL)
# Normalize whitespace
definition = " ".join(definition.split())
# Normalize to uppercase for comparison (case-insensitive)
return definition.upper()
[docs]
def compare_indexes(
self, expected: Index, actual: Index, dialect: str = "postgresql"
) -> Optional[IndexDiff]:
"""Compare two index objects.
Args:
expected: Expected index from migrations
actual: Actual index from database
dialect: SQL dialect
Returns:
IndexDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
index_name = expected.name or actual.name
table_name = expected.table_name or actual.table_name
diff = IndexDiff(object_name=index_name, index_name=index_name, table_name=table_name)
# Compare columns
expected_cols = [self._normalize_identifier(c) for c in (expected.columns or [])]
actual_cols = [self._normalize_identifier(c) for c in (actual.columns or [])]
if expected_cols != actual_cols:
diff.columns_changed = True
diff.expected_columns = expected.columns
diff.actual_columns = actual.columns
logger.info(
f"Index '{index_name}': columns changed from {expected_cols} to {actual_cols}"
)
# Compare uniqueness
expected_unique = getattr(expected, "unique", False)
actual_unique = getattr(actual, "unique", False)
if expected_unique != actual_unique:
diff.uniqueness_changed = (expected_unique, actual_unique)
logger.info(
f"Index '{index_name}': uniqueness changed from {expected_unique} to {actual_unique}"
)
# Compare index type
expected_type = self._normalize_identifier(getattr(expected, "type", "btree"))
actual_type = self._normalize_identifier(getattr(actual, "type", "btree"))
if expected_type != actual_type:
diff.type_changed = (expected_type, actual_type)
logger.info(f"Index '{index_name}': type changed from {expected_type} to {actual_type}")
# Grammar-based: Compare MySQL ONLINE/OFFLINE
if dialect in ("mysql", "mariadb"):
expected_online = getattr(expected, "online", None)
actual_online = getattr(actual, "online", None)
if expected_online is not None and actual_online is not None:
if expected_online != actual_online:
diff.online_changed = (expected_online, actual_online)
logger.info(
f"Index '{index_name}': ONLINE/OFFLINE status changed from {expected_online} to {actual_online}"
)
# Grammar-based: Compare PostgreSQL CONCURRENTLY
if dialect in ("postgresql", "postgres"):
expected_concurrently = getattr(expected, "concurrently", False)
actual_concurrently = getattr(actual, "concurrently", False)
if expected_concurrently != actual_concurrently:
diff.concurrently_changed = (expected_concurrently, actual_concurrently)
logger.info(
f"Index '{index_name}': CONCURRENTLY status changed from {expected_concurrently} to {actual_concurrently}"
)
# Grammar-based: Compare Oracle TABLESPACE
if dialect == "oracle":
expected_tablespace = getattr(expected, "tablespace", None)
actual_tablespace = getattr(actual, "tablespace", None)
if expected_tablespace != actual_tablespace:
diff.tablespace_changed = (expected_tablespace, actual_tablespace)
logger.info(
f"Index '{index_name}': TABLESPACE changed from {expected_tablespace} to {actual_tablespace}"
)
# Calculate final diff status
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_sequences(
self, expected: Sequence, actual: Sequence, dialect: str = "postgresql"
) -> Optional[SequenceDiff]:
"""Compare two sequence objects.
Args:
expected: Expected sequence from migrations
actual: Actual sequence from database
dialect: SQL dialect
Returns:
SequenceDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
seq_name = expected.name or actual.name
diff = SequenceDiff(object_name=seq_name, sequence_name=seq_name)
# Compare start value (Sequence uses start_with)
expected_start = getattr(expected, "start_with", getattr(expected, "start_value", None))
actual_start = getattr(actual, "start_with", getattr(actual, "start_value", None))
if expected_start != actual_start and expected_start is not None:
diff.start_value_changed = (expected_start, actual_start)
logger.info(
f"Sequence '{seq_name}': start value changed from {expected_start} to {actual_start}"
)
# Compare increment (Sequence uses increment_by)
expected_inc = getattr(expected, "increment_by", getattr(expected, "increment", None))
actual_inc = getattr(actual, "increment_by", getattr(actual, "increment", None))
if expected_inc != actual_inc and expected_inc is not None:
diff.increment_changed = (expected_inc, actual_inc)
logger.info(
f"Sequence '{seq_name}': increment changed from {expected_inc} to {actual_inc}"
)
# Compare min value
expected_min = getattr(expected, "min_value", None)
actual_min = getattr(actual, "min_value", None)
if expected_min != actual_min and expected_min is not None:
diff.min_value_changed = (expected_min, actual_min)
logger.info(
f"Sequence '{seq_name}': min value changed from {expected_min} to {actual_min}"
)
# Compare max value
expected_max = getattr(expected, "max_value", None)
actual_max = getattr(actual, "max_value", None)
if expected_max != actual_max and expected_max is not None:
diff.max_value_changed = (expected_max, actual_max)
logger.info(
f"Sequence '{seq_name}': max value changed from {expected_max} to {actual_max}"
)
# Compare cycle
expected_cycle = getattr(expected, "cycle", False)
actual_cycle = getattr(actual, "cycle", False)
if expected_cycle != actual_cycle:
diff.cycle_changed = (expected_cycle, actual_cycle)
logger.info(
f"Sequence '{seq_name}': cycle changed from {expected_cycle} to {actual_cycle}"
)
# Grammar-based: Compare PostgreSQL TEMPORARY
if dialect in ("postgresql", "postgres"):
expected_temp = getattr(expected, "temp", False)
actual_temp = getattr(actual, "temp", False)
if expected_temp != actual_temp:
diff.temp_changed = (expected_temp, actual_temp)
logger.info(
f"Sequence '{seq_name}': TEMPORARY changed from {expected_temp} to {actual_temp}"
)
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_triggers(
self, expected: Trigger, actual: Trigger, dialect: str = "postgresql"
) -> Optional[TriggerDiff]:
"""Compare two trigger objects.
Args:
expected: Expected trigger from migrations
actual: Actual trigger from database
dialect: SQL dialect
Returns:
TriggerDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
trigger_name = expected.name or actual.name
table_name = expected.table_name or actual.table_name
diff = TriggerDiff(
object_name=trigger_name, trigger_name=trigger_name, table_name=table_name
)
# Compare timing
expected_timing = self._normalize_identifier(getattr(expected, "timing", ""))
actual_timing = self._normalize_identifier(getattr(actual, "timing", ""))
if expected_timing != actual_timing:
diff.timing_changed = (expected_timing, actual_timing)
logger.info(
f"Trigger '{trigger_name}': timing changed from {expected_timing} to {actual_timing}"
)
# Compare event(s)
expected_events = [self._normalize_identifier(e) for e in getattr(expected, "events", [])]
actual_events = [self._normalize_identifier(e) for e in getattr(actual, "events", [])]
# Fallback to single 'event' attribute if present
if not expected_events:
single = getattr(expected, "event", None)
if single:
expected_events = [self._normalize_identifier(single)]
if not actual_events:
single = getattr(actual, "event", None)
if single:
actual_events = [self._normalize_identifier(single)]
if expected_events != actual_events:
diff.event_changed = (expected_events, actual_events)
logger.info(
f"Trigger '{trigger_name}': event changed from {expected_events} to {actual_events}"
)
# Compare definition
expected_def = self._normalize_expression(expected.definition)
actual_def = self._normalize_expression(actual.definition)
if expected_def != actual_def:
diff.definition_changed = True
logger.info(f"Trigger '{trigger_name}': definition changed")
# Compare enabled status
expected_enabled = getattr(expected, "enabled", True)
actual_enabled = getattr(actual, "enabled", True)
if expected_enabled != actual_enabled:
diff.enabled_changed = (expected_enabled, actual_enabled)
logger.info(
f"Trigger '{trigger_name}': enabled status changed from {expected_enabled} to {actual_enabled}"
)
# Grammar-based: Compare PostgreSQL CONSTRAINT TRIGGER
if dialect in ("postgresql", "postgres"):
expected_constraint = getattr(expected, "is_constraint_trigger", False)
actual_constraint = getattr(actual, "is_constraint_trigger", False)
if expected_constraint != actual_constraint:
diff.constraint_trigger_changed = (expected_constraint, actual_constraint)
logger.info(
f"Trigger '{trigger_name}': CONSTRAINT TRIGGER status changed from {expected_constraint} to {actual_constraint}"
)
# Grammar-based: Compare MySQL definer
if dialect in ("mysql", "mariadb"):
expected_definer = getattr(expected, "definer", None)
actual_definer = getattr(actual, "definer", None)
if expected_definer != actual_definer:
diff.definer_changed = (expected_definer, actual_definer)
logger.info(
f"Trigger '{trigger_name}': definer changed from {expected_definer} to {actual_definer}"
)
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_procedures(
self, expected: Procedure, actual: Procedure, dialect: str = "postgresql"
) -> Optional[ProcedureDiff]:
"""Compare two procedure objects.
Args:
expected: Expected procedure from migrations
actual: Actual procedure from database
dialect: SQL dialect
Returns:
ProcedureDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
proc_name = expected.name or actual.name
diff = ProcedureDiff(object_name=proc_name, procedure_name=proc_name)
# Compare parameters
expected_params = self._normalize_parameters(expected.parameters)
actual_params = self._normalize_parameters(actual.parameters)
if expected_params != actual_params:
diff.parameters_changed = True
# Convert parameters to string list for diff
diff.expected_parameters = [str(p) for p in (expected.parameters or [])]
diff.actual_parameters = [str(p) for p in (actual.parameters or [])]
logger.info(
f"Procedure '{proc_name}': parameters changed from {expected_params} to {actual_params}"
)
# Compare definition
expected_def = self._normalize_expression(expected.body)
actual_def = self._normalize_expression(actual.body)
if expected_def != actual_def:
diff.definition_changed = True
logger.info(f"Procedure '{proc_name}': definition changed")
diff._calculate_diffs()
return diff if diff.has_diffs else None
def _normalize_parameters(self, parameters: Optional[List]) -> str:
"""Normalize procedure/function parameters for comparison.
Args:
parameters: List of parameter objects or strings
Returns:
Normalized parameter signature
"""
if not parameters:
return ""
# Convert to string representation and normalize
param_strs = []
for param in parameters:
if isinstance(param, str):
param_strs.append(param.upper().strip())
elif hasattr(param, "name") and hasattr(param, "data_type"):
param_str = f"{param.name} {param.data_type}".upper().strip()
param_strs.append(param_str)
return ",".join(sorted(param_strs))
[docs]
def compare_functions(
self, expected: Procedure, actual: Procedure, dialect: str = "postgresql"
) -> Optional[FunctionDiff]:
"""Compare two function objects (Procedure with is_function=True).
Args:
expected: Expected function from migrations (Procedure with is_function=True)
actual: Actual function from database (Procedure with is_function=True)
dialect: SQL dialect
Returns:
FunctionDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
func_name = expected.name or actual.name
diff = FunctionDiff(object_name=func_name, function_name=func_name)
# Compare parameters
expected_params = self._normalize_parameters(expected.parameters)
actual_params = self._normalize_parameters(actual.parameters)
if expected_params != actual_params:
diff.parameters_changed = True
# Convert parameters to string list for diff
diff.expected_parameters = [str(p) for p in (expected.parameters or [])]
diff.actual_parameters = [str(p) for p in (actual.parameters or [])]
logger.info(
f"Function '{func_name}': parameters changed from {expected_params} to {actual_params}"
)
# Compare return type
expected_return = expected.return_type.upper().strip() if expected.return_type else ""
actual_return = actual.return_type.upper().strip() if actual.return_type else ""
if expected_return != actual_return:
diff.return_type_changed = (expected_return, actual_return)
logger.info(
f"Function '{func_name}': return type changed from {expected_return} to {actual_return}"
)
# Compare definition
expected_def = self._normalize_expression(expected.body)
actual_def = self._normalize_expression(actual.body)
if expected_def != actual_def:
diff.definition_changed = True
logger.info(f"Function '{func_name}': definition changed")
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_synonyms(
self, expected: Synonym, actual: Synonym, dialect: str = "postgresql"
) -> Optional[SynonymDiff]:
"""Compare two synonym objects.
Args:
expected: Expected synonym from migrations
actual: Actual synonym from database
dialect: SQL dialect
Returns:
SynonymDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
syn_name = expected.name or actual.name
diff = SynonymDiff(object_name=syn_name, synonym_name=syn_name)
# Normalize target object names for comparison
def _normalize_target(target: Optional[str]) -> str:
"""Normalize target object name for comparison.
Handles quoted and unquoted identifiers:
- Quoted identifiers (e.g. "name", [name], `name`): Strip quotes, preserve case
- Unquoted identifiers: Apply dialect-specific case normalization
- Oracle/DB2: uppercase
- PostgreSQL/MySQL/SQL Server: lowercase
"""
if not target:
return ""
cleaned = target.strip()
is_quoted = False
# Check for and remove dialect-specific quoting
if cleaned.startswith('"') and cleaned.endswith('"'):
cleaned = cleaned[1:-1]
is_quoted = True
elif cleaned.startswith("[") and cleaned.endswith("]"):
cleaned = cleaned[1:-1]
is_quoted = True
elif cleaned.startswith("`") and cleaned.endswith("`"):
cleaned = cleaned[1:-1]
is_quoted = True
# For quoted identifiers, preserve case (case-sensitive)
# For unquoted identifiers, apply dialect-specific normalization
if not is_quoted:
if dialect.lower() in ("oracle", "db2"):
return cleaned.upper()
else:
return cleaned.lower()
return cleaned
# Compare target object
expected_target = _normalize_target(expected.target_object)
actual_target = _normalize_target(actual.target_object)
if expected_target != actual_target:
diff.target_changed = (expected.target_object, actual.target_object)
diff.expected_target = expected.target_full_name
diff.actual_target = actual.target_full_name
logger.info(
f"Synonym '{syn_name}': target changed from {expected.target_object} to {actual.target_object}"
)
# Compare target schema
expected_schema = _normalize_target(expected.target_schema)
actual_schema = _normalize_target(actual.target_schema)
if expected_schema != actual_schema:
diff.target_schema_changed = (expected.target_schema, actual.target_schema)
logger.info(
f"Synonym '{syn_name}': target schema changed from {expected.target_schema} to {actual.target_schema}"
)
# Compare target database (SQL Server)
expected_db = _normalize_target(expected.target_database)
actual_db = _normalize_target(actual.target_database)
if expected_db != actual_db:
diff.target_database_changed = (expected.target_database, actual.target_database)
logger.info(
f"Synonym '{syn_name}': target database changed from {expected.target_database} to {actual.target_database}"
)
# Compare database link (Oracle)
expected_link = _normalize_target(expected.db_link)
actual_link = _normalize_target(actual.db_link)
if expected_link != actual_link:
diff.db_link_changed = (expected.db_link, actual.db_link)
logger.info(
f"Synonym '{syn_name}': database link changed from {expected.db_link} to {actual.db_link}"
)
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_user_defined_types(
self, expected: UserDefinedType, actual: UserDefinedType, dialect: str = "postgresql"
) -> Optional[UserDefinedTypeDiff]:
"""Compare two user-defined type objects.
Args:
expected: Expected UDT from migrations
actual: Actual UDT from database
dialect: SQL dialect
Returns:
UserDefinedTypeDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
type_name = expected.name or actual.name
diff = UserDefinedTypeDiff(object_name=type_name, type_name=type_name)
# Compare type category (COMPOSITE, ENUM, DOMAIN, DISTINCT, etc.)
def _canonical_category(category: str) -> str:
mapping = {
"C": "COMPOSITE",
"R": "COMPOSITE",
"S": "COMPOSITE",
"STRUCT": "COMPOSITE",
"STRUCTURED": "COMPOSITE",
"OBJECT": "COMPOSITE", # Oracle OBJECT types are composite types
"E": "ENUM",
"ENUM": "ENUM",
"D": "DOMAIN",
"DOMAIN": "DOMAIN",
"DISTINCT": "DISTINCT",
}
value = category.upper() if category else "UNKNOWN"
return mapping.get(value, value)
expected_category = _canonical_category(expected.type_category)
actual_category = _canonical_category(actual.type_category)
if expected_category != actual_category:
diff.type_category_changed = (expected.type_category, actual.type_category)
diff.expected_type_category = expected.type_category
diff.actual_type_category = actual.type_category
logger.info(
f"User-defined type '{type_name}': category changed from {expected.type_category} to {actual.type_category}"
)
# Compare base type (for DOMAIN and DISTINCT types)
if expected.base_type or actual.base_type:
expected_base = (expected.base_type or "").upper()
actual_base = (actual.base_type or "").upper()
if expected_base != actual_base:
diff.base_type_changed = (expected.base_type, actual.base_type)
diff.expected_base_type = expected.base_type
diff.actual_base_type = actual.base_type
logger.info(
f"User-defined type '{type_name}': base type changed from {expected.base_type} to {actual.base_type}"
)
# Compare attributes (for COMPOSITE types)
if expected.is_composite and actual.is_composite:
def _normalize_attributes(
attrs: Optional[List[Dict[str, Any]]],
) -> List[Tuple[str, str]]:
normalized: List[Tuple[str, str]] = []
if not attrs:
return normalized
for attr in attrs:
name = (attr.get("name") or "").strip().lower()
attr_type_raw = (attr.get("type") or "").strip()
normalized_type = self.type_normalizer.normalize(
attr_type_raw,
dialect,
)
normalized.append((name, normalized_type.upper()))
return normalized
expected_attrs = _normalize_attributes(expected.attributes)
actual_attrs = _normalize_attributes(actual.attributes)
if expected_attrs != actual_attrs:
diff.attributes_changed = True
diff.expected_attributes = expected.attributes
diff.actual_attributes = actual.attributes
logger.info(f"User-defined type '{type_name}': attributes changed")
logger.info(f" Expected attributes (normalized): {expected_attrs}")
logger.info(f" Actual attributes (normalized): {actual_attrs}")
logger.info(f" Expected attributes (raw): {expected.attributes}")
logger.info(f" Actual attributes (raw): {actual.attributes}")
# Compare enum values (for ENUM types)
if expected.is_enum and actual.is_enum:
expected_values = sorted(expected.enum_values or [])
actual_values = sorted(actual.enum_values or [])
if expected_values != actual_values:
diff.enum_values_changed = True
diff.expected_enum_values = expected.enum_values
diff.actual_enum_values = actual.enum_values
logger.info(f"User-defined type '{type_name}': enum values changed")
# Compare definition (for types with explicit definitions)
if expected.definition or actual.definition:
expected_def = (expected.definition or "").strip().upper()
actual_def = (actual.definition or "").strip().upper()
if expected_def != actual_def:
diff.definition_changed = True
logger.info(f"User-defined type '{type_name}': definition changed")
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_packages(
self, expected: Package, actual: Package, dialect: str = "oracle"
) -> Optional[PackageDiff]:
"""Compare two package objects (Oracle).
Args:
expected: Expected package from migrations
actual: Actual package from database
dialect: SQL dialect (typically oracle)
Returns:
PackageDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
package_name = expected.name or actual.name
diff = PackageDiff(object_name=package_name, package_name=package_name)
# Compare package specification
expected_spec = self._normalize_package_code(expected.spec)
actual_spec = self._normalize_package_code(actual.spec)
if expected_spec != actual_spec:
diff.spec_changed = True
diff.expected_spec = expected.spec
diff.actual_spec = actual.spec
logger.info(f"Package '{package_name}': specification changed")
# Compare package body
expected_body = self._normalize_package_code(expected.body)
actual_body = self._normalize_package_code(actual.body)
if expected_body != actual_body:
diff.body_changed = True
diff.expected_body = expected.body
diff.actual_body = actual.body
logger.info(f"Package '{package_name}': body changed")
diff._calculate_diffs()
return diff if diff.has_diffs else None
def _normalize_package_code(self, code: Optional[str]) -> str:
"""Normalize package specification or body for comparison.
Removes extra whitespace, normalizes case of keywords, removes comments.
Packages are typically Oracle-specific and case-insensitive for keywords.
Args:
code: Package specification or body code
Returns:
Normalized code
"""
if not code:
return ""
# Remove comments
import re
# Remove single-line comments (-- comment)
code = re.sub(r"--.*$", "", code, flags=re.MULTILINE)
# Remove multi-line comments (/* comment */)
code = re.sub(r"/\*.*?\*/", "", code, flags=re.DOTALL)
# Normalize whitespace (collapse multiple spaces/newlines into single space)
code = " ".join(code.split())
# Remove spaces around parentheses and other punctuation for consistent comparison
code = re.sub(r"\s*\(\s*", "(", code)
code = re.sub(r"\s*\)\s*", ")", code)
code = re.sub(r"\s*,\s*", ",", code)
code = re.sub(r"\s*;\s*", ";", code)
code = re.sub(r"\s*\.\s*", ".", code) # Handle periods/dots (e.g., schema.table)
# Normalize to uppercase for comparison (Oracle is case-insensitive for keywords)
# but preserves case for identifiers in quotes
return code.upper()
[docs]
def compare_extensions(
self, expected: Extension, actual: Extension, dialect: str = "postgresql"
) -> Optional[ExtensionDiff]:
"""Compare two extension objects (PostgreSQL).
Args:
expected: Expected extension from migrations
actual: Actual extension from database
dialect: SQL dialect (typically postgresql)
Returns:
ExtensionDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
extension_name = expected.name or actual.name
diff = ExtensionDiff(object_name=extension_name, extension_name=extension_name)
# Compare version
expected_version = (expected.version or "").strip()
actual_version = (actual.version or "").strip()
if expected_version != actual_version:
diff.version_changed = (expected.version, actual.version)
diff.expected_version = expected.version
diff.actual_version = actual.version
logger.info(
f"Extension '{extension_name}': version changed from {expected.version} to {actual.version}"
)
# Compare schema
expected_schema = self._normalize_identifier(expected.schema)
actual_schema = self._normalize_identifier(actual.schema)
if expected_schema != actual_schema:
diff.schema_changed = (expected.schema, actual.schema)
logger.info(
f"Extension '{extension_name}': schema changed from {expected.schema} to {actual.schema}"
)
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_events(
self, expected: Event, actual: Event, dialect: str = "mysql"
) -> Optional[EventDiff]:
"""Compare two event objects (MySQL).
Args:
expected: Expected event from migrations
actual: Actual event from database
dialect: SQL dialect (typically mysql)
Returns:
EventDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
event_name = expected.name or actual.name
diff = EventDiff(object_name=event_name, event_name=event_name)
# Compare definition (normalize for comparison)
expected_def = self._normalize_expression(expected.definition)
actual_def = self._normalize_expression(actual.definition)
if expected_def != actual_def:
diff.definition_changed = True
logger.info(f"Event '{event_name}': definition changed")
# Compare schedule
expected_schedule = self._normalize_expression(expected.schedule)
actual_schedule = self._normalize_expression(actual.schedule)
if expected_schedule != actual_schedule:
diff.schedule_changed = (expected.schedule, actual.schedule)
logger.info(
f"Event '{event_name}': schedule changed from {expected.schedule} to {actual.schedule}"
)
# Compare enabled status
if expected.enabled != actual.enabled:
diff.enabled_changed = (expected.enabled, actual.enabled)
logger.info(
f"Event '{event_name}': enabled status changed from {expected.enabled} to {actual.enabled}"
)
# Compare event type
expected_type = (expected.event_type or "").upper()
actual_type = (actual.event_type or "").upper()
if expected_type != actual_type:
diff.event_type_changed = (expected.event_type, actual.event_type)
logger.info(
f"Event '{event_name}': event type changed from {expected.event_type} to {actual.event_type}"
)
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_database_links(
self, expected: DatabaseLink, actual: DatabaseLink, dialect: str = "oracle"
) -> Optional[DatabaseLinkDiff]:
"""Compare two database link objects (Oracle).
Args:
expected: Expected database link from migrations
actual: Actual database link from database
dialect: SQL dialect (typically oracle)
Returns:
DatabaseLinkDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
link_name = expected.name or actual.name
diff = DatabaseLinkDiff(object_name=link_name, link_name=link_name)
# Compare host/connect string
expected_host = self._normalize_identifier(expected.host or expected.connect_string)
actual_host = self._normalize_identifier(actual.host or actual.connect_string)
if expected_host != actual_host:
diff.host_changed = (
expected.host or expected.connect_string,
actual.host or actual.connect_string,
)
diff.expected_host = expected.host or expected.connect_string
diff.actual_host = actual.host or actual.connect_string
logger.info(
f"Database link '{link_name}': host/connect string changed from {diff.expected_host} to {diff.actual_host}"
)
# Compare username
expected_user = self._normalize_identifier(expected.username)
actual_user = self._normalize_identifier(actual.username)
if expected_user != actual_user:
diff.username_changed = (expected.username, actual.username)
logger.info(
f"Database link '{link_name}': username changed from {expected.username} to {actual.username}"
)
# Compare public/private status
if expected.public != actual.public:
diff.public_changed = (expected.public, actual.public)
logger.info(
f"Database link '{link_name}': public status changed from {expected.public} to {actual.public}"
)
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_linked_servers(
self, expected: LinkedServer, actual: LinkedServer, dialect: str = "sqlserver"
) -> Optional[LinkedServerDiff]:
"""Compare two linked server objects (SQL Server).
Args:
expected: Expected linked server from migrations
actual: Actual linked server from database
dialect: SQL dialect (typically sqlserver)
Returns:
LinkedServerDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
server_name = expected.name or actual.name
diff = LinkedServerDiff(object_name=server_name, server_name=server_name)
# Compare product
expected_product = self._normalize_identifier(expected.product)
actual_product = self._normalize_identifier(actual.product)
if expected_product != actual_product:
diff.product_changed = (expected.product, actual.product)
logger.info(
f"Linked server '{server_name}': product changed from {expected.product} to {actual.product}"
)
# Compare provider
expected_provider = self._normalize_identifier(expected.provider)
actual_provider = self._normalize_identifier(actual.provider)
if expected_provider != actual_provider:
diff.provider_changed = (expected.provider, actual.provider)
logger.info(
f"Linked server '{server_name}': provider changed from {expected.provider} to {actual.provider}"
)
# Compare data source
expected_datasrc = self._normalize_identifier(expected.data_source)
actual_datasrc = self._normalize_identifier(actual.data_source)
if expected_datasrc != actual_datasrc:
diff.data_source_changed = (expected.data_source, actual.data_source)
logger.info(
f"Linked server '{server_name}': data source changed from {expected.data_source} to {actual.data_source}"
)
# Compare catalog
expected_catalog = self._normalize_identifier(expected.catalog)
actual_catalog = self._normalize_identifier(actual.catalog)
if expected_catalog != actual_catalog:
diff.catalog_changed = (expected.catalog, actual.catalog)
logger.info(
f"Linked server '{server_name}': catalog changed from {expected.catalog} to {actual.catalog}"
)
# Compare username
expected_user = self._normalize_identifier(expected.username)
actual_user = self._normalize_identifier(actual.username)
if expected_user != actual_user:
diff.username_changed = (expected.username, actual.username)
logger.info(
f"Linked server '{server_name}': username changed from {expected.username} to {actual.username}"
)
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_foreign_data_wrappers(
self, expected: ForeignDataWrapper, actual: ForeignDataWrapper, dialect: str = "postgresql"
) -> Optional[ForeignDataWrapperDiff]:
"""Compare two foreign data wrapper objects (PostgreSQL).
Args:
expected: Expected FDW from migrations
actual: Actual FDW from database
dialect: SQL dialect (typically postgresql)
Returns:
ForeignDataWrapperDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
fdw_name = expected.name or actual.name
diff = ForeignDataWrapperDiff(object_name=fdw_name, fdw_name=fdw_name)
# Compare handler
expected_handler = self._normalize_identifier(expected.handler)
actual_handler = self._normalize_identifier(actual.handler)
if expected_handler != actual_handler:
diff.handler_changed = (expected.handler, actual.handler)
logger.info(
f"Foreign data wrapper '{fdw_name}': handler changed from {expected.handler} to {actual.handler}"
)
# Compare validator
expected_validator = self._normalize_identifier(expected.validator)
actual_validator = self._normalize_identifier(actual.validator)
if expected_validator != actual_validator:
diff.validator_changed = (expected.validator, actual.validator)
logger.info(
f"Foreign data wrapper '{fdw_name}': validator changed from {expected.validator} to {actual.validator}"
)
# Compare options
if expected.options != actual.options:
diff.options_changed = (expected.options, actual.options)
logger.info(f"Foreign data wrapper '{fdw_name}': options changed")
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_foreign_servers(
self, expected: ForeignServer, actual: ForeignServer, dialect: str = "postgresql"
) -> Optional[ForeignServerDiff]:
"""Compare two foreign server objects (PostgreSQL).
Args:
expected: Expected foreign server from migrations
actual: Actual foreign server from database
dialect: SQL dialect (typically postgresql)
Returns:
ForeignServerDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
server_name = expected.name or actual.name
diff = ForeignServerDiff(object_name=server_name, server_name=server_name)
# Compare FDW name
expected_fdw = self._normalize_identifier(expected.fdw_name)
actual_fdw = self._normalize_identifier(actual.fdw_name)
if expected_fdw != actual_fdw:
diff.fdw_changed = (expected.fdw_name, actual.fdw_name)
logger.info(
f"Foreign server '{server_name}': FDW changed from {expected.fdw_name} to {actual.fdw_name}"
)
# Compare host
expected_host = self._normalize_identifier(expected.host)
actual_host = self._normalize_identifier(actual.host)
if expected_host != actual_host:
diff.host_changed = (expected.host, actual.host)
logger.info(
f"Foreign server '{server_name}': host changed from {expected.host} to {actual.host}"
)
# Compare port
if expected.port != actual.port:
diff.port_changed = (expected.port, actual.port)
logger.info(
f"Foreign server '{server_name}': port changed from {expected.port} to {actual.port}"
)
# Compare database name
expected_dbname = self._normalize_identifier(expected.dbname)
actual_dbname = self._normalize_identifier(actual.dbname)
if expected_dbname != actual_dbname:
diff.dbname_changed = (expected.dbname, actual.dbname)
logger.info(
f"Foreign server '{server_name}': dbname changed from {expected.dbname} to {actual.dbname}"
)
# Compare options (excluding host, port, dbname which are tracked separately)
expected_opts = {
k: v for k, v in expected.options.items() if k not in ["host", "port", "dbname"]
}
actual_opts = {
k: v for k, v in actual.options.items() if k not in ["host", "port", "dbname"]
}
if expected_opts != actual_opts:
diff.options_changed = (expected_opts, actual_opts)
logger.info(f"Foreign server '{server_name}': options changed")
diff._calculate_diffs()
return diff if diff.has_diffs else None
[docs]
def compare_modules(
self, expected: Module, actual: Module, dialect: str = "db2"
) -> Optional[ModuleDiff]:
"""Compare two DB2 module objects.
Args:
expected: Expected module from migrations
actual: Actual module from database
dialect: SQL dialect (typically db2)
Returns:
ModuleDiff if differences found, None otherwise
"""
import logging
logger = logging.getLogger(__name__)
module_name = expected.name or actual.name
diff = ModuleDiff(object_name=module_name, module_name=module_name)
# Normalize and compare module definitions
expected_def = self._normalize_module_code(expected.definition)
actual_def = self._normalize_module_code(actual.definition)
if expected_def != actual_def:
diff.definition_changed = True
logger.info(f"Module '{module_name}': definition changed")
diff._calculate_diffs()
return diff if diff.has_diffs else None
def _normalize_module_code(self, code: str) -> str:
"""Normalize DB2 module code for comparison.
Similar to package normalization but for DB2 modules.
Args:
code: Module definition code
Returns:
Normalized code
"""
if not code:
return ""
# Use the same normalization as packages since modules are conceptually similar
normalized = self._normalize_package_code(code)
return normalized