不同数据库的驱动一定不同,对于一种语言来说,不可能在标准库中实现所有的数据库驱动,那么一种语言该如何实现与数据库相关的标准库部分呢?Golang 通过定义一系列与数据库驱动相关的 Interface 实现了数据库驱动的逻辑与具体实现的解藕,十分优雅的解决了这个问题。
同时,Golang 标准库的 database/sql 包以空闲连接池的方式实现了连接的复用,从而达到了提高性能的效果。
Driver
Interface
图 1: Interface design
数据库接口设计如图 1 所示。Connector 通过 Driver 方法获取数据库驱动,通过 Connect 方法与数据库建立连接,并且可以猜测 Connector 是基于 Driver 获取连接的,通过连接可以开启一个事务或者创建一个 SQL 语句,根据 SQL 语句的种类选择应该使用 Exec 方法获得 Result 接口还是使用 Query 方法获得 Rows 接口。全部接口定义如下
type Connector interface {
Connect(context.Context) (Conn, error)
Driver() Driver
}
type Driver interface {
Open(name string) (Conn, error)
}
type Conn interface {
Prepare(query string) (Stmt, error)
Close() error
Begin() (Tx, error)
}
type Tx interface {
Commit() error
Rollback() error
}
type Stmt interface {
Close() error
NumInput() int
Exec(args []Value) (Result, error)
Query(args []Value) (Rows, error)
}
type Result interface {
LastInsertId() (int64, error)
RowsAffected() (int64, error)
}
type Rows interface {
Columns() []string
Close() error
Next(dest []Value) error
}
SQL
Register
对于驱动实现者来说,只需根据 Golang 标准库定义的接口来实现驱动,然后通过 Register 函数将驱动注册即可,注册详情如下
func Register(name string, driver driver.Driver) {
driversMu.Lock()
defer driversMu.Unlock()
if driver == nil {
panic("sql: Register driver is nil")
}
if _, dup := drivers[name]; dup {
panic("sql: Register called twice for driver " + name)
}
drivers[name] = driver
}
L10 将传入的驱动名称和驱动实例保存到 drivers,drivers 是 database/sql 中定义的一个全局变量。
var (
driversMu sync.RWMutex
drivers = make(map[string]driver.Driver)
)
L2 创建了一个读写锁,因为 Golang 中的 map 类型不是并发安全的,所以在访问 drivers 的时候需要加锁来保证其安全性。
Connect Manager
图 2: Connect Manager
连接管理如图 2 所示,其中核心结构是 DB,DB 中有 3 个主要字段,connector 负责底层连接的创建;connRequests 负责连接请求的管理;freeConn 负责空闲连接的存储。
在对数据库进行操作之前必须要获取连接,获取连接的途径有 3 个,从空闲连接池获取;从 connRequests 中的信道获取连接请求,再通过连接请求获取;通过 connector 的 Connect 方法获取底层连接后构建连接。
连接被使用之后,要么形成新的连接请求发送到 connRequests 中的信道,要么作为空闲连接放入空闲连接池。
下面让我们通过阅读源码来了解数据库连接管理的详细过程,首先从 DB 的结构开始。
type DB struct {
// Atomic access only. At top of struct to prevent mis-alignment
// on 32-bit platforms. Of type time.Duration.
waitDuration int64 // Total time waited for new connections.
connector driver.Connector
// numClosed is an atomic counter which represents a total number of
// closed connections. Stmt.openStmt checks it before cleaning closed
// connections in Stmt.css.
numClosed uint64
mu sync.Mutex // protects following fields
freeConn []*driverConn // free connections ordered by returnedAt oldest to newest
connRequests map[uint64]chan connRequest
nextRequest uint64 // Next key to use in connRequests.
numOpen int // number of opened and pending open connections
// Used to signal the need for new connections
// a goroutine running connectionOpener() reads on this chan and
// maybeOpenNewConnections sends on the chan (one send per needed connection)
// It is closed during db.Close(). The close tells the connectionOpener
// goroutine to exit.
openerCh chan struct{}
closed bool
dep map[finalCloser]depSet
lastPut map[*driverConn]string // stacktrace of last conn's put; debug only
maxIdleCount int // zero means defaultMaxIdleConns; negative means 0
maxOpen int // <= 0 means unlimited
maxLifetime time.Duration // maximum amount of time a connection may be reused
maxIdleTime time.Duration // maximum amount of time a connection may be idle before being closed
cleanerCh chan struct{}
waitCount int64 // Total number of connections waited for.
maxIdleClosed int64 // Total number of connections closed due to idle count.
maxIdleTimeClosed int64 // Total number of connections closed due to idle time.
maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit.
stop func() // stop cancels the connection opener.
}
DB 实例可以通过 Open 函数进行创建,最终是调用 OpenDB 函数创建的 DB 实例,调用 OpenDB 还需要获得 Connector 接口作为参数。Connector 可以通过两种方式获取,L9 ~ L15 表明如果设计驱动的同时实现了 DriverContext 接口,那么直接调用 DriverContext 接口中的 OpenConnector 方法获取 Connector;L17 表明如果驱动没有实现 DriverContext 接口,则使用标准库中默认的 dsnConnector 类型作为 Connector 接口。这样设计既允许自己定义 Connector 获取的方法,也提供了默认的 Connector,提高了灵活性,在一些场景可考虑使用。
func Open(driverName, dataSourceName string) (*DB, error) {
driversMu.RLock()
driveri, ok := drivers[driverName]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
if driverCtx, ok := driveri.(driver.DriverContext); ok {
connector, err := driverCtx.OpenConnector(dataSourceName)
if err != nil {
return nil, err
}
return OpenDB(connector), nil
}
return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
OpenDB 详细内容如下
func OpenDB(c driver.Connector) *DB {
ctx, cancel := context.WithCancel(context.Background())
db := &DB{
connector: c,
openerCh: make(chan struct{}, connectionRequestQueueSize),
lastPut: make(map[*driverConn]string),
connRequests: make(map[uint64]chan connRequest),
stop: cancel,
}
go db.connectionOpener(ctx)
return db
}
[1] L3 ~ L9: 创建 DB 实例;
[2] L11: 开启一个创建连接的协程。
协程详情如下
func (db *DB) connectionOpener(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case <-db.openerCh:
db.openNewConnection(ctx)
}
}
}
协程的主要内容是监听一个信号,并且在收到信号后调用 oepnNewConnection 方法打开一个新的连接,那么问题来了,什么时候会发送信号到 openerCh 中呢?通过详细阅读可以发现,在maybeOpenNewConnections 方法中会向 openerCh 发送信号。
func (db *DB) maybeOpenNewConnections() {
numRequests := len(db.connRequests)
if db.maxOpen > 0 {
numCanOpen := db.maxOpen - db.numOpen
if numRequests > numCanOpen {
numRequests = numCanOpen
}
}
for numRequests > 0 {
db.numOpen++ // optimistically
numRequests--
if db.closed {
return
}
db.openerCh <- struct{}{}
}
}
[1] L2 ~ L8: 根据 maxOpen 和 numOpen 计算出还能打开的连接数量,如果现在连接请求数量大于还能打开的连接数量,将还能打开的连接数量赋值给连接请求数量;
[2] L9 ~ L16: 存在多少连接请求就向 openerCh 中发送多少次信号。
可以发现协程的唤醒与连接请求个数有关,那么连接请求究竟是为了解决什么问题而设计的呢?让我们带着疑问回归 openNewConnection 函数的阅读。
func (db *DB) openNewConnection(ctx context.Context) {
// maybeOpenNewConnections has already executed db.numOpen++ before it sent
// on db.openerCh. This function must execute db.numOpen-- if the
// connection fails or is closed before returning.
ci, err := db.connector.Connect(ctx)
db.mu.Lock()
defer db.mu.Unlock()
if db.closed {
if err == nil {
ci.Close()
}
db.numOpen--
return
}
if err != nil {
db.numOpen--
db.putConnDBLocked(nil, err)
db.maybeOpenNewConnections()
return
}
dc := &driverConn{
db: db,
createdAt: nowFunc(),
returnedAt: nowFunc(),
ci: ci,
}
if db.putConnDBLocked(dc, err) {
db.addDepLocked(dc, dc)
} else {
db.numOpen--
ci.Close()
}
}
[1] L5: 通过 Connector 打开一个连接;
[2] L6 ~ L14: 如果 DB 已经关闭并且成功打开了连接,将连接关闭,并且将打开的连接数量减 1 并返回;
[3] L15 ~ L20: 如果数据库没有关闭并且打开连接出现了错误,将打开的连接数量减 1,并将空连接和详细错误加入到 DB,并调用 maybeOpenNewConnections 进行错误处理;
[4] L21 ~ L32: 如果数据库没有错误并且成功打开了连接,创建 dc 实例,并将新打开的连接加入其中,再将 dc 加入 DB 中,如果加入失败,将打开的连接数量减 1,关闭连接。
可以发现在 DB 未关闭的情况下最终都会调用 putConnDBLocked 方法,并且方法执行失败一定会将打开连接数量减 1,因此可以猜测在协程接收到信号之前一定进行了连接数量加 1 的操作(猜想可以通过 maybeOpenNewConnections 方法得到验证),并且 putConnDBLocked 方法一定是核心操作。
func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
if db.closed {
return false
}
if db.maxOpen > 0 && db.numOpen > db.maxOpen {
return false
}
if c := len(db.connRequests); c > 0 {
var req chan connRequest
var reqKey uint64
for reqKey, req = range db.connRequests {
break
}
delete(db.connRequests, reqKey) // Remove from pending requests.
if err == nil {
dc.inUse = true
}
req <- connRequest{
conn: dc,
err: err,
}
return true
} else if err == nil && !db.closed {
if db.maxIdleConnsLocked() > len(db.freeConn) {
db.freeConn = append(db.freeConn, dc)
db.startCleanerLocked()
return true
}
db.maxIdleClosed++
}
return false
}
[1] L8 ~ L22: 在 connRequests 中随机选取一个 chan connRequest,将连接和错误传入其中;
[2] L23 ~ L30: 如果最大空闲连接大于当前空闲连接池中的连接数量,将连接加入空闲连接池中,并且开启定时清理空闲连接池中无效连接的协程。
那么问题来了,chan connRequest 只有发送端,接受端在哪里呢?通过继续阅读可以发现,接收端在 DB 的 conn 方法中。还记得文章一开始介绍的那些接口吗,DB 实现了其中的大部分,包括 Conn、Stmt、TX 等,实现这些接口的方法中都有调用 conn 方法,即无论是创建执行语句的 Prepare、开启事务的 Begin、执行非查询语句的 Exec、还是执行查询语句的 Query,都需要先执行 conn 方法。
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
db.mu.Lock()
if db.closed {
db.mu.Unlock()
return nil, errDBClosed
}
// Check if the context is expired.
select {
default:
case <-ctx.Done():
db.mu.Unlock()
return nil, ctx.Err()
}
lifetime := db.maxLifetime
// Prefer a free connection, if possible.
last := len(db.freeConn) - 1
if strategy == cachedOrNewConn && last >= 0 {
// Reuse the lowest idle time connection so we can close
// connections which remain idle as soon as possible.
conn := db.freeConn[last]
db.freeConn = db.freeConn[:last]
conn.inUse = true
if conn.expired(lifetime) {
db.maxLifetimeClosed++
db.mu.Unlock()
conn.Close()
return nil, driver.ErrBadConn
}
db.mu.Unlock()
// Reset the session if required.
if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
conn.Close()
return nil, err
}
return conn, nil
}
// Out of free connections or we were asked not to use one. If we're not
// allowed to open any more connections, make a request and wait.
if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
// Make the connRequest channel. It's buffered so that the
// connectionOpener doesn't block while waiting for the req to be read.
req := make(chan connRequest, 1)
reqKey := db.nextRequestKeyLocked()
db.connRequests[reqKey] = req
db.waitCount++
db.mu.Unlock()
waitStart := nowFunc()
// Timeout the connection request with the context.
select {
case <-ctx.Done():
// Remove the connection request and ensure no value has been sent
// on it after removing.
db.mu.Lock()
delete(db.connRequests, reqKey)
db.mu.Unlock()
atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))
select {
default:
case ret, ok := <-req:
if ok && ret.conn != nil {
db.putConn(ret.conn, ret.err, false)
}
}
return nil, ctx.Err()
case ret, ok := <-req:
atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))
if !ok {
return nil, errDBClosed
}
// Only check if the connection is expired if the strategy is cachedOrNewConns.
// If we require a new connection, just re-use the connection without looking
// at the expiry time. If it is expired, it will be checked when it is placed
// back into the connection pool.
// This prioritizes giving a valid connection to a client over the exact connection
// lifetime, which could expire exactly after this point anyway.
if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {
db.mu.Lock()
db.maxLifetimeClosed++
db.mu.Unlock()
ret.conn.Close()
return nil, driver.ErrBadConn
}
if ret.conn == nil {
return nil, ret.err
}
// Reset the session if required.
if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {
ret.conn.Close()
return nil, err
}
return ret.conn, ret.err
}
}
db.numOpen++ // optimistically
db.mu.Unlock()
ci, err := db.connector.Connect(ctx)
if err != nil {
db.mu.Lock()
db.numOpen-- // correct for earlier optimism
db.maybeOpenNewConnections()
db.mu.Unlock()
return nil, err
}
db.mu.Lock()
dc := &driverConn{
db: db,
createdAt: nowFunc(),
returnedAt: nowFunc(),
ci: ci,
inUse: true,
}
db.addDepLocked(dc, dc)
db.mu.Unlock()
return dc, nil
}
[1] L18 ~ L39: 取出空闲连接池中最后一个连接,即剩余空闲时间最少的连接,将其设为正在使用状态,判断连接是否过期,如果过期,则将由于生命周期限制而关闭的连接总数加 1,之后关闭连接,并返回空连接和错误,如果未过期,判断是否需要重设 session,如果需要则重设,最后返回连接;
[2] L43 ~ L103: 如果空闲连接池为空或者使用了 alwaysNewConn 策略,同时已打开连接数量大于等于最大连接数量,那么将创建连接请求的 channel,并且会监听关闭信号和连接请求 channel。收到关闭信号后会将连接放入空闲连接池;收到连接请求后会对连接请求中的连接进行检验,如果有问题返回 nil,没有问题返回连接。这样就明白了连接请求的设计本质上是为了保证 maxOpen 的有效性;
[3] L107 ~ L125: 通过 Connector 创建连接,如果连接失败,调用 maybeOpenNewConnections 尝试再次打开连接。
那么问题来了,使用 conn 方法获取连接并使用完毕之后对于连接是如何处理的,我们可以通过 Exec 方法进行查看,Exec 方法最终会调用 exec 方法,因此直接查看 exec 即可
func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {
dc, err := db.conn(ctx, strategy)
if err != nil {
return nil, err
}
return db.execDC(ctx, dc, dc.releaseConn, query, args)
}
conn 方法之前已经分析过了,看起来第 6 行的 releaseConn 方法是连接的后续处理方法
func (dc *driverConn) releaseConn(err error) {
dc.db.putConn(dc, err, true)
}
继续查看 putConn 方法
func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
if !errors.Is(err, driver.ErrBadConn) {
if !dc.validateConnection(resetSession) {
err = driver.ErrBadConn
}
}
db.mu.Lock()
if !dc.inUse {
db.mu.Unlock()
if debugGetPut {
fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
}
panic("sql: connection returned that was never out")
}
if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {
db.maxLifetimeClosed++
err = driver.ErrBadConn
}
if debugGetPut {
db.lastPut[dc] = stack()
}
dc.inUse = false
dc.returnedAt = nowFunc()
for _, fn := range dc.onPut {
fn()
}
dc.onPut = nil
if errors.Is(err, driver.ErrBadConn) {
// Don't reuse bad connections.
// Since the conn is considered bad and is being discarded, treat it
// as closed. Don't decrement the open count here, finalClose will
// take care of that.
db.maybeOpenNewConnections()
db.mu.Unlock()
dc.Close()
return
}
if putConnHook != nil {
putConnHook(db, dc)
}
added := db.putConnDBLocked(dc, nil)
db.mu.Unlock()
if !added {
dc.Close()
return
}
}
可以发现第 44 行调用了 putConnDBLocked 方法,结合之前 putConnDBLocked 方法的分析可知,连接使用完之后,如果连接请求集合不为空,则作为连接请求被使用;如果连接请求为空并且空闲连接池的容量还没有达到上限,则加入空闲连接池中。
Usage
图 3: Connector acquisition
首先通过 Open 函数创建 DB 实例,创建过程中最重要的是 Connector 的获取,因为获取新连接都需要通过 Connector 实现,Connector 的获取详情如图 3 所示,然后使用 Exec 或 Query 方法执行 SQL 语句即可。
func Open(driverName, dataSourceName string) (*DB, error) {
driversMu.RLock()
driveri, ok := drivers[driverName]
driversMu.RUnlock()
if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
}
if driverCtx, ok := driveri.(driver.DriverContext); ok {
connector, err := driverCtx.OpenConnector(dataSourceName)
if err != nil {
return nil, err
}
return OpenDB(connector), nil
}
return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
}
下面是一个使用 database/sql 操作 mysql 数据库的例子,数据库相关方法封装如下
package mysql
var (
DBName = "test"
TableName = "user"
)
type User struct {
ID uint32
Name string
}
const (
mysqlUserCreateDatabase = iota
mysqlUserCreateTable
mysqlUserInsert
mysqlUserSelectByID
mysqlUserSelectAll
)
var (
errInvalidInsert = errors.New("[user] invalid insert ")
userSQLString = map[int]string{
fmt.Sprintf(`CREATE DATABASE IF NOT EXISTS %s`, DBName),
fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.%s(
id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY,
name VARCHAR(56) NOT NULL UNIQUE COMMENT '用户名',
)ENGINE=InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_bin;`, DBName, TableName),
fmt.Sprintf(`INSERT INTO %s.%s (name) VALUES (?)`, DBName, TableName),
fmt.Sprintf(`SELECT id, name FROM %s.%s WHERE id = ? LIMIT 1`, DBName, TableName),
fmt.Sprintf(`SELECT id, name FROM %s.%s`, DBName, TableName),
}
)
func CreateDatabase(db *sql.DB) error {
_, err := db.Exec(userSQLString[mysqlUserCreateDatabase])
return err
}
func CreateTable(db *sql.DB) error {
_, err := db.Exec(userSQLString[mysqlUserCreateTable])
return err
}
func InsertUser(db *sql.DB, name string) error {
result, err := db.Exec(userSQLString[mysqlUserInsert], name)
if err != nil {
return err
}
if rows, _ := result.RowsAffected(); rows == 0 {
return errInvalidInsert
}
return nil
}
func TxInsertUser(tx *sql.Tx, name string) error {
result, err := tx.Exec(userSQLString[mysqlUserInsert], name)
if err != nil {
return err
}
if rows, _ := result.RowsAffected(); rows == 0 {
return errInvalidInsert
}
return nil
}
func SelectUserByID(db *sql.DB, id uint32) (*User, error) {
var name string
row := db.QueryRow(userSQLString[mysqlUserSelectByID], id)
if err := row.Scan(&id, &name); err != nil {
return nil, err
}
return &Account{
ID: id,
Name: name,
}, nil
}
func TxSelectUserByID(tx *sql.Tx, id uint32) (*User, error) {
var name string
row := tx.QueryRow(userSQLString[mysqlUserSelectByID], id)
if err := row.Scan(&id, &name); err != nil {
return nil, err
}
return &User{
ID: id,
Name: name,
}, nil
}
func ListUsers(db *sql.DB) ([]*User, error) {
var (
Users = make([]*User, 0)
ID uint32
Name string
)
rows, err := db.Query(userSQLString[mysqlUserSelectAll])
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
if err := rows.Scan(&ID, &Name); err != nil {
return nil, err
}
User := &User{
ID: ID,
Name: Name,
}
Users = append(Users, User)
}
return Users, nil
}
func TxListUsers(tx *sql.Tx) ([]*User, error) {
var (
Users = make([]*User, 0)
ID uint32
Name string
)
rows, err := tx.Query(userSQLString[mysqlUserSelectAll])
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
if err := rows.Scan(&ID, &Name); err != nil {
return nil, err
}
User := &User{
ID: ID,
Name: Name,
}
Users = append(Users, User)
}
return Users, nil
}
使用方法前需要导入 mysql 驱动,在 mysql 驱动相关官网找到 dsn 格式用来初始化 DB 实例,最后调用封装好的方法即可。
package main
import (
"fmt"
"database/sql"
_ "github.com/go-sql-driver/mysql"
)
func main() {
dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=%t&loc=%s",
username, password, host, port, database, charset, true, "Local")
db := sql.Open("mysql", dsn)
db.SetMaxOpenConns(maxOpenConns)
db.SetMaxIdleConns(maxIdleConns)
db.SetConnMaxLifetime(time.Duration(maxLifetime) * time.Second)
if err := mysql.CreateDatabase(db); err != nil {
panic(err)
}
}
[1] L15: 设置最大打开连接数量来控制最大并发量;
[2] L16: 设置最大空闲连接来控制空闲连接池大小;
[3] L17: 设置连接的最长存活时间。