使用sqlx执行事务

使用sqlx执行事务

August 19, 2023
数据库
数据库

前文我们已经介绍了sqlx的基本使用方法,今天我们来看一下怎样使用sqlx库来执行事务。

数据库:mysql。

准备 #

创建一张用户账户表:

CREATE TABLE `user_account` (
  `id` int unsigned NOT NULL AUTO_INCREMENT,
  `user_id` int unsigned NOT NULL COMMENT '用户id',
  `balance` decimal(10,0) NOT NULL COMMENT '账户余额。',
  `created_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '添加时间',
  `modified_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
  PRIMARY KEY (`id`),
  UNIQUE KEY `udx_userid` (`user_id`) USING BTREE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci

我们向这张表中添加两条记录: 初始化数据 可以看到,表中现在有两个用户,id分别为1和2,账户余额都是100.

在实际中,一个用户可能有多个账户,这里为了简单起见,只创建了一个账户。且使用了唯一索引。本文后面部分说到用户1和用户2对应的账户记录时,均是指id=1和id=2的这两条记录。

事务 #

我们现在进行一个事务: 用户1向用户2转账10元。

type UserAccount struct {
	Id         int64     `db:"id"`
	UserId     int64     `db:"user_id"`
	Balance    float64   `db:"balance"` // 账户余额
	CreatedAt  time.Time `db:"created_at"`
	ModifiedAt time.Time `db:"modified_at"`
}

// GetUserAccountById 根据id获取用户账户
func GetUserAccountById(db *sqlx.DB, id int64) (*UserAccount, error) {
	sql := `SELECT * FROM user_account WHERE id = ?`
	userAccount := new(UserAccount)
	err := db.Get(userAccount, sql, id)
	if err != nil {
		return nil, err
	}

	return userAccount, nil
}

// UpdateUserAccountBalance 更新用户账户余额,balance为正数表示增加,为负数表示减少
func UpdateUserAccountBalance(db *sqlx.Tx, id int64, balance float64) error {
	sql := `UPDATE user_account SET balance = balance + ? WHERE id = ?`
	_, err := db.Exec(sql, balance, id)
	if err != nil {
		return err
	}

	return nil
}

// Transaction 包装事务方法
func Transaction(db *sqlx.DB, fn func(tx *sqlx.Tx) error) error {
	var tx *sqlx.Tx
	var err error

	tx, err = db.Beginx()

	if err != nil {
		return errors.Wrap(err, "开启事务失败")
	}

	defer func() {
		if err != nil {
			if err1 := tx.Rollback(); err1 != nil {
				err = errors.WithMessagef(err, "execute failed,rollback failed:%v", err1)
			}
		}
	}()

	if err = fn(tx); err != nil {
		return err
	}

	return tx.Commit()
}

// Transfer 转账
func Transfer(db *sqlx.DB, fromUserId int64, toUserId int64, balance float64) error {
	return Transaction(db, func(tx *sqlx.Tx) error {
		// 扣除fromUserId账户余额
		if err := UpdateUserAccountBalance(tx, fromUserId, -balance); err != nil {
			return err
		}

		// 增加toUserId账户余额
		if err := UpdateUserAccountBalance(tx, toUserId, balance); err != nil {
			return err
		}

		return nil
	})
}

上面,我们定义了一个UserAccount结构体,用来表示用户账户信息。定义了一个Transaction方法,用来包装事务操作。定义了一个Transfer方法,用来执行转账操作。

我们下面来写测试用例:

func TestTransferMoney(t *testing.T) {
	err := Transfer(conn, 1, 2, 10.0)

	assert.Equal(t, err, nil)

	// 查询余额
	account1, err := GetUserAccountById(conn, 1)
	assert.Equal(t, err, nil)
	assert.Equal(t, account1.Balance, 90.0)

	account2, err := GetUserAccountById(conn, 2)

	assert.Equal(t, err, nil)
	assert.Equal(t, account2.Balance, 110.0)
}

在这个测试用例中,我们执行了一次转账操作,并在执行成功后,查询了两个账户的余额,校验了余额是否正确。

运行测试用例,测试通过,查询结果如下:

执行结果