gin-base/database/drivers/mssql/mssql_do_exec.go

193 lines
6.1 KiB
Go

// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
//
// This Source Code Form is subject to the terms of the MIT License.
// If a copy of the MIT was not distributed with this file,
// You can obtain one at https://github.com/gogf/gf.
package mssql
import (
"context"
"database/sql"
"fmt"
"regexp"
"strings"
"git.magicany.cc/black1552/gin-base/database"
"github.com/gogf/gf/v2/errors/gcode"
"github.com/gogf/gf/v2/errors/gerror"
)
const (
// INSERT statement prefixes
insertPrefixDefault = "INSERT INTO"
insertPrefixIgnore = "INSERT IGNORE INTO"
// Database field attributes
fieldExtraIdentity = "IDENTITY"
fieldKeyPrimary = "PRI"
// SQL keywords and syntax markers
outputKeyword = "OUTPUT"
insertValuesMarker = ") VALUES" // find the position of the string "VALUES" in the INSERT SQL statement to embed output code for retrieving the last inserted ID
// Object and field references
insertedObjectName = "INSERTED"
// Result field names and aliases
affectCountExpression = " 1 as AffectCount"
lastInsertIdFieldAlias = "ID"
)
// DoExec commits the sql string and its arguments to underlying driver
// through given link object and returns the execution result.
func (d *Driver) DoExec(ctx context.Context, link database.Link, sqlStr string, args ...interface{}) (result sql.Result, err error) {
// Transaction checks.
if link == nil {
if tx := database.TXFromCtx(ctx, d.GetGroup()); tx != nil {
// Firstly, check and retrieve transaction link from context.
link = &txLinkMssql{tx.GetSqlTX()}
} else if link, err = d.MasterLink(); err != nil {
// Or else it creates one from master node.
return nil, err
}
} else if !link.IsTransaction() {
// If current link is not transaction link, it checks and retrieves transaction from context.
if tx := database.TXFromCtx(ctx, d.GetGroup()); tx != nil {
link = &txLinkMssql{tx.GetSqlTX()}
}
}
// SQL filtering.
sqlStr, args = d.FormatSqlBeforeExecuting(sqlStr, args)
sqlStr, args, err = d.DoFilter(ctx, link, sqlStr, args)
if err != nil {
return nil, err
}
if !strings.HasPrefix(sqlStr, insertPrefixDefault) && !strings.HasPrefix(sqlStr, insertPrefixIgnore) {
return d.Core.DoExec(ctx, link, sqlStr, args)
}
// Find the first position of VALUES marker in the INSERT statement.
pos := strings.Index(sqlStr, insertValuesMarker)
table := d.GetTableNameFromSql(sqlStr)
outPutSql := d.GetInsertOutputSql(ctx, table)
// rebuild sql add output
var (
sqlValueBefore = sqlStr[:pos+1]
sqlValueAfter = sqlStr[pos+1:]
)
sqlStr = fmt.Sprintf("%s%s%s", sqlValueBefore, outPutSql, sqlValueAfter)
// fmt.Println("sql str:", sqlStr)
// Link execution.
var out database.DoCommitOutput
out, err = d.DoCommit(ctx, database.DoCommitInput{
Link: link,
Sql: sqlStr,
Args: args,
Stmt: nil,
Type: database.SqlTypeQueryContext,
IsTransaction: link.IsTransaction(),
})
if err != nil {
return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err
}
stdSqlResult := out.Records
if len(stdSqlResult) == 0 {
err = gerror.WrapCode(
gcode.CodeDbOperationError,
gerror.New("affected count is zero"),
`sql.Result.RowsAffected failed`,
)
return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err
}
// For batch insert, OUTPUT clause returns one row per inserted row.
// So the rowsAffected should be the count of returned records.
rowsAffected := int64(len(stdSqlResult))
// get last_insert_id from the first returned row
lastInsertId := stdSqlResult[0].GMap().GetVar(lastInsertIdFieldAlias).Int64()
return &Result{lastInsertId: lastInsertId, rowsAffected: rowsAffected}, err
}
// GetTableNameFromSql get table name from sql statement
// It handles table string like:
// "user"
// "user u"
// "DbLog.dbo.user",
// "user as u".
func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) {
// INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?)
var (
leftChars, rightChars = d.GetChars()
trimStr = leftChars + rightChars + "[] "
pattern = "INTO(.+?)\\("
regCompile = regexp.MustCompile(pattern)
tableInfo = regCompile.FindStringSubmatch(sqlStr)
)
// get the first one. after the first it may be content of the value, it's not table name.
table = tableInfo[1]
table = strings.Trim(table, " ")
if strings.Contains(table, ".") {
tmpAry := strings.Split(table, ".")
// the last one is table name
table = tmpAry[len(tmpAry)-1]
} else if strings.Contains(table, "as") || strings.Contains(table, " ") {
tmpAry := strings.Split(table, "as")
if len(tmpAry) < 2 {
tmpAry = strings.Split(table, " ")
}
// get the first one
table = tmpAry[0]
}
table = strings.Trim(table, trimStr)
return table
}
// txLink is used to implement interface Link for TX.
type txLinkMssql struct {
*sql.Tx
}
// IsTransaction returns if current Link is a transaction.
func (l *txLinkMssql) IsTransaction() bool {
return true
}
// IsOnMaster checks and returns whether current link is operated on master node.
// Note that, transaction operation is always operated on master node.
func (l *txLinkMssql) IsOnMaster() bool {
return true
}
// GetInsertOutputSql gen get last_insert_id code
func (d *Driver) GetInsertOutputSql(ctx context.Context, table string) string {
fds, errFd := d.GetDB().TableFields(ctx, table)
if errFd != nil {
return ""
}
extraSqlAry := make([]string, 0)
extraSqlAry = append(extraSqlAry, fmt.Sprintf(" %s %s", outputKeyword, affectCountExpression))
incrNo := 0
if len(fds) > 0 {
for _, fd := range fds {
// has primary key and is auto-increment
if fd.Extra == fieldExtraIdentity && fd.Key == fieldKeyPrimary && !fd.Null {
incrNoStr := ""
if incrNo == 0 { // fixed first field named id, convenient to get
incrNoStr = fmt.Sprintf(" as %s", lastInsertIdFieldAlias)
}
extraSqlAry = append(extraSqlAry, fmt.Sprintf("%s.%s%s", insertedObjectName, fd.Name, incrNoStr))
incrNo++
}
// fmt.Printf("null:%t name:%s key:%s k:%s \n", fd.Null, fd.Name, fd.Key, k)
}
}
return strings.Join(extraSqlAry, ",")
// sql example:INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?)
}