injectionCode.go 4.28 KB
Newer Older
haoyanbin's avatar
1  
haoyanbin committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
package utils

import (
	"fmt"
	"go/ast"
	"go/parser"
	"go/token"
	"io/ioutil"
	"strings"
)

//@author: [LeonardWang](https://github.com/WangLeonard)
//@function: AutoInjectionCode
//@description: 向文件中固定注释位置写入代码
//@param: filepath string, funcName string, codeData string
//@return: error

func AutoInjectionCode(filepath string, funcName string, codeData string) error {
	startComment := "Code generated by gin-vue-admin Begin; DO NOT EDIT."
	endComment := "Code generated by gin-vue-admin End; DO NOT EDIT."
	srcData, err := ioutil.ReadFile(filepath)
	if err != nil {
		return err
	}
	srcDataLen := len(srcData)
	fset := token.NewFileSet()
	fparser, err := parser.ParseFile(fset, filepath, srcData, parser.ParseComments)
	if err != nil {
		return err
	}
	codeData = strings.TrimSpace(codeData)
	var codeStartPos = -1
	var codeEndPos = srcDataLen
	var expectedFunction *ast.FuncDecl

	var startCommentPos = -1
	var endCommentPos = srcDataLen

	// 如果指定了函数名,先寻找对应函数
	if funcName != "" {
		for _, decl := range fparser.Decls {
			if funDecl, ok := decl.(*ast.FuncDecl); ok && funDecl.Name.Name == funcName {
				expectedFunction = funDecl
				codeStartPos = int(funDecl.Body.Lbrace)
				codeEndPos = int(funDecl.Body.Rbrace)
				break
			}
		}
	}

	// 遍历所有注释
	for _, comment := range fparser.Comments {
		if int(comment.Pos()) > codeStartPos && int(comment.End()) <= codeEndPos {
			if startComment != "" && strings.Contains(comment.Text(), startComment) {
				startCommentPos = int(comment.Pos()) // Note: Pos is the second '/'
			}
			if endComment != "" && strings.Contains(comment.Text(), endComment) {
				endCommentPos = int(comment.Pos()) // Note: Pos is the second '/'
			}
		}
	}

	if endCommentPos == srcDataLen {
		return fmt.Errorf("comment:%s not found", endComment)
	}

	// 在指定函数名,且函数中startComment和endComment都存在时,进行区间查重
	if (codeStartPos != -1 && codeEndPos <= srcDataLen) && (startCommentPos != -1 && endCommentPos != srcDataLen) && expectedFunction != nil {
		if exist := checkExist(&srcData, startCommentPos, endCommentPos, expectedFunction.Body, codeData); exist {
			fmt.Printf("文件 %s 待插入数据 %s 已存在\n", filepath, codeData)
			return nil // 这里不需要返回错误?
		}
	}

	// 两行注释中间没有换行时,会被认为是一条Comment
	if startCommentPos == endCommentPos {
		endCommentPos = startCommentPos + strings.Index(string(srcData[startCommentPos:]), endComment)
		for srcData[endCommentPos] != '/' {
			endCommentPos--
		}
	}

	// 记录"//"之前的空字符,保持写入后的格式一致
	tmpSpace := make([]byte, 0, 10)
	for tmp := endCommentPos - 2; tmp >= 0; tmp-- {
		if srcData[tmp] != '\n' {
			tmpSpace = append(tmpSpace, srcData[tmp])
		} else {
			break
		}
	}

	reverseSpace := make([]byte, 0, len(tmpSpace))
	for index := len(tmpSpace) - 1; index >= 0; index-- {
		reverseSpace = append(reverseSpace, tmpSpace[index])
	}

	// 插入数据
	indexPos := endCommentPos - 1
	insertData := []byte(append([]byte(codeData+"\n"), reverseSpace...))

	remainData := append([]byte{}, srcData[indexPos:]...)
	srcData = append(append(srcData[:indexPos], insertData...), remainData...)

	// 写回数据
	return ioutil.WriteFile(filepath, srcData, 0600)
}

func checkExist(srcData *[]byte, startPos int, endPos int, blockStmt *ast.BlockStmt, target string) bool {
	for _, list := range blockStmt.List {
		switch stmt := list.(type) {
		case *ast.ExprStmt:
			if callExpr, ok := stmt.X.(*ast.CallExpr); ok &&
				int(callExpr.Pos()) > startPos && int(callExpr.End()) < endPos {
				text := string((*srcData)[int(callExpr.Pos()-1):int(callExpr.End())])
				key := strings.TrimSpace(text)
				if key == target {
					return true
				}
			}
		case *ast.BlockStmt:
			if checkExist(srcData, startPos, endPos, stmt, target) {
				return true
			}
		case *ast.AssignStmt:
			// 为 model 中的代码进行检查
			if len(stmt.Rhs) > 0 {
				if callExpr, ok := stmt.Rhs[0].(*ast.CallExpr); ok {
					for _, arg := range callExpr.Args {
						if int(arg.Pos()) > startPos && int(arg.End()) < endPos {
							text := string((*srcData)[int(arg.Pos()-1):int(arg.End())])
							key := strings.TrimSpace(text)
							if key == target {
								return true
							}
						}
					}
				}
			}
		}
	}
	return false
}