feat: add cache

This commit is contained in:
骑着蜗牛追导弹 2024-11-09 02:06:10 +08:00
parent adb78cf3f0
commit 9607e94124
6 changed files with 71 additions and 16 deletions

View File

@ -2,14 +2,17 @@ package cache
import ( import (
"fmt" "fmt"
"kenaito-dns/config"
"kenaito-dns/dao" "kenaito-dns/dao"
"sync" "sync"
"time"
) )
var KeyResolveRecordMap sync.Map var KeyResolveRecordMap sync.Map
var IdResolveRecordMap sync.Map var IdResolveRecordMap sync.Map
func ReloadCache() { func ReloadCache() {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache start")
resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion()) resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion())
for _, record := range resolveRecords { for _, record := range resolveRecords {
// id -> resolveRecord // id -> resolveRecord
@ -18,15 +21,14 @@ func ReloadCache() {
cacheKey := fmt.Sprintf("%s-%s", record.Name, record.RecordType) cacheKey := fmt.Sprintf("%s-%s", record.Name, record.RecordType)
records, ok := KeyResolveRecordMap.Load(cacheKey) records, ok := KeyResolveRecordMap.Load(cacheKey)
if !ok { if !ok {
fmt.Println("读取缓存失败, key=" + cacheKey)
var tempRecords []dao.ResolveRecord var tempRecords []dao.ResolveRecord
tempRecords = append(tempRecords, record) tempRecords = append(tempRecords, record)
KeyResolveRecordMap.Store(cacheKey, tempRecords) KeyResolveRecordMap.Store(cacheKey, tempRecords)
} else { } else {
fmt.Println("读取缓存成功, key=" + cacheKey)
var newRecords = records.([]dao.ResolveRecord) var newRecords = records.([]dao.ResolveRecord)
records = append(newRecords, record) records = append(newRecords, record)
KeyResolveRecordMap.Store(cacheKey, records) KeyResolveRecordMap.Store(cacheKey, records)
} }
} }
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache end")
} }

View File

@ -8,6 +8,7 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"kenaito-dns/cache"
"kenaito-dns/constant" "kenaito-dns/constant"
"kenaito-dns/dao" "kenaito-dns/dao"
"kenaito-dns/domain" "kenaito-dns/domain"
@ -46,6 +47,7 @@ func InitRestFunc(r *gin.Engine) {
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("添加"+newRecord.RecordType+"记录失败, %v", err)}) c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("添加"+newRecord.RecordType+"记录失败, %v", err)})
return return
} }
cache.ReloadCache()
body := make(map[string]interface{}) body := make(map[string]interface{})
body["oldVersion"] = oldVersion body["oldVersion"] = oldVersion
body["newVersion"] = newVersion body["newVersion"] = newVersion
@ -77,6 +79,7 @@ func InitRestFunc(r *gin.Engine) {
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("删除"+newRecord.RecordType+"记录失败, %v", err)}) c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("删除"+newRecord.RecordType+"记录失败, %v", err)})
return return
} }
cache.ReloadCache()
body := make(map[string]interface{}) body := make(map[string]interface{})
body["oldVersion"] = oldVersion body["oldVersion"] = oldVersion
body["newVersion"] = newVersion body["newVersion"] = newVersion
@ -116,6 +119,7 @@ func InitRestFunc(r *gin.Engine) {
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("更新"+newRecord.RecordType+"记录失败, %v", err)}) c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("更新"+newRecord.RecordType+"记录失败, %v", err)})
return return
} }
cache.ReloadCache()
body := make(map[string]interface{}) body := make(map[string]interface{})
body["oldVersion"] = oldVersion body["oldVersion"] = oldVersion
body["newVersion"] = newVersion body["newVersion"] = newVersion
@ -173,6 +177,7 @@ func InitRestFunc(r *gin.Engine) {
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("回滚失败, %v", err)}) c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("回滚失败, %v", err)})
return return
} }
cache.ReloadCache()
body := make(map[string]interface{}) body := make(map[string]interface{})
body["currentVersion"] = jsonObj.Version body["currentVersion"] = jsonObj.Version
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{

View File

@ -8,9 +8,12 @@ package core
import ( import (
"fmt" "fmt"
"github.com/miekg/dns" "github.com/miekg/dns"
"kenaito-dns/cache"
"kenaito-dns/config"
"kenaito-dns/constant" "kenaito-dns/constant"
"kenaito-dns/dao" "kenaito-dns/dao"
"net" "net"
"time"
) )
func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) { func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
@ -46,7 +49,16 @@ func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
func handleARecord(q dns.Question, msg *dns.Msg) { func handleARecord(q dns.Question, msg *dns.Msg) {
name := q.Name name := q.Name
queryName := name[0 : len(name)-1] queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_A) var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_A)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_A)
}
if len(records) > 0 { if len(records) > 0 {
for _, record := range records { for _, record := range records {
fmt.Printf("=== A记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value) fmt.Printf("=== A记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
@ -69,7 +81,16 @@ func handleARecord(q dns.Question, msg *dns.Msg) {
func handleAAAARecord(q dns.Question, msg *dns.Msg) { func handleAAAARecord(q dns.Question, msg *dns.Msg) {
name := q.Name name := q.Name
queryName := name[0 : len(name)-1] queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_AAAA) var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_AAAA)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_AAAA)
}
if len(records) > 0 { if len(records) > 0 {
for _, record := range records { for _, record := range records {
fmt.Printf("=== AAAA记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value) fmt.Printf("=== AAAA记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
@ -92,7 +113,16 @@ func handleAAAARecord(q dns.Question, msg *dns.Msg) {
func handleCNAMERecord(q dns.Question, msg *dns.Msg) { func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
name := q.Name name := q.Name
queryName := name[0 : len(name)-1] queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_CNAME) var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_CNAME)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_CNAME)
}
if len(records) > 0 { if len(records) > 0 {
for _, record := range records { for _, record := range records {
fmt.Printf("=== CNAME记录 === 请求解析的域名:%s,解析的目标域名:%s\n", name, record.Value) fmt.Printf("=== CNAME记录 === 请求解析的域名:%s,解析的目标域名:%s\n", name, record.Value)
@ -114,7 +144,16 @@ func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
func handleMXRecord(q dns.Question, msg *dns.Msg) { func handleMXRecord(q dns.Question, msg *dns.Msg) {
name := q.Name name := q.Name
queryName := name[0 : len(name)-1] queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_MX) var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_MX)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_MX)
}
if len(records) > 0 { if len(records) > 0 {
for _, record := range records { for _, record := range records {
fmt.Printf("=== MX记录 === 请求解析的域名:%s,解析的目标域名:%s, MX优先级: 10\n", name, record.Value) fmt.Printf("=== MX记录 === 请求解析的域名:%s,解析的目标域名:%s, MX优先级: 10\n", name, record.Value)
@ -137,7 +176,16 @@ func handleMXRecord(q dns.Question, msg *dns.Msg) {
func handleTXTRecord(q dns.Question, msg *dns.Msg) { func handleTXTRecord(q dns.Question, msg *dns.Msg) {
name := q.Name name := q.Name
queryName := name[0 : len(name)-1] queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_TXT) var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_TXT)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_TXT)
}
if len(records) > 0 { if len(records) > 0 {
for _, record := range records { for _, record := range records {
fmt.Printf("=== TXT记录 === 请求解析的域名:%s,解析的目标值:%s\n", name, record.Value) fmt.Printf("=== TXT记录 === 请求解析的域名:%s,解析的目标值:%s\n", name, record.Value)

View File

@ -21,13 +21,13 @@ type ResolveRecord struct {
Version int `xorm:"not null integer 'version'"` Version int `xorm:"not null integer 'version'"`
} }
func FindResolveRecordById(id int) []ResolveRecord { func FindResolveRecordById(id int) ResolveRecord {
var records []ResolveRecord var record ResolveRecord
err := Engine.Table("resolve_record").Where("`id` = ?", id).Find(&records) _, err := Engine.Table("resolve_record").Where("`id` = ?", id).Get(&record)
if err != nil { if err != nil {
fmt.Println(err) fmt.Println(err)
} }
return records return record
} }
func FindResolveRecordByVersion(version int) []ResolveRecord { func FindResolveRecordByVersion(version int) []ResolveRecord {

Binary file not shown.

10
main.go
View File

@ -18,7 +18,7 @@ import (
) )
func main() { func main() {
fmt.Println("[app] [info] kenaito-dns version = " + config.AppVersion) fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " kenaito-dns version = " + config.AppVersion)
go cache.ReloadCache() go cache.ReloadCache()
go initDNSServer() go initDNSServer()
initRestfulServer() initRestfulServer()
@ -30,9 +30,9 @@ func initDNSServer() {
// 设置服务器地址和协议 // 设置服务器地址和协议
server := &dns.Server{Addr: config.DnsServerPort, Net: "udp"} server := &dns.Server{Addr: config.DnsServerPort, Net: "udp"}
// 开始监听 // 开始监听
fmt.Printf("[dns] [info] Starting DNS server on %s\n", server.Addr) fmt.Printf("[dns] [info] "+time.Now().Format(config.AppTimeFormat)+" Starting DNS server on %s\n", server.Addr)
if err := server.ListenAndServe(); err != nil { if err := server.ListenAndServe(); err != nil {
fmt.Printf("[dns] [error] Failed to start DNS server: %s\n", err.Error()) fmt.Printf("[dns] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start DNS server: %s\n", err.Error())
} }
} }
@ -75,9 +75,9 @@ func initRestfulServer() {
WriteTimeout: config.WebWriteTimeout * time.Second, WriteTimeout: config.WebWriteTimeout * time.Second,
} }
controller.InitRestFunc(router) controller.InitRestFunc(router)
fmt.Printf("[gin] [info] Start Gin server: %s\n", config.WebServerPort) fmt.Printf("[gin] [info] "+time.Now().Format(config.AppTimeFormat)+" Start Gin server: %s\n", config.WebServerPort)
err := server.ListenAndServe() err := server.ListenAndServe()
if err != nil { if err != nil {
fmt.Printf("[gin] [error] Failed to start Gin server: %s\n", config.WebServerPort) fmt.Printf("[gin] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start Gin server: %s\n", config.WebServerPort)
} }
} }