193 lines
6.1 KiB
Go
193 lines
6.1 KiB
Go
// Copyright GoFrame Author(https://goframe.org). All Rights Reserved.
|
|
//
|
|
// This Source Code Form is subject to the terms of the MIT License.
|
|
// If a copy of the MIT was not distributed with this file,
|
|
// You can obtain one at https://github.com/gogf/gf.
|
|
|
|
package mssql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"regexp"
|
|
"strings"
|
|
|
|
"git.magicany.cc/black1552/gin-base/database"
|
|
"github.com/gogf/gf/v2/errors/gcode"
|
|
"github.com/gogf/gf/v2/errors/gerror"
|
|
)
|
|
|
|
const (
|
|
// INSERT statement prefixes
|
|
insertPrefixDefault = "INSERT INTO"
|
|
insertPrefixIgnore = "INSERT IGNORE INTO"
|
|
|
|
// Database field attributes
|
|
fieldExtraIdentity = "IDENTITY"
|
|
fieldKeyPrimary = "PRI"
|
|
|
|
// SQL keywords and syntax markers
|
|
outputKeyword = "OUTPUT"
|
|
insertValuesMarker = ") VALUES" // find the position of the string "VALUES" in the INSERT SQL statement to embed output code for retrieving the last inserted ID
|
|
|
|
// Object and field references
|
|
insertedObjectName = "INSERTED"
|
|
|
|
// Result field names and aliases
|
|
affectCountExpression = " 1 as AffectCount"
|
|
lastInsertIdFieldAlias = "ID"
|
|
)
|
|
|
|
// DoExec commits the sql string and its arguments to underlying driver
|
|
// through given link object and returns the execution result.
|
|
func (d *Driver) DoExec(ctx context.Context, link database.Link, sqlStr string, args ...interface{}) (result sql.Result, err error) {
|
|
// Transaction checks.
|
|
if link == nil {
|
|
if tx := database.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
|
// Firstly, check and retrieve transaction link from context.
|
|
link = &txLinkMssql{tx.GetSqlTX()}
|
|
} else if link, err = d.MasterLink(); err != nil {
|
|
// Or else it creates one from master node.
|
|
return nil, err
|
|
}
|
|
} else if !link.IsTransaction() {
|
|
// If current link is not transaction link, it checks and retrieves transaction from context.
|
|
if tx := database.TXFromCtx(ctx, d.GetGroup()); tx != nil {
|
|
link = &txLinkMssql{tx.GetSqlTX()}
|
|
}
|
|
}
|
|
|
|
// SQL filtering.
|
|
sqlStr, args = d.FormatSqlBeforeExecuting(sqlStr, args)
|
|
sqlStr, args, err = d.DoFilter(ctx, link, sqlStr, args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if !strings.HasPrefix(sqlStr, insertPrefixDefault) && !strings.HasPrefix(sqlStr, insertPrefixIgnore) {
|
|
return d.Core.DoExec(ctx, link, sqlStr, args)
|
|
}
|
|
// Find the first position of VALUES marker in the INSERT statement.
|
|
pos := strings.Index(sqlStr, insertValuesMarker)
|
|
|
|
table := d.GetTableNameFromSql(sqlStr)
|
|
outPutSql := d.GetInsertOutputSql(ctx, table)
|
|
// rebuild sql add output
|
|
var (
|
|
sqlValueBefore = sqlStr[:pos+1]
|
|
sqlValueAfter = sqlStr[pos+1:]
|
|
)
|
|
|
|
sqlStr = fmt.Sprintf("%s%s%s", sqlValueBefore, outPutSql, sqlValueAfter)
|
|
|
|
// fmt.Println("sql str:", sqlStr)
|
|
// Link execution.
|
|
var out database.DoCommitOutput
|
|
out, err = d.DoCommit(ctx, database.DoCommitInput{
|
|
Link: link,
|
|
Sql: sqlStr,
|
|
Args: args,
|
|
Stmt: nil,
|
|
Type: database.SqlTypeQueryContext,
|
|
IsTransaction: link.IsTransaction(),
|
|
})
|
|
if err != nil {
|
|
return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err
|
|
}
|
|
stdSqlResult := out.Records
|
|
if len(stdSqlResult) == 0 {
|
|
err = gerror.WrapCode(
|
|
gcode.CodeDbOperationError,
|
|
gerror.New("affected count is zero"),
|
|
`sql.Result.RowsAffected failed`,
|
|
)
|
|
return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err
|
|
}
|
|
// For batch insert, OUTPUT clause returns one row per inserted row.
|
|
// So the rowsAffected should be the count of returned records.
|
|
rowsAffected := int64(len(stdSqlResult))
|
|
// get last_insert_id from the first returned row
|
|
lastInsertId := stdSqlResult[0].GMap().GetVar(lastInsertIdFieldAlias).Int64()
|
|
|
|
return &Result{lastInsertId: lastInsertId, rowsAffected: rowsAffected}, err
|
|
}
|
|
|
|
// GetTableNameFromSql get table name from sql statement
|
|
// It handles table string like:
|
|
// "user"
|
|
// "user u"
|
|
// "DbLog.dbo.user",
|
|
// "user as u".
|
|
func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) {
|
|
// INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?)
|
|
var (
|
|
leftChars, rightChars = d.GetChars()
|
|
trimStr = leftChars + rightChars + "[] "
|
|
pattern = "INTO(.+?)\\("
|
|
regCompile = regexp.MustCompile(pattern)
|
|
tableInfo = regCompile.FindStringSubmatch(sqlStr)
|
|
)
|
|
// get the first one. after the first it may be content of the value, it's not table name.
|
|
table = tableInfo[1]
|
|
table = strings.Trim(table, " ")
|
|
if strings.Contains(table, ".") {
|
|
tmpAry := strings.Split(table, ".")
|
|
// the last one is table name
|
|
table = tmpAry[len(tmpAry)-1]
|
|
} else if strings.Contains(table, "as") || strings.Contains(table, " ") {
|
|
tmpAry := strings.Split(table, "as")
|
|
if len(tmpAry) < 2 {
|
|
tmpAry = strings.Split(table, " ")
|
|
}
|
|
// get the first one
|
|
table = tmpAry[0]
|
|
}
|
|
table = strings.Trim(table, trimStr)
|
|
return table
|
|
}
|
|
|
|
// txLink is used to implement interface Link for TX.
|
|
type txLinkMssql struct {
|
|
*sql.Tx
|
|
}
|
|
|
|
// IsTransaction returns if current Link is a transaction.
|
|
func (l *txLinkMssql) IsTransaction() bool {
|
|
return true
|
|
}
|
|
|
|
// IsOnMaster checks and returns whether current link is operated on master node.
|
|
// Note that, transaction operation is always operated on master node.
|
|
func (l *txLinkMssql) IsOnMaster() bool {
|
|
return true
|
|
}
|
|
|
|
// GetInsertOutputSql gen get last_insert_id code
|
|
func (d *Driver) GetInsertOutputSql(ctx context.Context, table string) string {
|
|
fds, errFd := d.GetDB().TableFields(ctx, table)
|
|
if errFd != nil {
|
|
return ""
|
|
}
|
|
extraSqlAry := make([]string, 0)
|
|
extraSqlAry = append(extraSqlAry, fmt.Sprintf(" %s %s", outputKeyword, affectCountExpression))
|
|
incrNo := 0
|
|
if len(fds) > 0 {
|
|
for _, fd := range fds {
|
|
// has primary key and is auto-increment
|
|
if fd.Extra == fieldExtraIdentity && fd.Key == fieldKeyPrimary && !fd.Null {
|
|
incrNoStr := ""
|
|
if incrNo == 0 { // fixed first field named id, convenient to get
|
|
incrNoStr = fmt.Sprintf(" as %s", lastInsertIdFieldAlias)
|
|
}
|
|
|
|
extraSqlAry = append(extraSqlAry, fmt.Sprintf("%s.%s%s", insertedObjectName, fd.Name, incrNoStr))
|
|
incrNo++
|
|
}
|
|
// fmt.Printf("null:%t name:%s key:%s k:%s \n", fd.Null, fd.Name, fd.Key, k)
|
|
}
|
|
}
|
|
return strings.Join(extraSqlAry, ",")
|
|
// sql example:INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?)
|
|
}
|