From 9607e94124764ed54254e1e7f1548dd1e44c4db2 Mon Sep 17 00:00:00 2001 From: odboy Date: Sat, 9 Nov 2024 02:06:10 +0800 Subject: [PATCH] feat: add cache --- cache/resolve_record.go | 6 ++-- controller/resolve_record.go | 5 +++ core/handler.go | 58 ++++++++++++++++++++++++++++++++--- dao/resolve_record.go | 8 ++--- dns.sqlite3 | Bin 36864 -> 36864 bytes main.go | 10 +++--- 6 files changed, 71 insertions(+), 16 deletions(-) diff --git a/cache/resolve_record.go b/cache/resolve_record.go index fef0f11..a8aa1b8 100644 --- a/cache/resolve_record.go +++ b/cache/resolve_record.go @@ -2,14 +2,17 @@ package cache import ( "fmt" + "kenaito-dns/config" "kenaito-dns/dao" "sync" + "time" ) var KeyResolveRecordMap sync.Map var IdResolveRecordMap sync.Map func ReloadCache() { + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache start") resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion()) for _, record := range resolveRecords { // id -> resolveRecord @@ -18,15 +21,14 @@ func ReloadCache() { cacheKey := fmt.Sprintf("%s-%s", record.Name, record.RecordType) records, ok := KeyResolveRecordMap.Load(cacheKey) if !ok { - fmt.Println("读取缓存失败, key=" + cacheKey) var tempRecords []dao.ResolveRecord tempRecords = append(tempRecords, record) KeyResolveRecordMap.Store(cacheKey, tempRecords) } else { - fmt.Println("读取缓存成功, key=" + cacheKey) var newRecords = records.([]dao.ResolveRecord) records = append(newRecords, record) KeyResolveRecordMap.Store(cacheKey, records) } } + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache end") } diff --git a/controller/resolve_record.go b/controller/resolve_record.go index bf39d4c..cd753e9 100644 --- a/controller/resolve_record.go +++ b/controller/resolve_record.go @@ -8,6 +8,7 @@ package controller import ( "fmt" "github.com/gin-gonic/gin" + "kenaito-dns/cache" "kenaito-dns/constant" "kenaito-dns/dao" "kenaito-dns/domain" @@ -46,6 +47,7 @@ func InitRestFunc(r *gin.Engine) { c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("添加"+newRecord.RecordType+"记录失败, %v", err)}) return } + cache.ReloadCache() body := make(map[string]interface{}) body["oldVersion"] = oldVersion body["newVersion"] = newVersion @@ -77,6 +79,7 @@ func InitRestFunc(r *gin.Engine) { c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("删除"+newRecord.RecordType+"记录失败, %v", err)}) return } + cache.ReloadCache() body := make(map[string]interface{}) body["oldVersion"] = oldVersion body["newVersion"] = newVersion @@ -116,6 +119,7 @@ func InitRestFunc(r *gin.Engine) { c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("更新"+newRecord.RecordType+"记录失败, %v", err)}) return } + cache.ReloadCache() body := make(map[string]interface{}) body["oldVersion"] = oldVersion body["newVersion"] = newVersion @@ -173,6 +177,7 @@ func InitRestFunc(r *gin.Engine) { c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("回滚失败, %v", err)}) return } + cache.ReloadCache() body := make(map[string]interface{}) body["currentVersion"] = jsonObj.Version c.JSON(http.StatusOK, gin.H{ diff --git a/core/handler.go b/core/handler.go index 2745ec9..6e7acc6 100644 --- a/core/handler.go +++ b/core/handler.go @@ -8,9 +8,12 @@ package core import ( "fmt" "github.com/miekg/dns" + "kenaito-dns/cache" + "kenaito-dns/config" "kenaito-dns/constant" "kenaito-dns/dao" "net" + "time" ) func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { @@ -46,7 +49,16 @@ func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { func handleARecord(q dns.Question, msg *dns.Msg) { name := q.Name queryName := name[0 : len(name)-1] - records := dao.FindResolveRecordByNameType(queryName, constant.R_A) + var records []dao.ResolveRecord + cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_A) + value, ok := cache.KeyResolveRecordMap.Load(cacheKey) + if ok { + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start") + records = value.([]dao.ResolveRecord) + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end") + } else { + records = dao.FindResolveRecordByNameType(queryName, constant.R_A) + } if len(records) > 0 { for _, record := range records { fmt.Printf("=== A记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value) @@ -69,7 +81,16 @@ func handleARecord(q dns.Question, msg *dns.Msg) { func handleAAAARecord(q dns.Question, msg *dns.Msg) { name := q.Name queryName := name[0 : len(name)-1] - records := dao.FindResolveRecordByNameType(queryName, constant.R_AAAA) + var records []dao.ResolveRecord + cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_AAAA) + value, ok := cache.KeyResolveRecordMap.Load(cacheKey) + if ok { + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start") + records = value.([]dao.ResolveRecord) + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end") + } else { + records = dao.FindResolveRecordByNameType(queryName, constant.R_AAAA) + } if len(records) > 0 { for _, record := range records { fmt.Printf("=== AAAA记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value) @@ -92,7 +113,16 @@ func handleAAAARecord(q dns.Question, msg *dns.Msg) { func handleCNAMERecord(q dns.Question, msg *dns.Msg) { name := q.Name queryName := name[0 : len(name)-1] - records := dao.FindResolveRecordByNameType(queryName, constant.R_CNAME) + var records []dao.ResolveRecord + cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_CNAME) + value, ok := cache.KeyResolveRecordMap.Load(cacheKey) + if ok { + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start") + records = value.([]dao.ResolveRecord) + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end") + } else { + records = dao.FindResolveRecordByNameType(queryName, constant.R_CNAME) + } if len(records) > 0 { for _, record := range records { fmt.Printf("=== CNAME记录 === 请求解析的域名:%s,解析的目标域名:%s\n", name, record.Value) @@ -114,7 +144,16 @@ func handleCNAMERecord(q dns.Question, msg *dns.Msg) { func handleMXRecord(q dns.Question, msg *dns.Msg) { name := q.Name queryName := name[0 : len(name)-1] - records := dao.FindResolveRecordByNameType(queryName, constant.R_MX) + var records []dao.ResolveRecord + cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_MX) + value, ok := cache.KeyResolveRecordMap.Load(cacheKey) + if ok { + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start") + records = value.([]dao.ResolveRecord) + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end") + } else { + records = dao.FindResolveRecordByNameType(queryName, constant.R_MX) + } if len(records) > 0 { for _, record := range records { fmt.Printf("=== MX记录 === 请求解析的域名:%s,解析的目标域名:%s, MX优先级: 10\n", name, record.Value) @@ -137,7 +176,16 @@ func handleMXRecord(q dns.Question, msg *dns.Msg) { func handleTXTRecord(q dns.Question, msg *dns.Msg) { name := q.Name queryName := name[0 : len(name)-1] - records := dao.FindResolveRecordByNameType(queryName, constant.R_TXT) + var records []dao.ResolveRecord + cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_TXT) + value, ok := cache.KeyResolveRecordMap.Load(cacheKey) + if ok { + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start") + records = value.([]dao.ResolveRecord) + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end") + } else { + records = dao.FindResolveRecordByNameType(queryName, constant.R_TXT) + } if len(records) > 0 { for _, record := range records { fmt.Printf("=== TXT记录 === 请求解析的域名:%s,解析的目标值:%s\n", name, record.Value) diff --git a/dao/resolve_record.go b/dao/resolve_record.go index 8cb9ae6..67cf2c2 100644 --- a/dao/resolve_record.go +++ b/dao/resolve_record.go @@ -21,13 +21,13 @@ type ResolveRecord struct { Version int `xorm:"not null integer 'version'"` } -func FindResolveRecordById(id int) []ResolveRecord { - var records []ResolveRecord - err := Engine.Table("resolve_record").Where("`id` = ?", id).Find(&records) +func FindResolveRecordById(id int) ResolveRecord { + var record ResolveRecord + _, err := Engine.Table("resolve_record").Where("`id` = ?", id).Get(&record) if err != nil { fmt.Println(err) } - return records + return record } func FindResolveRecordByVersion(version int) []ResolveRecord { diff --git a/dns.sqlite3 b/dns.sqlite3 index c252b45b7f6086a59081f8a2480aa6b85774c282..03d37d40bda2fcfe8867df0cbb2aac1e2638d3d4 100644 GIT binary patch delta 366 zcmZozz|^pSX@WGP>qHr6M%RrA^Yz6AdHGWq*!iY0@ZaE@#^1%akKc-4fbZF6MS%vs zdNm$)26cW$ZN|jJM7{i!r2I;~fEHH kB&?1Ad&`)efdQ;uLyeo5U@&A~ENEb`so+09BZwsc0Gl>d>Hq)$ delta 47 zcmZozz|^pSX@WGP#Y7orMvILJ^Yz)eco~2|U^An`fBnq@4!`*~75wLCWL?bYAOHYK CQw_KP diff --git a/main.go b/main.go index 3eda64f..ae8a266 100644 --- a/main.go +++ b/main.go @@ -18,7 +18,7 @@ import ( ) func main() { - fmt.Println("[app] [info] kenaito-dns version = " + config.AppVersion) + fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " kenaito-dns version = " + config.AppVersion) go cache.ReloadCache() go initDNSServer() initRestfulServer() @@ -30,9 +30,9 @@ func initDNSServer() { // 设置服务器地址和协议 server := &dns.Server{Addr: config.DnsServerPort, Net: "udp"} // 开始监听 - fmt.Printf("[dns] [info] Starting DNS server on %s\n", server.Addr) + fmt.Printf("[dns] [info] "+time.Now().Format(config.AppTimeFormat)+" Starting DNS server on %s\n", server.Addr) if err := server.ListenAndServe(); err != nil { - fmt.Printf("[dns] [error] Failed to start DNS server: %s\n", err.Error()) + fmt.Printf("[dns] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start DNS server: %s\n", err.Error()) } } @@ -75,9 +75,9 @@ func initRestfulServer() { WriteTimeout: config.WebWriteTimeout * time.Second, } controller.InitRestFunc(router) - fmt.Printf("[gin] [info] Start Gin server: %s\n", config.WebServerPort) + fmt.Printf("[gin] [info] "+time.Now().Format(config.AppTimeFormat)+" Start Gin server: %s\n", config.WebServerPort) err := server.ListenAndServe() if err != nil { - fmt.Printf("[gin] [error] Failed to start Gin server: %s\n", config.WebServerPort) + fmt.Printf("[gin] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start Gin server: %s\n", config.WebServerPort) } }