feat: add cache
This commit is contained in:
parent
adb78cf3f0
commit
9607e94124
|
@ -2,14 +2,17 @@ package cache
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"kenaito-dns/config"
|
||||
"kenaito-dns/dao"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var KeyResolveRecordMap sync.Map
|
||||
var IdResolveRecordMap sync.Map
|
||||
|
||||
func ReloadCache() {
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache start")
|
||||
resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion())
|
||||
for _, record := range resolveRecords {
|
||||
// id -> resolveRecord
|
||||
|
@ -18,15 +21,14 @@ func ReloadCache() {
|
|||
cacheKey := fmt.Sprintf("%s-%s", record.Name, record.RecordType)
|
||||
records, ok := KeyResolveRecordMap.Load(cacheKey)
|
||||
if !ok {
|
||||
fmt.Println("读取缓存失败, key=" + cacheKey)
|
||||
var tempRecords []dao.ResolveRecord
|
||||
tempRecords = append(tempRecords, record)
|
||||
KeyResolveRecordMap.Store(cacheKey, tempRecords)
|
||||
} else {
|
||||
fmt.Println("读取缓存成功, key=" + cacheKey)
|
||||
var newRecords = records.([]dao.ResolveRecord)
|
||||
records = append(newRecords, record)
|
||||
KeyResolveRecordMap.Store(cacheKey, records)
|
||||
}
|
||||
}
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache end")
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ package controller
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"kenaito-dns/cache"
|
||||
"kenaito-dns/constant"
|
||||
"kenaito-dns/dao"
|
||||
"kenaito-dns/domain"
|
||||
|
@ -46,6 +47,7 @@ func InitRestFunc(r *gin.Engine) {
|
|||
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("添加"+newRecord.RecordType+"记录失败, %v", err)})
|
||||
return
|
||||
}
|
||||
cache.ReloadCache()
|
||||
body := make(map[string]interface{})
|
||||
body["oldVersion"] = oldVersion
|
||||
body["newVersion"] = newVersion
|
||||
|
@ -77,6 +79,7 @@ func InitRestFunc(r *gin.Engine) {
|
|||
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("删除"+newRecord.RecordType+"记录失败, %v", err)})
|
||||
return
|
||||
}
|
||||
cache.ReloadCache()
|
||||
body := make(map[string]interface{})
|
||||
body["oldVersion"] = oldVersion
|
||||
body["newVersion"] = newVersion
|
||||
|
@ -116,6 +119,7 @@ func InitRestFunc(r *gin.Engine) {
|
|||
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("更新"+newRecord.RecordType+"记录失败, %v", err)})
|
||||
return
|
||||
}
|
||||
cache.ReloadCache()
|
||||
body := make(map[string]interface{})
|
||||
body["oldVersion"] = oldVersion
|
||||
body["newVersion"] = newVersion
|
||||
|
@ -173,6 +177,7 @@ func InitRestFunc(r *gin.Engine) {
|
|||
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("回滚失败, %v", err)})
|
||||
return
|
||||
}
|
||||
cache.ReloadCache()
|
||||
body := make(map[string]interface{})
|
||||
body["currentVersion"] = jsonObj.Version
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
|
|
|
@ -8,9 +8,12 @@ package core
|
|||
import (
|
||||
"fmt"
|
||||
"github.com/miekg/dns"
|
||||
"kenaito-dns/cache"
|
||||
"kenaito-dns/config"
|
||||
"kenaito-dns/constant"
|
||||
"kenaito-dns/dao"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
@ -46,7 +49,16 @@ func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
|||
func handleARecord(q dns.Question, msg *dns.Msg) {
|
||||
name := q.Name
|
||||
queryName := name[0 : len(name)-1]
|
||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_A)
|
||||
var records []dao.ResolveRecord
|
||||
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_A)
|
||||
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||
if ok {
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||
records = value.([]dao.ResolveRecord)
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||
} else {
|
||||
records = dao.FindResolveRecordByNameType(queryName, constant.R_A)
|
||||
}
|
||||
if len(records) > 0 {
|
||||
for _, record := range records {
|
||||
fmt.Printf("=== A记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
|
||||
|
@ -69,7 +81,16 @@ func handleARecord(q dns.Question, msg *dns.Msg) {
|
|||
func handleAAAARecord(q dns.Question, msg *dns.Msg) {
|
||||
name := q.Name
|
||||
queryName := name[0 : len(name)-1]
|
||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_AAAA)
|
||||
var records []dao.ResolveRecord
|
||||
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_AAAA)
|
||||
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||
if ok {
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||
records = value.([]dao.ResolveRecord)
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||
} else {
|
||||
records = dao.FindResolveRecordByNameType(queryName, constant.R_AAAA)
|
||||
}
|
||||
if len(records) > 0 {
|
||||
for _, record := range records {
|
||||
fmt.Printf("=== AAAA记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
|
||||
|
@ -92,7 +113,16 @@ func handleAAAARecord(q dns.Question, msg *dns.Msg) {
|
|||
func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
|
||||
name := q.Name
|
||||
queryName := name[0 : len(name)-1]
|
||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_CNAME)
|
||||
var records []dao.ResolveRecord
|
||||
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_CNAME)
|
||||
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||
if ok {
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||
records = value.([]dao.ResolveRecord)
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||
} else {
|
||||
records = dao.FindResolveRecordByNameType(queryName, constant.R_CNAME)
|
||||
}
|
||||
if len(records) > 0 {
|
||||
for _, record := range records {
|
||||
fmt.Printf("=== CNAME记录 === 请求解析的域名:%s,解析的目标域名:%s\n", name, record.Value)
|
||||
|
@ -114,7 +144,16 @@ func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
|
|||
func handleMXRecord(q dns.Question, msg *dns.Msg) {
|
||||
name := q.Name
|
||||
queryName := name[0 : len(name)-1]
|
||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_MX)
|
||||
var records []dao.ResolveRecord
|
||||
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_MX)
|
||||
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||
if ok {
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||
records = value.([]dao.ResolveRecord)
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||
} else {
|
||||
records = dao.FindResolveRecordByNameType(queryName, constant.R_MX)
|
||||
}
|
||||
if len(records) > 0 {
|
||||
for _, record := range records {
|
||||
fmt.Printf("=== MX记录 === 请求解析的域名:%s,解析的目标域名:%s, MX优先级: 10\n", name, record.Value)
|
||||
|
@ -137,7 +176,16 @@ func handleMXRecord(q dns.Question, msg *dns.Msg) {
|
|||
func handleTXTRecord(q dns.Question, msg *dns.Msg) {
|
||||
name := q.Name
|
||||
queryName := name[0 : len(name)-1]
|
||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_TXT)
|
||||
var records []dao.ResolveRecord
|
||||
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_TXT)
|
||||
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||
if ok {
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||
records = value.([]dao.ResolveRecord)
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||
} else {
|
||||
records = dao.FindResolveRecordByNameType(queryName, constant.R_TXT)
|
||||
}
|
||||
if len(records) > 0 {
|
||||
for _, record := range records {
|
||||
fmt.Printf("=== TXT记录 === 请求解析的域名:%s,解析的目标值:%s\n", name, record.Value)
|
||||
|
|
|
@ -21,13 +21,13 @@ type ResolveRecord struct {
|
|||
Version int `xorm:"not null integer 'version'"`
|
||||
}
|
||||
|
||||
func FindResolveRecordById(id int) []ResolveRecord {
|
||||
var records []ResolveRecord
|
||||
err := Engine.Table("resolve_record").Where("`id` = ?", id).Find(&records)
|
||||
func FindResolveRecordById(id int) ResolveRecord {
|
||||
var record ResolveRecord
|
||||
_, err := Engine.Table("resolve_record").Where("`id` = ?", id).Get(&record)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
}
|
||||
return records
|
||||
return record
|
||||
}
|
||||
|
||||
func FindResolveRecordByVersion(version int) []ResolveRecord {
|
||||
|
|
BIN
dns.sqlite3
BIN
dns.sqlite3
Binary file not shown.
10
main.go
10
main.go
|
@ -18,7 +18,7 @@ import (
|
|||
)
|
||||
|
||||
func main() {
|
||||
fmt.Println("[app] [info] kenaito-dns version = " + config.AppVersion)
|
||||
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " kenaito-dns version = " + config.AppVersion)
|
||||
go cache.ReloadCache()
|
||||
go initDNSServer()
|
||||
initRestfulServer()
|
||||
|
@ -30,9 +30,9 @@ func initDNSServer() {
|
|||
// 设置服务器地址和协议
|
||||
server := &dns.Server{Addr: config.DnsServerPort, Net: "udp"}
|
||||
// 开始监听
|
||||
fmt.Printf("[dns] [info] Starting DNS server on %s\n", server.Addr)
|
||||
fmt.Printf("[dns] [info] "+time.Now().Format(config.AppTimeFormat)+" Starting DNS server on %s\n", server.Addr)
|
||||
if err := server.ListenAndServe(); err != nil {
|
||||
fmt.Printf("[dns] [error] Failed to start DNS server: %s\n", err.Error())
|
||||
fmt.Printf("[dns] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start DNS server: %s\n", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,9 +75,9 @@ func initRestfulServer() {
|
|||
WriteTimeout: config.WebWriteTimeout * time.Second,
|
||||
}
|
||||
controller.InitRestFunc(router)
|
||||
fmt.Printf("[gin] [info] Start Gin server: %s\n", config.WebServerPort)
|
||||
fmt.Printf("[gin] [info] "+time.Now().Format(config.AppTimeFormat)+" Start Gin server: %s\n", config.WebServerPort)
|
||||
err := server.ListenAndServe()
|
||||
if err != nil {
|
||||
fmt.Printf("[gin] [error] Failed to start Gin server: %s\n", config.WebServerPort)
|
||||
fmt.Printf("[gin] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start Gin server: %s\n", config.WebServerPort)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue