SQLAlchemy – performing a bulk upsert (if exists, update, else insert) in postgresql

Question:

I am trying to write a bulk upsert in python using the SQLAlchemy module (not in SQL!).

I am getting the following error on a SQLAlchemy add:

sqlalchemy.exc.IntegrityError: (IntegrityError) duplicate key value violates unique constraint "posts_pkey"
DETAIL:  Key (id)=(TEST1234) already exists.

I have a table called posts with a primary key on the id column.

In this example, I already have a row in the db with id=TEST1234. When I attempt to db.session.add() a new posts object with the id set to TEST1234, I get the error above. I was under the impression that if the primary key already exists, the record would get updated.

How can I upsert with Flask-SQLAlchemy based on primary key alone? Is there a simple solution?

If there is not, I can always check for and delete any record with a matching id, and then insert the new record, but that seems expensive for my situation, where I do not expect many updates.

Asked By: mgoldwasser

||

Answers:

There is an upsert-esque operation in SQLAlchemy:

db.session.merge()

After I found this command, I was able to perform upserts, but it is worth mentioning that this operation is slow for a bulk “upsert”.

The alternative is to get a list of the primary keys you would like to upsert, and query the database for any matching ids:

# Imagine that post1, post5, and post1000 are posts objects with ids 1, 5 and 1000 respectively
# The goal is to "upsert" these posts.
# we initialize a dict which maps id to the post object

my_new_posts = {1: post1, 5: post5, 1000: post1000} 

for each in posts.query.filter(posts.id.in_(my_new_posts.keys())).all():
    # Only merge those posts which already exist in the database
    db.session.merge(my_new_posts.pop(each.id))

# Only add those posts which did not exist in the database 
db.session.add_all(my_new_posts.values())

# Now we commit our modifications (merges) and inserts (adds) to the database!
db.session.commit()
Answered By: mgoldwasser

An alternative approach using compilation extension (https://docs.sqlalchemy.org/en/13/core/compiler.html):

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import Insert

@compiles(Insert)
def compile_upsert(insert_stmt, compiler, **kwargs):
    """
    converts every SQL insert to an upsert  i.e;
    INSERT INTO test (foo, bar) VALUES (1, 'a')
    becomes:
    INSERT INTO test (foo, bar) VALUES (1, 'a') ON CONFLICT(foo) DO UPDATE SET (bar = EXCLUDED.bar)
    (assuming foo is a primary key)
    :param insert_stmt: Original insert statement
    :param compiler: SQL Compiler
    :param kwargs: optional arguments
    :return: upsert statement
    """
    pk = insert_stmt.table.primary_key
    insert = compiler.visit_insert(insert_stmt, **kwargs)
    ondup = f'ON CONFLICT ({",".join(c.name for c in pk)}) DO UPDATE SET'
    updates = ', '.join(f"{c.name}=EXCLUDED.{c.name}" for c in insert_stmt.table.columns)
    upsert = ' '.join((insert, ondup, updates))
    return upsert

This should ensure that all insert statements behave as upserts. This implementation is in Postgres dialect, but it should be fairly easy to modify for MySQL dialect.

Answered By: danielcahall

You can leverage the on_conflict_do_update variant. A simple example would be the following:

from sqlalchemy.dialects.postgresql import insert

class Post(Base):
    """
    A simple class for demonstration
    """

    id = Column(Integer, primary_key=True)
    title = Column(Unicode)

# Prepare all the values that should be "upserted" to the DB
values = [
    {"id": 1, "title": "mytitle 1"},
    {"id": 2, "title": "mytitle 2"},
    {"id": 3, "title": "mytitle 3"},
    {"id": 4, "title": "mytitle 4"},
]

stmt = insert(Post).values(values)
stmt = stmt.on_conflict_do_update(
    # Let's use the constraint name which was visible in the original posts error msg
    constraint="post_pkey",

    # The columns that should be updated on conflict
    set_={
        "title": stmt.excluded.title
    }
)
session.execute(stmt)

See the Postgres docs for more details about ON CONFLICT DO UPDATE.

See the SQLAlchemy docs for more details about on_conflict_do_update.

Side-Note on duplicated column names

The above code uses the column names as dict keys both in the values list and the argument to set_. If the column-name is changed in the class-definition this needs to be changed everywhere or it will break. This can be avoided by accessing the column definitions, making the code a bit uglier, but more robust:

coldefs = Post.__table__.c

values = [
    {coldefs.id.name: 1, coldefs.title.name: "mytitlte 1"},
    ...
]

stmt = stmt.on_conflict_do_update(
    ...
    set_={
        coldefs.title.name: stmt.excluded.title
        ...
    }
)
Answered By: exhuma

This is not the safest method, but it is very simple and very fast. I was just trying to selectively overwrite a portion of a table. I deleted the known rows that I knew would conflict and then I appended the new rows from a pandas dataframe. Your pandas dataframe column names will need to match your sql table column names.

eng = create_engine('postgresql://...')
conn = eng.connect()

conn.execute("DELETE FROM my_table WHERE col = %s", val)
df.to_sql('my_table', con=eng, if_exists='append')
Answered By: user1071182

I started looking at this and I think I’ve found a pretty efficient way to do upserts in sqlalchemy with a mix of bulk_insert_mappings and bulk_update_mappings instead of merge.

import time
import sqlite3

from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import Column, Integer, String, create_engine
from sqlalchemy.orm import scoped_session, sessionmaker
from contextlib import contextmanager


engine = None
Session = sessionmaker()
Base = declarative_base()


def creat_new_database(db_name="sqlite:///bulk_upsert_sqlalchemy.db"):
    global engine
    engine = create_engine(db_name, echo=False)
    local_session = scoped_session(Session)
    local_session.remove()
    local_session.configure(bind=engine, autoflush=False, expire_on_commit=False)
    Base.metadata.drop_all(engine)
    Base.metadata.create_all(engine)


@contextmanager
def db_session():
    local_session = scoped_session(Session)
    session = local_session()

    session.expire_on_commit = False

    try:
        yield session
    except BaseException:
        session.rollback()
        raise
    finally:
        session.close()


class Customer(Base):
    __tablename__ = "customer"
    id = Column(Integer, primary_key=True)
    name = Column(String(255))


def bulk_upsert_mappings(customers):

    entries_to_update = []
    entries_to_put = []
    with db_session() as sess:
        t0 = time.time()

        # Find all customers that needs to be updated and build mappings
        for each in (
            sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
        ):
            customer = customers.pop(each.id)
            entries_to_update.append({"id": customer["id"], "name": customer["name"]})

        # Bulk mappings for everything that needs to be inserted
        for customer in customers.values():
            entries_to_put.append({"id": customer["id"], "name": customer["name"]})

        sess.bulk_insert_mappings(Customer, entries_to_put)
        sess.bulk_update_mappings(Customer, entries_to_update)
        sess.commit()

    print(
        "Total time for upsert with MAPPING update "
        + str(len(customers))
        + " records "
        + str(time.time() - t0)
        + " sec"
        + " inserted : "
        + str(len(entries_to_put))
        + " - updated : "
        + str(len(entries_to_update))
    )


def bulk_upsert_merge(customers):

    entries_to_update = 0
    entries_to_put = []
    with db_session() as sess:
        t0 = time.time()

        # Find all customers that needs to be updated and merge
        for each in (
            sess.query(Customer.id).filter(Customer.id.in_(customers.keys())).all()
        ):
            values = customers.pop(each.id)
            sess.merge(Customer(id=values["id"], name=values["name"]))
            entries_to_update += 1

        # Bulk mappings for everything that needs to be inserted
        for customer in customers.values():
            entries_to_put.append({"id": customer["id"], "name": customer["name"]})

        sess.bulk_insert_mappings(Customer, entries_to_put)
        sess.commit()

    print(
        "Total time for upsert with MERGE update "
        + str(len(customers))
        + " records "
        + str(time.time() - t0)
        + " sec"
        + " inserted : "
        + str(len(entries_to_put))
        + " - updated : "
        + str(entries_to_update)
    )


if __name__ == "__main__":

    batch_size = 10000

    # Only inserts
    customers_insert = {
        i: {"id": i, "name": "customer_" + str(i)} for i in range(batch_size)
    }

    # 50/50 inserts update
    customers_upsert = {
        i: {"id": i, "name": "customer_2_" + str(i)}
        for i in range(int(batch_size / 2), batch_size + int(batch_size / 2))
    }

    creat_new_database()
    bulk_upsert_mappings(customers_insert.copy())
    bulk_upsert_mappings(customers_upsert.copy())
    bulk_upsert_mappings(customers_insert.copy())

    creat_new_database()
    bulk_upsert_merge(customers_insert.copy())
    bulk_upsert_merge(customers_upsert.copy())
    bulk_upsert_merge(customers_insert.copy())

The results for the benchmark:

Total time for upsert with MAPPING: 0.17138004302978516 sec inserted : 10000 - updated : 0
Total time for upsert with MAPPING: 0.22074174880981445 sec inserted : 5000 - updated : 5000
Total time for upsert with MAPPING: 0.22307634353637695 sec inserted : 0 - updated : 10000
Total time for upsert with MERGE: 0.1724097728729248 sec inserted : 10000 - updated : 0
Total time for upsert with MERGE: 7.852903842926025 sec inserted : 5000 - updated : 5000
Total time for upsert with MERGE: 15.11970829963684 sec inserted : 0 - updated : 10000
Answered By: Emil Wåreus

I know this is kind of late, but I have built on the answer given by @Emil Wåreusand turned it into a function that can be used on any model (table),

def upsert_data(self, entries, model, key):
    entries_to_update = []
    entries_to_insert = []
    
    # get all entries to be updated
    for each in session.query(model).filter(getattr(model, key).in_(entries.keys())).all():
        entry = entries.pop(str(getattr(each, key)))
        entries_to_update.append(entry)
        
    # get all entries to be inserted
    for entry in entries.values():
        entries_to_insert.append(entry)

    session.bulk_insert_mappings(model, entries_to_insert)
    session.bulk_update_mappings(model, entries_to_update)

    session.commit()

entries should be a dictionary, with the primary key values as the keys, and the values should be mappings (mappings of the values against the columns of the database).

model is the ORM model that you want to upsert to.

key is the primary key of the table.

You can even use this function to get the model for the table you want to insert to from a string,

def get_table(self, table_name):
    for c in self.base._decl_class_registry.values():
        if hasattr(c, '__tablename__') and c.__tablename__ == table_name:
            return c

Using this, you can just pass the name of the table as a string to the upsert_data function,

def upsert_data(self, entries, table, key):
    model = get_table(table)
    entries_to_update = []
    entries_to_insert = []
    
    # get all entries to be updated
    for each in session.query(model).filter(getattr(model, key).in_(entries.keys())).all():
        entry = entries.pop(str(getattr(each, key)))
        entries_to_update.append(entry)
        
    # get all entries to be inserted
    for entry in entries.values():
        entries_to_insert.append(entry)

    session.bulk_insert_mappings(model, entries_to_insert)
    session.bulk_update_mappings(model, entries_to_update)

    session.commit()
Answered By: Minura Punchihewa