// Copyright GoFrame Author(https://goframe.org). All Rights Reserved. // // This Source Code Form is subject to the terms of the MIT License. // If a copy of the MIT was not distributed with this file, // You can obtain one at https://github.com/gogf/gf. package mssql import ( "context" "database/sql" "fmt" "regexp" "strings" "git.magicany.cc/black1552/gin-base/database" "github.com/gogf/gf/v2/errors/gcode" "github.com/gogf/gf/v2/errors/gerror" ) const ( // INSERT statement prefixes insertPrefixDefault = "INSERT INTO" insertPrefixIgnore = "INSERT IGNORE INTO" // Database field attributes fieldExtraIdentity = "IDENTITY" fieldKeyPrimary = "PRI" // SQL keywords and syntax markers outputKeyword = "OUTPUT" insertValuesMarker = ") VALUES" // find the position of the string "VALUES" in the INSERT SQL statement to embed output code for retrieving the last inserted ID // Object and field references insertedObjectName = "INSERTED" // Result field names and aliases affectCountExpression = " 1 as AffectCount" lastInsertIdFieldAlias = "ID" ) // DoExec commits the sql string and its arguments to underlying driver // through given link object and returns the execution result. func (d *Driver) DoExec(ctx context.Context, link database.Link, sqlStr string, args ...interface{}) (result sql.Result, err error) { // Transaction checks. if link == nil { if tx := database.TXFromCtx(ctx, d.GetGroup()); tx != nil { // Firstly, check and retrieve transaction link from context. link = &txLinkMssql{tx.GetSqlTX()} } else if link, err = d.MasterLink(); err != nil { // Or else it creates one from master node. return nil, err } } else if !link.IsTransaction() { // If current link is not transaction link, it checks and retrieves transaction from context. if tx := database.TXFromCtx(ctx, d.GetGroup()); tx != nil { link = &txLinkMssql{tx.GetSqlTX()} } } // SQL filtering. sqlStr, args = d.FormatSqlBeforeExecuting(sqlStr, args) sqlStr, args, err = d.DoFilter(ctx, link, sqlStr, args) if err != nil { return nil, err } if !strings.HasPrefix(sqlStr, insertPrefixDefault) && !strings.HasPrefix(sqlStr, insertPrefixIgnore) { return d.Core.DoExec(ctx, link, sqlStr, args) } // Find the first position of VALUES marker in the INSERT statement. pos := strings.Index(sqlStr, insertValuesMarker) table := d.GetTableNameFromSql(sqlStr) outPutSql := d.GetInsertOutputSql(ctx, table) // rebuild sql add output var ( sqlValueBefore = sqlStr[:pos+1] sqlValueAfter = sqlStr[pos+1:] ) sqlStr = fmt.Sprintf("%s%s%s", sqlValueBefore, outPutSql, sqlValueAfter) // fmt.Println("sql str:", sqlStr) // Link execution. var out database.DoCommitOutput out, err = d.DoCommit(ctx, database.DoCommitInput{ Link: link, Sql: sqlStr, Args: args, Stmt: nil, Type: database.SqlTypeQueryContext, IsTransaction: link.IsTransaction(), }) if err != nil { return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err } stdSqlResult := out.Records if len(stdSqlResult) == 0 { err = gerror.WrapCode( gcode.CodeDbOperationError, gerror.New("affected count is zero"), `sql.Result.RowsAffected failed`, ) return &Result{lastInsertId: 0, rowsAffected: 0, err: err}, err } // For batch insert, OUTPUT clause returns one row per inserted row. // So the rowsAffected should be the count of returned records. rowsAffected := int64(len(stdSqlResult)) // get last_insert_id from the first returned row lastInsertId := stdSqlResult[0].GMap().GetVar(lastInsertIdFieldAlias).Int64() return &Result{lastInsertId: lastInsertId, rowsAffected: rowsAffected}, err } // GetTableNameFromSql get table name from sql statement // It handles table string like: // "user" // "user u" // "DbLog.dbo.user", // "user as u". func (d *Driver) GetTableNameFromSql(sqlStr string) (table string) { // INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?) var ( leftChars, rightChars = d.GetChars() trimStr = leftChars + rightChars + "[] " pattern = "INTO(.+?)\\(" regCompile = regexp.MustCompile(pattern) tableInfo = regCompile.FindStringSubmatch(sqlStr) ) // get the first one. after the first it may be content of the value, it's not table name. table = tableInfo[1] table = strings.Trim(table, " ") if strings.Contains(table, ".") { tmpAry := strings.Split(table, ".") // the last one is table name table = tmpAry[len(tmpAry)-1] } else if strings.Contains(table, "as") || strings.Contains(table, " ") { tmpAry := strings.Split(table, "as") if len(tmpAry) < 2 { tmpAry = strings.Split(table, " ") } // get the first one table = tmpAry[0] } table = strings.Trim(table, trimStr) return table } // txLink is used to implement interface Link for TX. type txLinkMssql struct { *sql.Tx } // IsTransaction returns if current Link is a transaction. func (l *txLinkMssql) IsTransaction() bool { return true } // IsOnMaster checks and returns whether current link is operated on master node. // Note that, transaction operation is always operated on master node. func (l *txLinkMssql) IsOnMaster() bool { return true } // GetInsertOutputSql gen get last_insert_id code func (d *Driver) GetInsertOutputSql(ctx context.Context, table string) string { fds, errFd := d.GetDB().TableFields(ctx, table) if errFd != nil { return "" } extraSqlAry := make([]string, 0) extraSqlAry = append(extraSqlAry, fmt.Sprintf(" %s %s", outputKeyword, affectCountExpression)) incrNo := 0 if len(fds) > 0 { for _, fd := range fds { // has primary key and is auto-increment if fd.Extra == fieldExtraIdentity && fd.Key == fieldKeyPrimary && !fd.Null { incrNoStr := "" if incrNo == 0 { // fixed first field named id, convenient to get incrNoStr = fmt.Sprintf(" as %s", lastInsertIdFieldAlias) } extraSqlAry = append(extraSqlAry, fmt.Sprintf("%s.%s%s", insertedObjectName, fd.Name, incrNoStr)) incrNo++ } // fmt.Printf("null:%t name:%s key:%s k:%s \n", fd.Null, fd.Name, fd.Key, k) } } return strings.Join(extraSqlAry, ",") // sql example:INSERT INTO "ip_to_id"("ip") OUTPUT 1 as AffectCount,INSERTED.id as ID VALUES(?) }