"""Extension SQL model class (PostgreSQL-specific)."""
from typing import Any, Optional
from sqlmeta.base import SqlObject, SqlObjectType
[docs]
class Extension(SqlObject):
"""
Represents a database extension (PostgreSQL-specific).
Extensions are add-on modules that provide additional functionality
to the database, such as PostGIS for geographic data, pgcrypto for
cryptographic functions, or pg_trgm for trigram text search.
"""
[docs]
def __init__(
self,
name: str,
version: Optional[str] = None,
schema: Optional[str] = None,
description: Optional[str] = None,
relocatable: bool = False,
dialect: Optional[str] = None,
):
"""Initialize an extension.
Args:
name: Extension name
version: Extension version (optional)
schema: Schema where the extension is installed (optional)
description: Extension description (optional)
relocatable: Whether the extension can be relocated to another schema
dialect: SQL dialect (typically 'postgresql')
"""
# Extensions are a PostgreSQL-specific feature
super().__init__(name, SqlObjectType.EXTENSION, schema, dialect or "postgresql")
self.version = version
self.description = description
self.relocatable = relocatable
@property
def create_statement(self) -> str:
"""
Generate CREATE EXTENSION statement.
Returns:
PostgreSQL CREATE EXTENSION statement
"""
stmt = f"CREATE EXTENSION IF NOT EXISTS {self.format_identifier(self.name)}"
# Add schema if specified
if self.schema:
stmt += f"\nSCHEMA {self.format_identifier(self.schema)}"
# Add version if specified
if self.version:
stmt += f"\nVERSION '{self.version}'"
return stmt
@property
def drop_statement(self) -> str:
"""
Generate DROP EXTENSION statement.
Returns:
PostgreSQL DROP EXTENSION statement
"""
return f"DROP EXTENSION IF EXISTS {self.format_identifier(self.name)}"
[docs]
def __str__(self) -> str:
"""Return string representation of the extension."""
info = f"EXTENSION {self.name}"
if self.version:
info += f" (v{self.version})"
if self.description:
info += f" - {self.description}"
return info
[docs]
def __eq__(self, other: Any) -> bool:
"""Check if two extensions are equal."""
if not isinstance(other, Extension):
return False
return super().__eq__(other) and (self.version or "") == (other.version or "")
[docs]
def __hash__(self) -> int:
"""Return hash of the extension."""
return hash(
(
self.name.lower(),
self.object_type,
(self.schema or "").lower(),
(self.version or ""),
)
)