diff --git a/README.md b/README.md index 871ebfe..1f8778d 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,6 @@ Bind9不能直接支持API的方式添加解析记录, 通过脚本修改Bind - gcc - go version >= 1.20 -## 接口文档 - -[在线阅读](https://oss.odboy.cn/blog/files/onlinedoc/kenaito-dns/index.html) - ## 项目结构 - constant 常量 @@ -43,6 +39,9 @@ more than 7 hours - 支持回滚 2024-11-08 [ok] - 添加缓存 2024-11-09 [ok] +- 新增Web控制台 2024-11-11 [ok] +- 支持一键启/停用 2024-11-11 [ok] +- 支持一键回滚 2024-11-11 [ok] ## 运行配置 @@ -143,7 +142,8 @@ nslookup example.com 192.168.1.103 ## 代码托管(以私人仓库Gitea为准) -- Gitea: [https://gitea.odboy.cn/odboy/kenaito-dns](https://gitea.odboy.cn/odboy/kenaito-dns) +- Gitea后端: [https://gitea.odboy.cn/odboy/kenaito-dns](https://gitea.odboy.cn/odboy/kenaito-dns) +- Gitea前端: [https://gitea.odboy.cn/odboy/kenaito-dns-front](https://gitea.odboy.cn/odboy/kenaito-dns-front) - Github: [https://github.com/odboy-tianjun/kenaito-dns](https://github.com/odboy-tianjun/kenaito-dns) - Gitee(已关闭,单纯的不想放在gitee): [https://gitee.com/odboy/kenaito-dns](https://gitee.com/odboy/kenaito-dns) diff --git a/cache/resolve_record.go b/cache/resolve_record.go index a1fa267..d4fa9c5 100644 --- a/cache/resolve_record.go +++ b/cache/resolve_record.go @@ -15,7 +15,7 @@ func ReloadCache() { fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache start") KeyResolveRecordMap.Range(cleanKeyCache) IdResolveRecordMap.Range(cleanIdCache) - resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion()) + resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion(), false) for _, record := range resolveRecords { // id -> resolveRecord IdResolveRecordMap.Store(record.Id, record) diff --git a/config/app.go b/config/app.go index 4c65cba..781a178 100644 --- a/config/app.go +++ b/config/app.go @@ -1,6 +1,7 @@ package config const ( - AppVersion = "1.0.0" - AppTimeFormat = "2006/01/02 15:04:05.999999" + AppVersion = "1.0.0" + AppTimeFormat = "2006/01/02 15:04:05.999999" + DataTimeFormat = "2006-01-02 15:04:05" ) diff --git a/controller/resolve_record.go b/controller/resolve_record.go index 947aef1..e2d75a7 100644 --- a/controller/resolve_record.go +++ b/controller/resolve_record.go @@ -21,9 +21,60 @@ func InitRestFunc(r *gin.Engine) { // 健康检查 r.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ - "message": "pong", + "code": 0, + "message": "success", + "data": "pong", }) }) + // 测试解析状态 + r.POST("/test", func(c *gin.Context) { + var jsonObj domain.TestArgs + err := c.ShouldBindJSON(&jsonObj) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("校验失败, %v", err)}) + return + } + name := jsonObj.Name + valid := util.IsValidDomain(name) + if !valid { + c.JSON(http.StatusBadRequest, gin.H{"message": "域名解析失败"}) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "域名解析成功", + }) + }) + // 启停记录 + r.POST("/switch", func(c *gin.Context) { + var jsonObj domain.SwitchArgs + err := c.ShouldBindJSON(&jsonObj) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("校验失败, %v", err)}) + return + } + _, err = dao.SwitchResolveRecord(jsonObj.Id, jsonObj.Enabled) + if err != nil { + if jsonObj.Enabled == 1 { + c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("启用失败, %v", err)}) + } else { + c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("停用失败, %v", err)}) + } + return + } + cache.ReloadCache() + if jsonObj.Enabled == 1 { + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "启用成功", + }) + } else { + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "停用成功", + }) + } + }) // 创建RR记录 r.POST("/create", func(c *gin.Context) { var jsonObj domain.CreateResolveRecord @@ -33,7 +84,9 @@ func InitRestFunc(r *gin.Engine) { return } if dao.IsResolveRecordExist(newRecord) { - c.JSON(http.StatusBadRequest, gin.H{"message": "记录 " + newRecord.Name + " " + newRecord.RecordType + " " + newRecord.Value + " 已存在"}) + c.JSON(http.StatusBadRequest, gin.H{ + "message": "记录 " + newRecord.Name + " " + newRecord.RecordType + " " + newRecord.Value + " 已存在", + }) return } newRecord.Ttl = jsonObj.Ttl @@ -52,8 +105,9 @@ func InitRestFunc(r *gin.Engine) { body["oldVersion"] = oldVersion body["newVersion"] = newVersion c.JSON(http.StatusOK, gin.H{ + "code": 0, "message": "添加" + newRecord.RecordType + "记录成功", - "body": body, + "data": body, }) return }) @@ -84,8 +138,9 @@ func InitRestFunc(r *gin.Engine) { body["oldVersion"] = oldVersion body["newVersion"] = newVersion c.JSON(http.StatusOK, gin.H{ + "code": 0, "message": "删除" + newRecord.RecordType + "记录成功", - "body": body, + "data": body, }) return }) @@ -128,6 +183,7 @@ func InitRestFunc(r *gin.Engine) { updRecord.RecordType = newRecord.RecordType updRecord.Ttl = newRecord.Ttl updRecord.Value = newRecord.Value + updRecord.CreateTime = localNewRecord.CreateTime executeResult, err = dao.ModifyResolveRecordById(localNewRecord.Id, updRecord) if !executeResult { c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("更新"+newRecord.RecordType+"记录失败, %v", err)}) @@ -138,8 +194,9 @@ func InitRestFunc(r *gin.Engine) { body["oldVersion"] = oldVersion body["newVersion"] = newVersion c.JSON(http.StatusOK, gin.H{ + "code": 0, "message": "更新" + newRecord.RecordType + "记录成功", - "body": body, + "data": body, }) return }) @@ -152,7 +209,13 @@ func InitRestFunc(r *gin.Engine) { return } records := dao.FindResolveRecordPage(jsonObj.Page, jsonObj.PageSize, &jsonObj) - c.JSON(http.StatusOK, gin.H{"message": "分页查询RR记录成功", "body": records}) + count := dao.CountResolveRecordPage(jsonObj.Page, jsonObj.PageSize, &jsonObj) + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "分页查询RR记录成功", + "data": records, + "count": count, + }) return }) // 根据id查询RR记录明细 @@ -164,13 +227,21 @@ func InitRestFunc(r *gin.Engine) { return } records := dao.FindResolveRecordById(jsonObj.Id) - c.JSON(http.StatusOK, gin.H{"message": "根据id查询RR记录明细成功", "body": records}) + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "根据id查询RR记录明细成功", + "data": records, + }) return }) // 查询变更历史记录 r.POST("/queryVersionList", func(c *gin.Context) { records := dao.FindResolveVersion() - c.JSON(http.StatusOK, gin.H{"message": "查询变更历史记录列表成功", "body": records}) + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "查询变更历史记录列表成功", + "data": records, + }) return }) // 回滚到某一版本 @@ -181,7 +252,7 @@ func InitRestFunc(r *gin.Engine) { c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("校验失败, %v", err)}) return } - versions := dao.FindResolveRecordByVersion(jsonObj.Version) + versions := dao.FindResolveRecordByVersion(jsonObj.Version, true) if len(versions) == 0 { c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("版本号 %d 不存在, 回滚失败", jsonObj.Version)}) return @@ -195,8 +266,9 @@ func InitRestFunc(r *gin.Engine) { body := make(map[string]interface{}) body["currentVersion"] = jsonObj.Version c.JSON(http.StatusOK, gin.H{ + "code": 0, "message": "回滚成功", - "body": body, + "data": body, }) return }) diff --git a/core/gin_cors.go b/core/gin_cors.go new file mode 100644 index 0000000..7ef4ef7 --- /dev/null +++ b/core/gin_cors.go @@ -0,0 +1,42 @@ +package core + +import ( + "fmt" + "github.com/gin-gonic/gin" + "net/http" + "strings" +) + +// Cors 跨域中间件 +func Cors() gin.HandlerFunc { + return func(c *gin.Context) { + method := c.Request.Method // 请求方法 + origin := c.Request.Header.Get("Origin") // 请求头部 + var headerKeys []string // 声明请求头keys + for k, _ := range c.Request.Header { + headerKeys = append(headerKeys, k) + } + headerStr := strings.Join(headerKeys, ", ") + if headerStr != "" { + headerStr = fmt.Sprintf("access-control-allow-origin, access-control-allow-headers, %s", headerStr) + } else { + headerStr = "access-control-allow-origin, access-control-allow-headers" + } + if origin != "" { + c.Writer.Header().Set("Access-Control-Allow-Origin", "*") + c.Header("Access-Control-Allow-Origin", "*") // 允许访问所有域 + c.Header("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE, UPDATE") // 服务器支持的所有跨域请求的方法 + c.Header("Access-Control-Allow-Headers", "Authorization, Content-Length, X-CSRF-Token, Token, session, X_Requested_With, Accept, Origin, Host, Connection, Accept-Encoding, Accept-Language, DNT, X-CustomHeader, Keep-Alive, User-Agent, X-Requested-With, If-Modified-Since, Cache-Control, Content-Type, Pragma") // 允许的头类型 + c.Header("Access-Control-Expose-Headers", "Content-Length, Access-Control-Allow-Origin, Access-Control-Allow-Headers, Cache-Control, Content-Language, Content-Type, Expires, Last-Modified, Pragma, FooBar") // 允许跨域设置,可以返回其他子段 + c.Header("Access-Control-Max-Age", "172800") // 缓存请求信息,单位为秒 + c.Header("Access-Control-Allow-Credentials", "false") // 跨域请求是否需要带cookie信息,默认设置为true + c.Set("content-type", "application/json;charset=utf8") // 设置返回格式是json + } + // 放行所有OPTIONS方法 + if method == "OPTIONS" { + c.JSON(http.StatusOK, "Options Request!") + } + // 处理请求 + c.Next() + } +} diff --git a/dao/resolve_record.go b/dao/resolve_record.go index 0df68b5..3e78a1e 100644 --- a/dao/resolve_record.go +++ b/dao/resolve_record.go @@ -7,9 +7,11 @@ package dao */ import ( "fmt" + "kenaito-dns/config" "kenaito-dns/domain" "kenaito-dns/util" "strings" + "time" ) type ResolveRecord struct { @@ -19,6 +21,9 @@ type ResolveRecord struct { Ttl int `xorm:"not null integer 'ttl'" json:"ttl"` Value string `xorm:"not null text 'value'" json:"value"` Version int `xorm:"not null integer 'version'" json:"version"` + CreateTime string `xorm:"not null text 'create_time'" json:"createTime"` + UpdateTime string `xorm:"not null text 'update_time'" json:"updateTime"` + Enabled int `xorm:"not null integer 'enabled'" json:"enabled"` } func (ResolveRecord) TableName() string { @@ -50,9 +55,14 @@ func FindOneResolveRecord(wrapper *ResolveRecord, version int) *ResolveRecord { return &record } -func FindResolveRecordByVersion(version int) []ResolveRecord { +func FindResolveRecordByVersion(version int, isAll bool) []ResolveRecord { var records []ResolveRecord - err := Engine.Table("resolve_record").Where("`version` = ?", version).Find(&records) + session := Engine.Table("resolve_record") + session.Where("`version` = ?", version) + if !isAll { + session.Where("`enabled` = ?", 1) + } + err := session.Find(&records) if err != nil { fmt.Println(err) } @@ -102,8 +112,44 @@ func FindResolveRecordPage(pageNo int, pageSize int, args *domain.QueryPageArgs) } return records } +func CountResolveRecordPage(pageNo int, pageSize int, args *domain.QueryPageArgs) int { + // 每页显示5条记录 + if pageSize <= 5 { + pageSize = 5 + } + // 要查询的页码 + if pageNo <= 0 { + pageNo = 1 + } + // 计算跳过的记录数 + offset := (pageNo - 1) * pageSize + session := Engine.Table("resolve_record").Where("") + if args != nil { + if !util.IsBlank(args.Name) { + qs := "%" + strings.TrimSpace(args.Name) + "%" + session.And("`name` LIKE ?", qs) + } + if !util.IsBlank(args.Type) { + qs := strings.TrimSpace(args.Type) + session.And("`record_type` = ?", qs) + } + if !util.IsBlank(args.Value) { + qs := strings.TrimSpace(args.Value) + session.And("`value` = ?", qs) + } + } + session.And("`version` = ?", GetResolveVersion()) + count, err := session.Limit(pageSize, offset).Count() + if err != nil { + fmt.Println(err) + } + return int(count) +} func SaveResolveRecord(wrapper *ResolveRecord) (bool, error) { + wrapper.CreateTime = time.Now().Format(config.DataTimeFormat) + wrapper.UpdateTime = time.Now().Format(config.DataTimeFormat) + wrapper.Enabled = 1 _, err := Engine.Table("resolve_record").Insert(wrapper) if err != nil { fmt.Println(err) @@ -116,7 +162,7 @@ func BackupResolveRecord(record *ResolveRecord) (bool, error, int, int) { var backupRecords []*ResolveRecord oldVersion := GetResolveVersion() newVersion := GetResolveVersion() + 1 - oldRecords := FindResolveRecordByVersion(oldVersion) + oldRecords := FindResolveRecordByVersion(oldVersion, true) for _, oldRecord := range oldRecords { newRecord := new(ResolveRecord) newRecord.Name = oldRecord.Name @@ -124,6 +170,9 @@ func BackupResolveRecord(record *ResolveRecord) (bool, error, int, int) { newRecord.Ttl = oldRecord.Ttl newRecord.Value = oldRecord.Value newRecord.Version = newVersion + newRecord.CreateTime = oldRecord.CreateTime + newRecord.UpdateTime = oldRecord.UpdateTime + newRecord.Enabled = oldRecord.Enabled backupRecords = append(backupRecords, newRecord) } record.Version = newVersion @@ -180,6 +229,21 @@ func IsUpdResolveRecordExist(id int, wrapper *ResolveRecord) bool { } func ModifyResolveRecordById(id int, updateRecord *ResolveRecord) (bool, error) { + updateRecord.UpdateTime = time.Now().Format(config.DataTimeFormat) + wrapper := new(ResolveRecord) + wrapper.Id = id + _, err := Engine.Table("resolve_record").Update(updateRecord, wrapper) + if err != nil { + fmt.Println(err) + return false, err + } + return true, nil +} + +func SwitchResolveRecord(id int, enabled int) (bool, error) { + var updateRecord ResolveRecord + updateRecord.UpdateTime = time.Now().Format(config.DataTimeFormat) + updateRecord.Enabled = enabled wrapper := new(ResolveRecord) wrapper.Id = id _, err := Engine.Table("resolve_record").Update(updateRecord, wrapper) diff --git a/dns.sqlite3 b/dns.sqlite3 index 1880b6a..88e4cde 100644 Binary files a/dns.sqlite3 and b/dns.sqlite3 differ diff --git a/domain/resolve_record.go b/domain/resolve_record.go index 40245fd..bd9a5d9 100644 --- a/domain/resolve_record.go +++ b/domain/resolve_record.go @@ -34,6 +34,15 @@ type QueryPageArgs struct { Value string `json:"value"` } +type TestArgs struct { + Name string `json:"name" binding:"required"` +} + +type SwitchArgs struct { + Id int `json:"id" binding:"required"` + Enabled int `json:"enabled" binding:"required"` +} + type QueryByIdArgs struct { Id int `json:"id" binding:"required"` } diff --git a/main.go b/main.go index 47e8be7..32239d7 100644 --- a/main.go +++ b/main.go @@ -66,6 +66,8 @@ func initRestfulServer() { param.ErrorMessage, ) })) + // 允许使用跨域请求,全局中间件 + router.Use(core.Cors()) // 使用 Recovery 中间件,处理任何出现的错误,并防止服务崩溃 router.Use(gin.Recovery()) server := &http.Server{ diff --git a/util/strtool.go b/util/strtool.go index d29b8bf..ad7ce85 100644 --- a/util/strtool.go +++ b/util/strtool.go @@ -6,9 +6,14 @@ package util * @Date 20241107 */ import ( + "fmt" + "golang.org/x/net/context" + "kenaito-dns/config" + "log" "net" "regexp" "strings" + "time" ) // IsBlank 检查字符串是否空 @@ -38,9 +43,49 @@ func IsIPv6(ipAddr string) bool { // IsValidDomain 判断域名是否正常解析 func IsValidDomain(domain string) bool { - _, err := net.LookupHost(domain) + dnsServer := getLocalIP() + if dnsServer == "" { + dnsServer = "223.5.5.5" + } + dnsServer = dnsServer + ":53" + _, err := lookupHostWithDNS(domain, dnsServer) if err != nil { return false } return true } + +func lookupHostWithDNS(host string, dnsServer string) ([]string, error) { + resolver := &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + conn, err := d.DialContext(ctx, network, dnsServer) + if err != nil { + fmt.Println("[app] [error] "+time.Now().Format(config.AppTimeFormat)+" [DNSTool] 连接到 DNS 服务器失败: ", err) + return nil, err + } + return conn, nil + }, + } + ips, err := resolver.LookupHost(context.Background(), host) + if err != nil { + return nil, err + } + return ips, nil +} + +func getLocalIP() string { + addrList, err := net.InterfaceAddrs() + if err != nil { + log.Fatal(err) + } + for _, addr := range addrList { + if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { + if ipNet.IP.To4() != nil { + return ipNet.IP.String() + } + } + } + return "" +}