feat: add cache
This commit is contained in:
parent
adb78cf3f0
commit
9607e94124
|
@ -2,14 +2,17 @@ package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"kenaito-dns/config"
|
||||||
"kenaito-dns/dao"
|
"kenaito-dns/dao"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
var KeyResolveRecordMap sync.Map
|
var KeyResolveRecordMap sync.Map
|
||||||
var IdResolveRecordMap sync.Map
|
var IdResolveRecordMap sync.Map
|
||||||
|
|
||||||
func ReloadCache() {
|
func ReloadCache() {
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache start")
|
||||||
resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion())
|
resolveRecords := dao.FindResolveRecordByVersion(dao.GetResolveVersion())
|
||||||
for _, record := range resolveRecords {
|
for _, record := range resolveRecords {
|
||||||
// id -> resolveRecord
|
// id -> resolveRecord
|
||||||
|
@ -18,15 +21,14 @@ func ReloadCache() {
|
||||||
cacheKey := fmt.Sprintf("%s-%s", record.Name, record.RecordType)
|
cacheKey := fmt.Sprintf("%s-%s", record.Name, record.RecordType)
|
||||||
records, ok := KeyResolveRecordMap.Load(cacheKey)
|
records, ok := KeyResolveRecordMap.Load(cacheKey)
|
||||||
if !ok {
|
if !ok {
|
||||||
fmt.Println("读取缓存失败, key=" + cacheKey)
|
|
||||||
var tempRecords []dao.ResolveRecord
|
var tempRecords []dao.ResolveRecord
|
||||||
tempRecords = append(tempRecords, record)
|
tempRecords = append(tempRecords, record)
|
||||||
KeyResolveRecordMap.Store(cacheKey, tempRecords)
|
KeyResolveRecordMap.Store(cacheKey, tempRecords)
|
||||||
} else {
|
} else {
|
||||||
fmt.Println("读取缓存成功, key=" + cacheKey)
|
|
||||||
var newRecords = records.([]dao.ResolveRecord)
|
var newRecords = records.([]dao.ResolveRecord)
|
||||||
records = append(newRecords, record)
|
records = append(newRecords, record)
|
||||||
KeyResolveRecordMap.Store(cacheKey, records)
|
KeyResolveRecordMap.Store(cacheKey, records)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Reload cache end")
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,6 +8,7 @@ package controller
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"kenaito-dns/cache"
|
||||||
"kenaito-dns/constant"
|
"kenaito-dns/constant"
|
||||||
"kenaito-dns/dao"
|
"kenaito-dns/dao"
|
||||||
"kenaito-dns/domain"
|
"kenaito-dns/domain"
|
||||||
|
@ -46,6 +47,7 @@ func InitRestFunc(r *gin.Engine) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("添加"+newRecord.RecordType+"记录失败, %v", err)})
|
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("添加"+newRecord.RecordType+"记录失败, %v", err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
cache.ReloadCache()
|
||||||
body := make(map[string]interface{})
|
body := make(map[string]interface{})
|
||||||
body["oldVersion"] = oldVersion
|
body["oldVersion"] = oldVersion
|
||||||
body["newVersion"] = newVersion
|
body["newVersion"] = newVersion
|
||||||
|
@ -77,6 +79,7 @@ func InitRestFunc(r *gin.Engine) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("删除"+newRecord.RecordType+"记录失败, %v", err)})
|
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("删除"+newRecord.RecordType+"记录失败, %v", err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
cache.ReloadCache()
|
||||||
body := make(map[string]interface{})
|
body := make(map[string]interface{})
|
||||||
body["oldVersion"] = oldVersion
|
body["oldVersion"] = oldVersion
|
||||||
body["newVersion"] = newVersion
|
body["newVersion"] = newVersion
|
||||||
|
@ -116,6 +119,7 @@ func InitRestFunc(r *gin.Engine) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("更新"+newRecord.RecordType+"记录失败, %v", err)})
|
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("更新"+newRecord.RecordType+"记录失败, %v", err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
cache.ReloadCache()
|
||||||
body := make(map[string]interface{})
|
body := make(map[string]interface{})
|
||||||
body["oldVersion"] = oldVersion
|
body["oldVersion"] = oldVersion
|
||||||
body["newVersion"] = newVersion
|
body["newVersion"] = newVersion
|
||||||
|
@ -173,6 +177,7 @@ func InitRestFunc(r *gin.Engine) {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("回滚失败, %v", err)})
|
c.JSON(http.StatusBadRequest, gin.H{"message": fmt.Sprintf("回滚失败, %v", err)})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
cache.ReloadCache()
|
||||||
body := make(map[string]interface{})
|
body := make(map[string]interface{})
|
||||||
body["currentVersion"] = jsonObj.Version
|
body["currentVersion"] = jsonObj.Version
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
|
|
@ -8,9 +8,12 @@ package core
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
"kenaito-dns/cache"
|
||||||
|
"kenaito-dns/config"
|
||||||
"kenaito-dns/constant"
|
"kenaito-dns/constant"
|
||||||
"kenaito-dns/dao"
|
"kenaito-dns/dao"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
@ -46,7 +49,16 @@ func HandleDNSRequest(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
func handleARecord(q dns.Question, msg *dns.Msg) {
|
func handleARecord(q dns.Question, msg *dns.Msg) {
|
||||||
name := q.Name
|
name := q.Name
|
||||||
queryName := name[0 : len(name)-1]
|
queryName := name[0 : len(name)-1]
|
||||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_A)
|
var records []dao.ResolveRecord
|
||||||
|
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_A)
|
||||||
|
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||||
|
if ok {
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||||
|
records = value.([]dao.ResolveRecord)
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||||
|
} else {
|
||||||
|
records = dao.FindResolveRecordByNameType(queryName, constant.R_A)
|
||||||
|
}
|
||||||
if len(records) > 0 {
|
if len(records) > 0 {
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
fmt.Printf("=== A记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
|
fmt.Printf("=== A记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
|
||||||
|
@ -69,7 +81,16 @@ func handleARecord(q dns.Question, msg *dns.Msg) {
|
||||||
func handleAAAARecord(q dns.Question, msg *dns.Msg) {
|
func handleAAAARecord(q dns.Question, msg *dns.Msg) {
|
||||||
name := q.Name
|
name := q.Name
|
||||||
queryName := name[0 : len(name)-1]
|
queryName := name[0 : len(name)-1]
|
||||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_AAAA)
|
var records []dao.ResolveRecord
|
||||||
|
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_AAAA)
|
||||||
|
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||||
|
if ok {
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||||
|
records = value.([]dao.ResolveRecord)
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||||
|
} else {
|
||||||
|
records = dao.FindResolveRecordByNameType(queryName, constant.R_AAAA)
|
||||||
|
}
|
||||||
if len(records) > 0 {
|
if len(records) > 0 {
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
fmt.Printf("=== AAAA记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
|
fmt.Printf("=== AAAA记录 === 请求解析的域名:%s,解析的目标IP地址:%s\n", name, record.Value)
|
||||||
|
@ -92,7 +113,16 @@ func handleAAAARecord(q dns.Question, msg *dns.Msg) {
|
||||||
func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
|
func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
|
||||||
name := q.Name
|
name := q.Name
|
||||||
queryName := name[0 : len(name)-1]
|
queryName := name[0 : len(name)-1]
|
||||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_CNAME)
|
var records []dao.ResolveRecord
|
||||||
|
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_CNAME)
|
||||||
|
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||||
|
if ok {
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||||
|
records = value.([]dao.ResolveRecord)
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||||
|
} else {
|
||||||
|
records = dao.FindResolveRecordByNameType(queryName, constant.R_CNAME)
|
||||||
|
}
|
||||||
if len(records) > 0 {
|
if len(records) > 0 {
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
fmt.Printf("=== CNAME记录 === 请求解析的域名:%s,解析的目标域名:%s\n", name, record.Value)
|
fmt.Printf("=== CNAME记录 === 请求解析的域名:%s,解析的目标域名:%s\n", name, record.Value)
|
||||||
|
@ -114,7 +144,16 @@ func handleCNAMERecord(q dns.Question, msg *dns.Msg) {
|
||||||
func handleMXRecord(q dns.Question, msg *dns.Msg) {
|
func handleMXRecord(q dns.Question, msg *dns.Msg) {
|
||||||
name := q.Name
|
name := q.Name
|
||||||
queryName := name[0 : len(name)-1]
|
queryName := name[0 : len(name)-1]
|
||||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_MX)
|
var records []dao.ResolveRecord
|
||||||
|
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_MX)
|
||||||
|
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||||
|
if ok {
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||||
|
records = value.([]dao.ResolveRecord)
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||||
|
} else {
|
||||||
|
records = dao.FindResolveRecordByNameType(queryName, constant.R_MX)
|
||||||
|
}
|
||||||
if len(records) > 0 {
|
if len(records) > 0 {
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
fmt.Printf("=== MX记录 === 请求解析的域名:%s,解析的目标域名:%s, MX优先级: 10\n", name, record.Value)
|
fmt.Printf("=== MX记录 === 请求解析的域名:%s,解析的目标域名:%s, MX优先级: 10\n", name, record.Value)
|
||||||
|
@ -137,7 +176,16 @@ func handleMXRecord(q dns.Question, msg *dns.Msg) {
|
||||||
func handleTXTRecord(q dns.Question, msg *dns.Msg) {
|
func handleTXTRecord(q dns.Question, msg *dns.Msg) {
|
||||||
name := q.Name
|
name := q.Name
|
||||||
queryName := name[0 : len(name)-1]
|
queryName := name[0 : len(name)-1]
|
||||||
records := dao.FindResolveRecordByNameType(queryName, constant.R_TXT)
|
var records []dao.ResolveRecord
|
||||||
|
cacheKey := fmt.Sprintf("%s-%s", queryName, constant.R_TXT)
|
||||||
|
value, ok := cache.KeyResolveRecordMap.Load(cacheKey)
|
||||||
|
if ok {
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache start")
|
||||||
|
records = value.([]dao.ResolveRecord)
|
||||||
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " [Cache] Query cache end")
|
||||||
|
} else {
|
||||||
|
records = dao.FindResolveRecordByNameType(queryName, constant.R_TXT)
|
||||||
|
}
|
||||||
if len(records) > 0 {
|
if len(records) > 0 {
|
||||||
for _, record := range records {
|
for _, record := range records {
|
||||||
fmt.Printf("=== TXT记录 === 请求解析的域名:%s,解析的目标值:%s\n", name, record.Value)
|
fmt.Printf("=== TXT记录 === 请求解析的域名:%s,解析的目标值:%s\n", name, record.Value)
|
||||||
|
|
|
@ -21,13 +21,13 @@ type ResolveRecord struct {
|
||||||
Version int `xorm:"not null integer 'version'"`
|
Version int `xorm:"not null integer 'version'"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func FindResolveRecordById(id int) []ResolveRecord {
|
func FindResolveRecordById(id int) ResolveRecord {
|
||||||
var records []ResolveRecord
|
var record ResolveRecord
|
||||||
err := Engine.Table("resolve_record").Where("`id` = ?", id).Find(&records)
|
_, err := Engine.Table("resolve_record").Where("`id` = ?", id).Get(&record)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
}
|
}
|
||||||
return records
|
return record
|
||||||
}
|
}
|
||||||
|
|
||||||
func FindResolveRecordByVersion(version int) []ResolveRecord {
|
func FindResolveRecordByVersion(version int) []ResolveRecord {
|
||||||
|
|
BIN
dns.sqlite3
BIN
dns.sqlite3
Binary file not shown.
10
main.go
10
main.go
|
@ -18,7 +18,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
fmt.Println("[app] [info] kenaito-dns version = " + config.AppVersion)
|
fmt.Println("[app] [info] " + time.Now().Format(config.AppTimeFormat) + " kenaito-dns version = " + config.AppVersion)
|
||||||
go cache.ReloadCache()
|
go cache.ReloadCache()
|
||||||
go initDNSServer()
|
go initDNSServer()
|
||||||
initRestfulServer()
|
initRestfulServer()
|
||||||
|
@ -30,9 +30,9 @@ func initDNSServer() {
|
||||||
// 设置服务器地址和协议
|
// 设置服务器地址和协议
|
||||||
server := &dns.Server{Addr: config.DnsServerPort, Net: "udp"}
|
server := &dns.Server{Addr: config.DnsServerPort, Net: "udp"}
|
||||||
// 开始监听
|
// 开始监听
|
||||||
fmt.Printf("[dns] [info] Starting DNS server on %s\n", server.Addr)
|
fmt.Printf("[dns] [info] "+time.Now().Format(config.AppTimeFormat)+" Starting DNS server on %s\n", server.Addr)
|
||||||
if err := server.ListenAndServe(); err != nil {
|
if err := server.ListenAndServe(); err != nil {
|
||||||
fmt.Printf("[dns] [error] Failed to start DNS server: %s\n", err.Error())
|
fmt.Printf("[dns] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start DNS server: %s\n", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,9 +75,9 @@ func initRestfulServer() {
|
||||||
WriteTimeout: config.WebWriteTimeout * time.Second,
|
WriteTimeout: config.WebWriteTimeout * time.Second,
|
||||||
}
|
}
|
||||||
controller.InitRestFunc(router)
|
controller.InitRestFunc(router)
|
||||||
fmt.Printf("[gin] [info] Start Gin server: %s\n", config.WebServerPort)
|
fmt.Printf("[gin] [info] "+time.Now().Format(config.AppTimeFormat)+" Start Gin server: %s\n", config.WebServerPort)
|
||||||
err := server.ListenAndServe()
|
err := server.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fmt.Printf("[gin] [error] Failed to start Gin server: %s\n", config.WebServerPort)
|
fmt.Printf("[gin] [error] "+time.Now().Format(config.AppTimeFormat)+" Failed to start Gin server: %s\n", config.WebServerPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue