package database import ( "context" "errors" "github.com/jmoiron/sqlx" "github.com/rs/zerolog" ) // TxState is a enum that describes the states of a transaction. type TxState int const ( // TxOpened means the transaction begun and is not yet closed. TxOpened TxState = iota // TxCommitted means the transaction was committed (maybe with an error). TxCommitted // TxRolledback means the transaction was rolled back (maybe with an error). TxRolledback ) // Begin begins a transaction and returns a *database.Tx. func Begin(ctx context.Context, db *sqlx.DB) (*Tx, error) { tx, err := db.BeginTxx(ctx, nil) if err != nil { return nil, err } return &Tx{tx, TxOpened}, nil } // Tx wraps a sqlx.Tx. type Tx struct { *sqlx.Tx state TxState } // Commit commits the transaction. func (tx *Tx) Commit() error { if tx.state != TxOpened { return errors.New("transaction is not open") } tx.state = TxCommitted return tx.Tx.Commit() } // Rollback rollbacks the transaction. func (tx *Tx) Rollback() error { if tx.state != TxOpened { return errors.New("transaction is not open") } tx.state = TxRolledback return tx.Tx.Rollback() } // LoggedRollback rollbacks the transaction and log the error if it fails. func (tx *Tx) LoggedRollback(log zerolog.Logger) { if err := tx.Rollback(); err != nil { log.Err(err).Msg("rollback failed") } } // RollbackIfOpened rollbacks the transaction if it is still opened. func (tx *Tx) RollbackIfOpened(log zerolog.Logger) { if tx.state == TxOpened { tx.LoggedRollback(log) } }