package common

import (
	"database/sql"
	"gin-vue-admin/global"

	// import mysql
	_ "github.com/go-sql-driver/mysql"
)

var (
	db       *sql.DB
	commonDB *sql.DB
	err      error
)

// NewAccountConn is 连接医院用户库
func NewAccountConn() (*sql.DB, error) {
	if db == nil {
		db, err = sql.Open("mysql", global.GVA_CONFIG.DoctorDb.Account)
		return db, err
	}
	return db, nil
}

// NewCommonConn is 连接公共用户库
func NewCommonConn() (*sql.DB, error) {
	if commonDB == nil {
		commonDB, err = sql.Open("mysql", global.GVA_CONFIG.DoctorDb.Common)
		return commonDB, err
	}
	return commonDB, nil
}

// GetResultRow is 获取返回值, 获取一条
func GetResultRow(rows *sql.Rows) map[string]string {
	columns, _ := rows.Columns()
	scanArgs := make([]interface{}, len(columns))
	values := make([][]byte, len(columns))
	for j := range values {
		scanArgs[j] = &values[j]
	}
	record := make(map[string]string)
	for rows.Next() {
		rows.Scan(scanArgs...)
		for i, v := range values {
			if v != nil {
				record[columns[i]] = string(v)
			}
		}
	}
	return record
}

// GetResultRows is 获取所有
func GetResultRows(rows *sql.Rows) map[int]map[string]string {
	columns, _ := rows.Columns()
	values := make([][]byte, len(columns))
	scans := make([]interface{}, len(columns))
	for k, _ := range values {
		scans[k] = &values[k]
	}
	i := 0
	result := make(map[int]map[string]string)
	for rows.Next() {
		rows.Scan(scans...)
		row := make(map[string]string)
		for k, v := range values {
			key := columns[k]
			row[key] = string(v)
		}
		result[i] = row
		i++
	}
	return result
}

// CloseTx is 有错误则关闭事务、没错则提交事务
func CloseTx(tx *sql.Tx, err error) {
	if err != nil {
		tx.Rollback()
	}
	tx.Commit()
}