feat: add cache

This commit is contained in:
骑着蜗牛追导弹 2024-11-09 02:06:10 +08:00
parent adb78cf3f0
commit 9607e94124
6 changed files with 71 additions and 16 deletions

View File

@ -2,14 +2,17 @@ package cache
import (
"fmt"
"kenaito-dns/config"
"kenaito-dns/dao"
"sync"
"time"
)
var KeyResolveRecordMap sync.Map
var IdResolveRecordMap sync.Map
func ReloadCache() {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache start")
resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion())
for _, record := range resolveRecords {
// id -> resolveRecord
@ -18,15 +21,14 @@ func ReloadCache() {
cacheKey := fmt.Sprintf("%s-%s", record.Name, record.RecordType)
records, ok := KeyResolveRecordMap.Load(cacheKey)
if !ok {
fmt.Println("读取缓存失败, key=" + cacheKey)
var tempRecords []dao.ResolveRecord
tempRecords = append(tempRecords, record)
KeyResolveRecordMap.Store(cacheKey, tempRecords)
} else {
fmt.Println("读取缓存成功, key=" + cacheKey)
var newRecords = records.([]dao.ResolveRecord)
records = append(newRecords, record)
KeyResolveRecordMap.Store(cacheKey, records)
}
}
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache end")
}

View File

@ -8,6 +8,7 @@ package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"kenaito-dns/cache"
"kenaito-dns/constant"
"kenaito-dns/dao"
"kenaito-dns/domain"
@ -46,6 +47,7 @@ func InitRestFunc(r *gin.Engine) {
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("添加"+newRecord.RecordType+"记录失败, %v", err)})
return
}
cache.ReloadCache()
body := make(map[string]interface{})
body["oldVersion"] = oldVersion
body["newVersion"] = newVersion
@ -77,6 +79,7 @@ func InitRestFunc(r *gin.Engine) {
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("删除"+newRecord.RecordType+"记录失败, %v", err)})
return
}
cache.ReloadCache()
body := make(map[string]interface{})
body["oldVersion"] = oldVersion
body["newVersion"] = newVersion
@ -116,6 +119,7 @@ func InitRestFunc(r *gin.Engine) {
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("更新"+newRecord.RecordType+"记录失败, %v", err)})
return
}
cache.ReloadCache()
body := make(map[string]interface{})
body["oldVersion"] = oldVersion
body["newVersion"] = newVersion
@ -173,6 +177,7 @@ func InitRestFunc(r *gin.Engine) {
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("回滚失败, %v", err)})
return
}
cache.ReloadCache()
body := make(map[string]interface{})
body["currentVersion"] = jsonObj.Version
c.JSON(http.StatusOK, gin.H{

View File

@ -8,9 +8,12 @@ package core
import (
"fmt"
"github.com/miekg/dns"
"kenaito-dns/cache"
"kenaito-dns/config"
"kenaito-dns/constant"
"kenaito-dns/dao"
"net"
"time"
)
func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
@ -46,7 +49,16 @@ func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
func handleARecord(q dns.Question, msg *dns.Msg) {
name := q.Name
queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_A)
var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_A)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_A)
}
if len(records) > 0 {
for _, record := range records {
fmt.Printf("=== A记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
@ -69,7 +81,16 @@ func handleARecord(q dns.Question, msg *dns.Msg) {
func handleAAAARecord(q dns.Question, msg *dns.Msg) {
name := q.Name
queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_AAAA)
var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_AAAA)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_AAAA)
}
if len(records) > 0 {
for _, record := range records {
fmt.Printf("=== AAAA记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
@ -92,7 +113,16 @@ func handleAAAARecord(q dns.Question, msg *dns.Msg) {
func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
name := q.Name
queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_CNAME)
var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_CNAME)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_CNAME)
}
if len(records) > 0 {
for _, record := range records {
fmt.Printf("=== CNAME记录 === 请求解析的域名:%s,解析的目标域名:%s\n", name, record.Value)
@ -114,7 +144,16 @@ func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
func handleMXRecord(q dns.Question, msg *dns.Msg) {
name := q.Name
queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_MX)
var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_MX)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_MX)
}
if len(records) > 0 {
for _, record := range records {
fmt.Printf("=== MX记录 === 请求解析的域名:%s,解析的目标域名:%s, MX优先级: 10\n", name, record.Value)
@ -137,7 +176,16 @@ func handleMXRecord(q dns.Question, msg *dns.Msg) {
func handleTXTRecord(q dns.Question, msg *dns.Msg) {
name := q.Name
queryName := name[0 : len(name)-1]
records := dao.FindResolveRecordByNameType(queryName, constant.R_TXT)
var records []dao.ResolveRecord
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_TXT)
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
if ok {
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
records = value.([]dao.ResolveRecord)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
} else {
records = dao.FindResolveRecordByNameType(queryName, constant.R_TXT)
}
if len(records) > 0 {
for _, record := range records {
fmt.Printf("=== TXT记录 === 请求解析的域名:%s,解析的目标值:%s\n", name, record.Value)

View File

@ -21,13 +21,13 @@ type ResolveRecord struct {
Version int `xorm:"not null integer 'version'"`
}
func FindResolveRecordById(id int) []ResolveRecord {
var records []ResolveRecord
err := Engine.Table("resolve_record").Where("`id` = ?", id).Find(&records)
func FindResolveRecordById(id int) ResolveRecord {
var record ResolveRecord
_, err := Engine.Table("resolve_record").Where("`id` = ?", id).Get(&record)
if err != nil {
fmt.Println(err)
}
return records
return record
}
func FindResolveRecordByVersion(version int) []ResolveRecord {

Binary file not shown.

10
main.go
View File

@ -18,7 +18,7 @@ import (
)
func main() {
fmt.Println("[app] [info] kenaito-dns version = " + config.AppVersion)
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " kenaito-dns version = " + config.AppVersion)
go cache.ReloadCache()
go initDNSServer()
initRestfulServer()
@ -30,9 +30,9 @@ func initDNSServer() {
// 设置服务器地址和协议
server := &dns.Server{Addr: config.DnsServerPort, Net: "udp"}
// 开始监听
fmt.Printf("[dns] [info] Starting DNS server on %s\n", server.Addr)
fmt.Printf("[dns] [info] "+time.Now().Format(config.AppTimeFormat)+" Starting DNS server on %s\n", server.Addr)
if err := server.ListenAndServe(); err != nil {
fmt.Printf("[dns] [error] Failed to start DNS server: %s\n", err.Error())
fmt.Printf("[dns] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start DNS server: %s\n", err.Error())
}
}
@ -75,9 +75,9 @@ func initRestfulServer() {
WriteTimeout: config.WebWriteTimeout * time.Second,
}
controller.InitRestFunc(router)
fmt.Printf("[gin] [info] Start Gin server: %s\n", config.WebServerPort)
fmt.Printf("[gin] [info] "+time.Now().Format(config.AppTimeFormat)+" Start Gin server: %s\n", config.WebServerPort)
err := server.ListenAndServe()
if err != nil {
fmt.Printf("[gin] [error] Failed to start Gin server: %s\n", config.WebServerPort)
fmt.Printf("[gin] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start Gin server: %s\n", config.WebServerPort)
}
}