不同数据库的驱动一定不同,对于一种语言来说,不可能在标准库中实现所有的数据库驱动,那么一种语言该如何实现与数据库相关的标准库部分呢?Golang 通过定义一系列与数据库驱动相关的 Interface 实现了数据库驱动的逻辑与具体实现的解藕,十分优雅的解决了这个问题。
同时,Golang 标准库的 database/sql 包以空闲连接池的方式实现了连接的复用,从而达到了提高性能的效果。
Driver
Interface
图 1: Interface design
数据库接口设计如图 1 所示。Connector 通过 Driver 方法获取数据库驱动,通过 Connect 方法与数据库建立连接,并且可以猜测 Connector 是基于 Driver 获取连接的,通过连接可以开启一个事务或者创建一个 SQL 语句,根据 SQL 语句的种类选择应该使用 Exec 方法获得 Result 接口还是使用 Query 方法获得 Rows 接口。全部接口定义如下
type Connector interface {Connect(context.Context) (Conn, error)Driver() Driver}type Driver interface {Open(name string) (Conn, error)}type Conn interface {Prepare(query string) (Stmt, error)Close() errorBegin() (Tx, error)}type Tx interface {Commit() errorRollback() error}type Stmt interface {Close() errorNumInput() intExec(args []Value) (Result, error)Query(args []Value) (Rows, error)}type Result interface {LastInsertId() (int64, error)RowsAffected() (int64, error)}type Rows interface {Columns() []stringClose() errorNext(dest []Value) error}
SQL
Register
对于驱动实现者来说,只需根据 Golang 标准库定义的接口来实现驱动,然后通过 Register 函数将驱动注册即可,注册详情如下
func Register(name string, driver driver.Driver) {driversMu.Lock()defer driversMu.Unlock()if driver == nil {panic("sql: Register driver is nil")}if _, dup := drivers[name]; dup {panic("sql: Register called twice for driver " + name)}drivers[name] = driver}
L10 将传入的驱动名称和驱动实例保存到 drivers,drivers 是 database/sql 中定义的一个全局变量。
var (driversMu sync.RWMutexdrivers = make(map[string]driver.Driver))
L2 创建了一个读写锁,因为 Golang 中的 map 类型不是并发安全的,所以在访问 drivers 的时候需要加锁来保证其安全性。
Connect Manager
图 2: Connect Manager
连接管理如图 2 所示,其中核心结构是 DB,DB 中有 3 个主要字段,connector 负责底层连接的创建;connRequests 负责连接请求的管理;freeConn 负责空闲连接的存储。
在对数据库进行操作之前必须要获取连接,获取连接的途径有 3 个,从空闲连接池获取;从 connRequests 中的信道获取连接请求,再通过连接请求获取;通过 connector 的 Connect 方法获取底层连接后构建连接。
连接被使用之后,要么形成新的连接请求发送到 connRequests 中的信道,要么作为空闲连接放入空闲连接池。
下面让我们通过阅读源码来了解数据库连接管理的详细过程,首先从 DB 的结构开始。
type DB struct {// Atomic access only. At top of struct to prevent mis-alignment// on 32-bit platforms. Of type time.Duration.waitDuration int64 // Total time waited for new connections.connector driver.Connector// numClosed is an atomic counter which represents a total number of// closed connections. Stmt.openStmt checks it before cleaning closed// connections in Stmt.css.numClosed uint64mu sync.Mutex // protects following fieldsfreeConn []*driverConn // free connections ordered by returnedAt oldest to newestconnRequests map[uint64]chan connRequestnextRequest uint64 // Next key to use in connRequests.numOpen int // number of opened and pending open connections// Used to signal the need for new connections// a goroutine running connectionOpener() reads on this chan and// maybeOpenNewConnections sends on the chan (one send per needed connection)// It is closed during db.Close(). The close tells the connectionOpener// goroutine to exit.openerCh chan struct{}closed booldep map[finalCloser]depSetlastPut map[*driverConn]string // stacktrace of last conn's put; debug onlymaxIdleCount int // zero means defaultMaxIdleConns; negative means 0maxOpen int // <= 0 means unlimitedmaxLifetime time.Duration // maximum amount of time a connection may be reusedmaxIdleTime time.Duration // maximum amount of time a connection may be idle before being closedcleanerCh chan struct{}waitCount int64 // Total number of connections waited for.maxIdleClosed int64 // Total number of connections closed due to idle count.maxIdleTimeClosed int64 // Total number of connections closed due to idle time.maxLifetimeClosed int64 // Total number of connections closed due to max connection lifetime limit.stop func() // stop cancels the connection opener.}
DB 实例可以通过 Open 函数进行创建,最终是调用 OpenDB 函数创建的 DB 实例,调用 OpenDB 还需要获得 Connector 接口作为参数。Connector 可以通过两种方式获取,L9 ~ L15 表明如果设计驱动的同时实现了 DriverContext 接口,那么直接调用 DriverContext 接口中的 OpenConnector 方法获取 Connector;L17 表明如果驱动没有实现 DriverContext 接口,则使用标准库中默认的 dsnConnector 类型作为 Connector 接口。这样设计既允许自己定义 Connector 获取的方法,也提供了默认的 Connector,提高了灵活性,在一些场景可考虑使用。
func Open(driverName, dataSourceName string) (*DB, error) {driversMu.RLock()driveri, ok := drivers[driverName]driversMu.RUnlock()if !ok {return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)}if driverCtx, ok := driveri.(driver.DriverContext); ok {connector, err := driverCtx.OpenConnector(dataSourceName)if err != nil {return nil, err}return OpenDB(connector), nil}return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil}
OpenDB 详细内容如下
func OpenDB(c driver.Connector) *DB {ctx, cancel := context.WithCancel(context.Background())db := &DB{connector: c,openerCh: make(chan struct{}, connectionRequestQueueSize),lastPut: make(map[*driverConn]string),connRequests: make(map[uint64]chan connRequest),stop: cancel,}go db.connectionOpener(ctx)return db}
[1] L3 ~ L9: 创建 DB 实例;
[2] L11: 开启一个创建连接的协程。
协程详情如下
func (db *DB) connectionOpener(ctx context.Context) {for {select {case <-ctx.Done():returncase <-db.openerCh:db.openNewConnection(ctx)}}}
协程的主要内容是监听一个信号,并且在收到信号后调用 oepnNewConnection 方法打开一个新的连接,那么问题来了,什么时候会发送信号到 openerCh 中呢?通过详细阅读可以发现,在maybeOpenNewConnections 方法中会向 openerCh 发送信号。
func (db *DB) maybeOpenNewConnections() {numRequests := len(db.connRequests)if db.maxOpen > 0 {numCanOpen := db.maxOpen - db.numOpenif numRequests > numCanOpen {numRequests = numCanOpen}}for numRequests > 0 {db.numOpen++ // optimisticallynumRequests--if db.closed {return}db.openerCh <- struct{}{}}}
[1] L2 ~ L8: 根据 maxOpen 和 numOpen 计算出还能打开的连接数量,如果现在连接请求数量大于还能打开的连接数量,将还能打开的连接数量赋值给连接请求数量;
[2] L9 ~ L16: 存在多少连接请求就向 openerCh 中发送多少次信号。
可以发现协程的唤醒与连接请求个数有关,那么连接请求究竟是为了解决什么问题而设计的呢?让我们带着疑问回归 openNewConnection 函数的阅读。
func (db *DB) openNewConnection(ctx context.Context) {// maybeOpenNewConnections has already executed db.numOpen++ before it sent// on db.openerCh. This function must execute db.numOpen-- if the// connection fails or is closed before returning.ci, err := db.connector.Connect(ctx)db.mu.Lock()defer db.mu.Unlock()if db.closed {if err == nil {ci.Close()}db.numOpen--return}if err != nil {db.numOpen--db.putConnDBLocked(nil, err)db.maybeOpenNewConnections()return}dc := &driverConn{db: db,createdAt: nowFunc(),returnedAt: nowFunc(),ci: ci,}if db.putConnDBLocked(dc, err) {db.addDepLocked(dc, dc)} else {db.numOpen--ci.Close()}}
[1] L5: 通过 Connector 打开一个连接;
[2] L6 ~ L14: 如果 DB 已经关闭并且成功打开了连接,将连接关闭,并且将打开的连接数量减 1 并返回;
[3] L15 ~ L20: 如果数据库没有关闭并且打开连接出现了错误,将打开的连接数量减 1,并将空连接和详细错误加入到 DB,并调用 maybeOpenNewConnections 进行错误处理;
[4] L21 ~ L32: 如果数据库没有错误并且成功打开了连接,创建 dc 实例,并将新打开的连接加入其中,再将 dc 加入 DB 中,如果加入失败,将打开的连接数量减 1,关闭连接。
可以发现在 DB 未关闭的情况下最终都会调用 putConnDBLocked 方法,并且方法执行失败一定会将打开连接数量减 1,因此可以猜测在协程接收到信号之前一定进行了连接数量加 1 的操作(猜想可以通过 maybeOpenNewConnections 方法得到验证),并且 putConnDBLocked 方法一定是核心操作。
func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {if db.closed {return false}if db.maxOpen > 0 && db.numOpen > db.maxOpen {return false}if c := len(db.connRequests); c > 0 {var req chan connRequestvar reqKey uint64for reqKey, req = range db.connRequests {break}delete(db.connRequests, reqKey) // Remove from pending requests.if err == nil {dc.inUse = true}req <- connRequest{conn: dc,err: err,}return true} else if err == nil && !db.closed {if db.maxIdleConnsLocked() > len(db.freeConn) {db.freeConn = append(db.freeConn, dc)db.startCleanerLocked()return true}db.maxIdleClosed++}return false}
[1] L8 ~ L22: 在 connRequests 中随机选取一个 chan connRequest,将连接和错误传入其中;
[2] L23 ~ L30: 如果最大空闲连接大于当前空闲连接池中的连接数量,将连接加入空闲连接池中,并且开启定时清理空闲连接池中无效连接的协程。
那么问题来了,chan connRequest 只有发送端,接受端在哪里呢?通过继续阅读可以发现,接收端在 DB 的 conn 方法中。还记得文章一开始介绍的那些接口吗,DB 实现了其中的大部分,包括 Conn、Stmt、TX 等,实现这些接口的方法中都有调用 conn 方法,即无论是创建执行语句的 Prepare、开启事务的 Begin、执行非查询语句的 Exec、还是执行查询语句的 Query,都需要先执行 conn 方法。
func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {db.mu.Lock()if db.closed {db.mu.Unlock()return nil, errDBClosed}// Check if the context is expired.select {default:case <-ctx.Done():db.mu.Unlock()return nil, ctx.Err()}lifetime := db.maxLifetime// Prefer a free connection, if possible.last := len(db.freeConn) - 1if strategy == cachedOrNewConn && last >= 0 {// Reuse the lowest idle time connection so we can close// connections which remain idle as soon as possible.conn := db.freeConn[last]db.freeConn = db.freeConn[:last]conn.inUse = trueif conn.expired(lifetime) {db.maxLifetimeClosed++db.mu.Unlock()conn.Close()return nil, driver.ErrBadConn}db.mu.Unlock()// Reset the session if required.if err := conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {conn.Close()return nil, err}return conn, nil}// Out of free connections or we were asked not to use one. If we're not// allowed to open any more connections, make a request and wait.if db.maxOpen > 0 && db.numOpen >= db.maxOpen {// Make the connRequest channel. It's buffered so that the// connectionOpener doesn't block while waiting for the req to be read.req := make(chan connRequest, 1)reqKey := db.nextRequestKeyLocked()db.connRequests[reqKey] = reqdb.waitCount++db.mu.Unlock()waitStart := nowFunc()// Timeout the connection request with the context.select {case <-ctx.Done():// Remove the connection request and ensure no value has been sent// on it after removing.db.mu.Lock()delete(db.connRequests, reqKey)db.mu.Unlock()atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))select {default:case ret, ok := <-req:if ok && ret.conn != nil {db.putConn(ret.conn, ret.err, false)}}return nil, ctx.Err()case ret, ok := <-req:atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))if !ok {return nil, errDBClosed}// Only check if the connection is expired if the strategy is cachedOrNewConns.// If we require a new connection, just re-use the connection without looking// at the expiry time. If it is expired, it will be checked when it is placed// back into the connection pool.// This prioritizes giving a valid connection to a client over the exact connection// lifetime, which could expire exactly after this point anyway.if strategy == cachedOrNewConn && ret.err == nil && ret.conn.expired(lifetime) {db.mu.Lock()db.maxLifetimeClosed++db.mu.Unlock()ret.conn.Close()return nil, driver.ErrBadConn}if ret.conn == nil {return nil, ret.err}// Reset the session if required.if err := ret.conn.resetSession(ctx); errors.Is(err, driver.ErrBadConn) {ret.conn.Close()return nil, err}return ret.conn, ret.err}}db.numOpen++ // optimisticallydb.mu.Unlock()ci, err := db.connector.Connect(ctx)if err != nil {db.mu.Lock()db.numOpen-- // correct for earlier optimismdb.maybeOpenNewConnections()db.mu.Unlock()return nil, err}db.mu.Lock()dc := &driverConn{db: db,createdAt: nowFunc(),returnedAt: nowFunc(),ci: ci,inUse: true,}db.addDepLocked(dc, dc)db.mu.Unlock()return dc, nil}
[1] L18 ~ L39: 取出空闲连接池中最后一个连接,即剩余空闲时间最少的连接,将其设为正在使用状态,判断连接是否过期,如果过期,则将由于生命周期限制而关闭的连接总数加 1,之后关闭连接,并返回空连接和错误,如果未过期,判断是否需要重设 session,如果需要则重设,最后返回连接;
[2] L43 ~ L103: 如果空闲连接池为空或者使用了 alwaysNewConn 策略,同时已打开连接数量大于等于最大连接数量,那么将创建连接请求的 channel,并且会监听关闭信号和连接请求 channel。收到关闭信号后会将连接放入空闲连接池;收到连接请求后会对连接请求中的连接进行检验,如果有问题返回 nil,没有问题返回连接。这样就明白了连接请求的设计本质上是为了保证 maxOpen 的有效性;
[3] L107 ~ L125: 通过 Connector 创建连接,如果连接失败,调用 maybeOpenNewConnections 尝试再次打开连接。
那么问题来了,使用 conn 方法获取连接并使用完毕之后对于连接是如何处理的,我们可以通过 Exec 方法进行查看,Exec 方法最终会调用 exec 方法,因此直接查看 exec 即可
func (db *DB) exec(ctx context.Context, query string, args []any, strategy connReuseStrategy) (Result, error) {dc, err := db.conn(ctx, strategy)if err != nil {return nil, err}return db.execDC(ctx, dc, dc.releaseConn, query, args)}
conn 方法之前已经分析过了,看起来第 6 行的 releaseConn 方法是连接的后续处理方法
func (dc *driverConn) releaseConn(err error) {dc.db.putConn(dc, err, true)}
继续查看 putConn 方法
func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {if !errors.Is(err, driver.ErrBadConn) {if !dc.validateConnection(resetSession) {err = driver.ErrBadConn}}db.mu.Lock()if !dc.inUse {db.mu.Unlock()if debugGetPut {fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])}panic("sql: connection returned that was never out")}if !errors.Is(err, driver.ErrBadConn) && dc.expired(db.maxLifetime) {db.maxLifetimeClosed++err = driver.ErrBadConn}if debugGetPut {db.lastPut[dc] = stack()}dc.inUse = falsedc.returnedAt = nowFunc()for _, fn := range dc.onPut {fn()}dc.onPut = nilif errors.Is(err, driver.ErrBadConn) {// Don't reuse bad connections.// Since the conn is considered bad and is being discarded, treat it// as closed. Don't decrement the open count here, finalClose will// take care of that.db.maybeOpenNewConnections()db.mu.Unlock()dc.Close()return}if putConnHook != nil {putConnHook(db, dc)}added := db.putConnDBLocked(dc, nil)db.mu.Unlock()if !added {dc.Close()return}}
可以发现第 44 行调用了 putConnDBLocked 方法,结合之前 putConnDBLocked 方法的分析可知,连接使用完之后,如果连接请求集合不为空,则作为连接请求被使用;如果连接请求为空并且空闲连接池的容量还没有达到上限,则加入空闲连接池中。
Usage
图 3: Connector acquisition
首先通过 Open 函数创建 DB 实例,创建过程中最重要的是 Connector 的获取,因为获取新连接都需要通过 Connector 实现,Connector 的获取详情如图 3 所示,然后使用 Exec 或 Query 方法执行 SQL 语句即可。
func Open(driverName, dataSourceName string) (*DB, error) {driversMu.RLock()driveri, ok := drivers[driverName]driversMu.RUnlock()if !ok {return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)}if driverCtx, ok := driveri.(driver.DriverContext); ok {connector, err := driverCtx.OpenConnector(dataSourceName)if err != nil {return nil, err}return OpenDB(connector), nil}return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil}
下面是一个使用 database/sql 操作 mysql 数据库的例子,数据库相关方法封装如下
package mysqlvar (DBName = "test"TableName = "user")type User struct {ID uint32Name string}const (mysqlUserCreateDatabase = iotamysqlUserCreateTablemysqlUserInsertmysqlUserSelectByIDmysqlUserSelectAll)var (errInvalidInsert = errors.New("[user] invalid insert ")userSQLString = map[int]string{fmt.Sprintf(`CREATE DATABASE IF NOT EXISTS %s`, DBName),fmt.Sprintf(`CREATE TABLE IF NOT EXISTS %s.%s(id INT UNSIGNED AUTO_INCREMENT PRIMARY KEY,name VARCHAR(56) NOT NULL UNIQUE COMMENT '用户名',)ENGINE=InnoDB CHARSET=utf8mb4 COLLATE=utf8mb4_bin;`, DBName, TableName),fmt.Sprintf(`INSERT INTO %s.%s (name) VALUES (?)`, DBName, TableName),fmt.Sprintf(`SELECT id, name FROM %s.%s WHERE id = ? LIMIT 1`, DBName, TableName),fmt.Sprintf(`SELECT id, name FROM %s.%s`, DBName, TableName),})func CreateDatabase(db *sql.DB) error {_, err := db.Exec(userSQLString[mysqlUserCreateDatabase])return err}func CreateTable(db *sql.DB) error {_, err := db.Exec(userSQLString[mysqlUserCreateTable])return err}func InsertUser(db *sql.DB, name string) error {result, err := db.Exec(userSQLString[mysqlUserInsert], name)if err != nil {return err}if rows, _ := result.RowsAffected(); rows == 0 {return errInvalidInsert}return nil}func TxInsertUser(tx *sql.Tx, name string) error {result, err := tx.Exec(userSQLString[mysqlUserInsert], name)if err != nil {return err}if rows, _ := result.RowsAffected(); rows == 0 {return errInvalidInsert}return nil}func SelectUserByID(db *sql.DB, id uint32) (*User, error) {var name stringrow := db.QueryRow(userSQLString[mysqlUserSelectByID], id)if err := row.Scan(&id, &name); err != nil {return nil, err}return &Account{ID: id,Name: name,}, nil}func TxSelectUserByID(tx *sql.Tx, id uint32) (*User, error) {var name stringrow := tx.QueryRow(userSQLString[mysqlUserSelectByID], id)if err := row.Scan(&id, &name); err != nil {return nil, err}return &User{ID: id,Name: name,}, nil}func ListUsers(db *sql.DB) ([]*User, error) {var (Users = make([]*User, 0)ID uint32Name string)rows, err := db.Query(userSQLString[mysqlUserSelectAll])if err != nil {return nil, err}defer rows.Close()for rows.Next() {if err := rows.Scan(&ID, &Name); err != nil {return nil, err}User := &User{ID: ID,Name: Name,}Users = append(Users, User)}return Users, nil}func TxListUsers(tx *sql.Tx) ([]*User, error) {var (Users = make([]*User, 0)ID uint32Name string)rows, err := tx.Query(userSQLString[mysqlUserSelectAll])if err != nil {return nil, err}defer rows.Close()for rows.Next() {if err := rows.Scan(&ID, &Name); err != nil {return nil, err}User := &User{ID: ID,Name: Name,}Users = append(Users, User)}return Users, nil}
使用方法前需要导入 mysql 驱动,在 mysql 驱动相关官网找到 dsn 格式用来初始化 DB 实例,最后调用封装好的方法即可。
package mainimport ("fmt""database/sql"_ "github.com/go-sql-driver/mysql")func main() {dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=%s&parseTime=%t&loc=%s",username, password, host, port, database, charset, true, "Local")db := sql.Open("mysql", dsn)db.SetMaxOpenConns(maxOpenConns)db.SetMaxIdleConns(maxIdleConns)db.SetConnMaxLifetime(time.Duration(maxLifetime) * time.Second)if err := mysql.CreateDatabase(db); err != nil {panic(err)}}
[1] L15: 设置最大打开连接数量来控制最大并发量;
[2] L16: 设置最大空闲连接来控制空闲连接池大小;
[3] L17: 设置连接的最长存活时间。
