import datetime
import logging
import traceback
from queue import Queue
from sqlalchemy import MetaData, Table, and_
from sqlalchemy.exc import OperationalError
from sqlalchemy.sql import select
from bonobo.config import Configurable, ContextProcessor, Option, Service, use_context, use_raw_input
from bonobo.errors import UnrecoverableError
from bonobo_sqlalchemy.constants import INSERT, UPDATE
from bonobo_sqlalchemy.errors import ProhibitedOperationError
logger = logging.getLogger(__name__)
[docs]@use_context
@use_raw_input
class InsertOrUpdate(Configurable):
"""
TODO: fields vs columns, choose a name (XXX)
Maybe the obvious choice is to keep "field" for row fields, as it's already the name used by bonobo, and call the
database columns "columns".
"""
table_name = Option(str, positional=True) # type: str
fetch_columns = Option(tuple, required=False, default=()) # type: tuple
insert_only_fields = Option(tuple, required=False, default=()) # type: tuple
discriminant = Option(tuple, required=False, default=('id', )) # type: tuple
created_at_field = Option(str, required=False, default='created_at') # type: str
updated_at_field = Option(str, required=False, default='updated_at') # type: str
allowed_operations = Option(
tuple, required=False, default=(
INSERT,
UPDATE,
)
) # type: tuple
buffer_size = Option(int, required=False, default=1000) # type: int
engine = Service('sqlalchemy.engine') # type: str
@ContextProcessor
def create_connection(self, context, *, engine):
"""
This context processor creates an sqlalchemy connection for use during the lifetime of this transformation's
execution.
:param engine:
"""
try:
connection = engine.connect()
except OperationalError as exc:
raise UnrecoverableError('Could not create SQLAlchemy connection: {}.'.format(str(exc).replace('\n', ''))
) from exc
with connection:
yield connection
@ContextProcessor
def create_table(self, context, connection, *, engine):
"""SQLAlchemy table object, using metadata autoloading from database to avoid the need of column definitions."""
yield Table(self.table_name, MetaData(), autoload=True, autoload_with=engine)
@ContextProcessor
def create_buffer(self, context, connection, table, *, engine):
"""
This context processor creates a "buffer" of yet to be persisted elements, and commits the remaining elements
when the transformation ends.
:param engine:
:param connection:
"""
buffer = yield Queue()
try:
for row in self.commit(table, connection, buffer, force=True):
context.send(row)
except Exception as exc:
logger.exception('Flush fail')
raise UnrecoverableError('Flushing query buffer failed.') from exc
def __call__(self, connection, table, buffer, context, row, engine):
"""
Main transformation method, pushing a row to the "yet to be processed elements" queue and commiting if necessary.
:param engine:
:param connection:
:param buffer:
:param row:
"""
buffer.put(row)
yield from self.commit(table, connection, buffer)
[docs] def commit(self, table, connection, buffer, force=False):
if force or (buffer.qsize() >= self.buffer_size):
with connection.begin():
while buffer.qsize() > 0:
try:
yield self.insert_or_update(table, connection, buffer.get())
except Exception as exc:
yield exc
[docs] def insert_or_update(self, table, connection, row):
""" Actual database load transformation logic, without the buffering / transaction logic.
"""
# find line, if it exist
dbrow = self.find(connection, table, row)
# TODO XXX use actual database function instead of this stupid thing
now = datetime.datetime.now()
column_names = table.columns.keys()
# UpdatedAt field configured ? Let's set the value in source hash
if self.updated_at_field in column_names:
row[self.updated_at_field] = now # XXX not pure ...
# Update logic
if dbrow:
if not UPDATE in self.allowed_operations:
raise ProhibitedOperationError('UPDATE operations are not allowed by this transformation.')
query = table.update().values(
**{col: row.get(col)
for col in self.get_columns_for(column_names, row, dbrow)}
).where(and_(*(getattr(table.c, col) == row.get(col) for col in self.discriminant)))
# INSERT
else:
if not INSERT in self.allowed_operations:
raise ProhibitedOperationError('INSERT operations are not allowed by this transformation.')
if self.created_at_field in column_names:
row[self.created_at_field] = now # XXX UNPURE
else:
if self.created_at_field in row:
del row[self.created_at_field] # UNPURE
query = table.insert().values(**{col: row.get(col) for col in self.get_columns_for(column_names, row)})
# Execute
try:
connection.execute(query)
except Exception:
logger.exception('Rollback...')
connection.rollback()
raise
# Increment stats TODO
# if dbrow:
# self._output._special_stats[UPDATE] += 1
# else:
# self._output._special_stats[INSERT] += 1
# If user required us to fetch some columns, let's query again to get their actual values.
if self.fetch_columns and len(self.fetch_columns):
if not dbrow:
dbrow = self.find(row)
if not dbrow:
raise ValueError('Could not find matching row after load.')
for alias, column in self.fetch_columns.items():
row[alias] = dbrow[column]
return row
[docs] def find(self, connection, table, row):
sql = select([table]).where(and_(*(getattr(table.c, col) == row.get(col)
for col in self.discriminant))).limit(1)
row = connection.execute(sql).fetchone()
return dict(row) if row else None
[docs] def get_columns_for(self, column_names, row, dbrow=None):
"""Retrieve list of table column names for which we have a value in given hash.
"""
if dbrow:
candidates = filter(lambda col: col not in self.insert_only_fields, column_names)
else:
candidates = column_names
try:
fields = row._fields
except AttributeError as exc:
fields = list(row.keys())
return set(candidates).intersection(fields)
[docs] def add_fetch_columns(self, *columns, **aliased_columns):
self.fetch_columns = {
**self.fetch_columns,
**aliased_columns,
}
for column in columns:
self.fetch_columns[column] = column