使用sqlx执行事务
August 19, 2023
前文我们已经介绍了sqlx的基本使用方法,今天我们来看一下怎样使用sqlx库来执行事务。
数据库:mysql。
准备 #
创建一张用户账户表:
CREATE TABLE `user_account` (
`id` int unsigned NOT NULL AUTO_INCREMENT,
`user_id` int unsigned NOT NULL COMMENT '用户id',
`balance` decimal(10,0) NOT NULL COMMENT '账户余额。',
`created_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP COMMENT '添加时间',
`modified_at` datetime NOT NULL DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP COMMENT '更新时间',
PRIMARY KEY (`id`),
UNIQUE KEY `udx_userid` (`user_id`) USING BTREE
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_general_ci
我们向这张表中添加两条记录:
可以看到,表中现在有两个用户,id分别为1和2,账户余额都是100.
在实际中,一个用户可能有多个账户,这里为了简单起见,只创建了一个账户。且使用了唯一索引。本文后面部分说到用户1和用户2对应的账户记录时,均是指id=1和id=2的这两条记录。
事务 #
我们现在进行一个事务: 用户1向用户2转账10元。
type UserAccount struct {
Id int64 `db:"id"`
UserId int64 `db:"user_id"`
Balance float64 `db:"balance"` // 账户余额
CreatedAt time.Time `db:"created_at"`
ModifiedAt time.Time `db:"modified_at"`
}
// GetUserAccountById 根据id获取用户账户
func GetUserAccountById(db *sqlx.DB, id int64) (*UserAccount, error) {
sql := `SELECT * FROM user_account WHERE id = ?`
userAccount := new(UserAccount)
err := db.Get(userAccount, sql, id)
if err != nil {
return nil, err
}
return userAccount, nil
}
// UpdateUserAccountBalance 更新用户账户余额,balance为正数表示增加,为负数表示减少
func UpdateUserAccountBalance(db *sqlx.Tx, id int64, balance float64) error {
sql := `UPDATE user_account SET balance = balance + ? WHERE id = ?`
_, err := db.Exec(sql, balance, id)
if err != nil {
return err
}
return nil
}
// Transaction 包装事务方法
func Transaction(db *sqlx.DB, fn func(tx *sqlx.Tx) error) error {
var tx *sqlx.Tx
var err error
tx, err = db.Beginx()
if err != nil {
return errors.Wrap(err, "开启事务失败")
}
defer func() {
if err != nil {
if err1 := tx.Rollback(); err1 != nil {
err = errors.WithMessagef(err, "execute failed,rollback failed:%v", err1)
}
}
}()
if err = fn(tx); err != nil {
return err
}
return tx.Commit()
}
// Transfer 转账
func Transfer(db *sqlx.DB, fromUserId int64, toUserId int64, balance float64) error {
return Transaction(db, func(tx *sqlx.Tx) error {
// 扣除fromUserId账户余额
if err := UpdateUserAccountBalance(tx, fromUserId, -balance); err != nil {
return err
}
// 增加toUserId账户余额
if err := UpdateUserAccountBalance(tx, toUserId, balance); err != nil {
return err
}
return nil
})
}
上面,我们定义了一个UserAccount结构体,用来表示用户账户信息。定义了一个Transaction方法,用来包装事务操作。定义了一个Transfer方法,用来执行转账操作。
我们下面来写测试用例:
func TestTransferMoney(t *testing.T) {
err := Transfer(conn, 1, 2, 10.0)
assert.Equal(t, err, nil)
// 查询余额
account1, err := GetUserAccountById(conn, 1)
assert.Equal(t, err, nil)
assert.Equal(t, account1.Balance, 90.0)
account2, err := GetUserAccountById(conn, 2)
assert.Equal(t, err, nil)
assert.Equal(t, account2.Balance, 110.0)
}
在这个测试用例中,我们执行了一次转账操作,并在执行成功后,查询了两个账户的余额,校验了余额是否正确。
运行测试用例,测试通过,查询结果如下:
