2022-07-05 22:47:35

dnstrace源码学习

用于dns压力测试的工具,就一个单一文件,逻辑也比较清晰简单

参数解析

命令参数包用的是kingpin: github.com/alecthomas/kingpin,看着很是简洁,直接使用func方式来定义参数,默认值,支持长短命令。

var ( pApp = kingpin.New("dnstrace", "A high QPS DNS benchmark.").Author(author) pServer = pApp.Flag("server", "DNS server IP:port to test.").Short('s').Default("127.0.0.1").String() ) func main() { pApp.Version(version) kingpin.MustParse(pApp.Parse(os.Args[1:])) ... }

dns协议包

dns协议用到了实现完备的包: github.com/miekg/dns

抽出demo用法:

package main import ( "github.com/miekg/dns" "log" "time" ) func main() { var msg *dns.Msg dnsServer := "192.168.223.1:53" dnsTimeout := 2 * time.Second co, err := dns.DialTimeout("udp", dnsServer, dnsTimeout) if err != nil { log.Println(err) return } defer co.Close() _ = co.SetDeadline(time.Now().Add(2 * time.Second)) q := dns.Question{Name: dns.Fqdn("sre.wiki"), Qtype: dns.TypeA, Qclass: dns.ClassINET} msg = &dns.Msg{ Compress: false, Question: []dns.Question{q}, Answer: nil, Ns: nil, Extra: nil, } start := time.Now() err = co.WriteMsg(msg) if err != nil { log.Println(err) return } msg, err = co.ReadMsg() if err != nil { log.Println(err) return } log.Printf("rcode: %d, answer: %+v, timeused: %vms", msg.Rcode, msg.Answer, time.Since(start).Milliseconds()) }

限速

限速用uber的包: go.uber.org/ratelimit,简洁明了

limit = ratelimit.New(*pRate) // 在需要限速的地方使用 limit.Take()

抽出demo用法:

package main import ( "fmt" "go.uber.org/ratelimit" "time" ) func main() { limit := ratelimit.New(1) n := 0 for { n += 1 fmt.Println(time.Now(), "count", n) limit.Take() } }

日志颜色

使用 github.com/fatih/color

demo,给限速日志打印加颜色:

package main import ( "github.com/fatih/color" "go.uber.org/ratelimit" "time" ) func main() { limit := ratelimit.New(2) n := 0 redPrintln := color.New(color.FgRed).Println greenPrintln := color.New(color.FgGreen).Println for { n += 1 if n % 2 == 0 { redPrintln(time.Now(), "count", n) } else { greenPrintln(time.Now(), "count", n) } limit.Take() } }

dnstrace完整源码

package main import ( "context" "fmt" "math/rand" "os" "os/signal" "strconv" "strings" "sync" "sync/atomic" "syscall" "time" "github.com/HdrHistogram/hdrhistogram-go" "github.com/alecthomas/kingpin" "github.com/fatih/color" "github.com/miekg/dns" "github.com/olekukonko/tablewriter" "go.uber.org/ratelimit" ) var ( // Tag is set by build at compile time to Git Tag Tag = "" // Commit is set by build at compile time to Git SHA1 Commit = "" author = "Rahul Powar <rahul@redsift.io>" ) var ( pApp = kingpin.New("dnstrace", "A high QPS DNS benchmark.").Author(author) pServer = pApp.Flag("server", "DNS server IP:port to test.").Short('s').Default("127.0.0.1").String() pType = pApp.Flag("type", "Query type.").Short('t').Default("A").Enum("TXT", "A", "AAAA") //TODO: Rest of them pt 1 pCount = pApp.Flag("number", "Number of queries to issue. Note that the total number of queries issued = number*concurrency*len(queries).").Short('n').Default("1").Int64() pConcurrency = pApp.Flag("concurrency", "Number of concurrent queries to issue.").Short('c').Default("1").Uint32() pRate = pApp.Flag("rate-limit", "Apply a global questions / second rate limit.").Short('l').Default("0").Int() pQperConn = pApp.Flag("query-per-conn", "Queries on a connection before creating a new one. 0: unlimited").Default("0").Int64() pExpect = pApp.Flag("expect", "Expect a specific response.").Short('e').Strings() pRecurse = pApp.Flag("recurse", "Allow DNS recursion.").Short('r').Default("false").Bool() pUdpSize = pApp.Flag("edns0", "Enable EDNS0 with specified size.").Default("0").Uint16() pTCP = pApp.Flag("tcp", "Use TCP fot DNS requests.").Default("false").Bool() pWriteTimeout = pApp.Flag("write", "DNS write timeout.").Default("1s").Duration() pReadTimeout = pApp.Flag("read", "DNS read timeout.").Default(dnsTimeout.String()).Duration() pRCodes = pApp.Flag("codes", "Enable counting DNS return codes.").Default("true").Bool() pHistMin = pApp.Flag("min", "Minimum value for timing histogram.").Default((time.Microsecond * 400).String()).Duration() pHistMax = pApp.Flag("max", "Maximum value for histogram.").Default(dnsTimeout.String()).Duration() pHistPre = pApp.Flag("precision", "Significant figure for histogram precision.").Default("1").PlaceHolder("[1-5]").Int() pHistDisplay = pApp.Flag("distribution", "Display distribution histogram of timings to stdout.").Default("true").Bool() pCsv = pApp.Flag("csv", "Export distribution to CSV.").Default("").PlaceHolder("/path/to/file.csv").String() pIOErrors = pApp.Flag("io-errors", "Log I/O errors to stderr.").Default("false").Bool() pSilent = pApp.Flag("silent", "Disable stdout.").Default("false").Bool() pColor = pApp.Flag("color", "ANSI Color output.").Default("true").Bool() pQueries = pApp.Arg("queries", "Queries to issue.").Required().Strings() ) var ( count int64 cerror int64 ecount int64 success int64 matched int64 mismatch int64 ) const dnsTimeout = time.Second * 2 type rstats struct { codes map[int]int64 hist *hdrhistogram.Histogram } func isExpected(a string) bool { for _, b := range *pExpect { if b == a { return true } } return false } // 具体干活 func do(ctx context.Context) []*rstats { // 获取请求,支持多个域名 questions := make([]string, len(*pQueries)) for i, q := range *pQueries { // 转换为fqdn标准格式 questions[i] = dns.Fqdn(q) } // dns类型 qType := dns.TypeNone switch *pType { // TODO: Rest of them pt 2 case "TXT": qType = dns.TypeTXT case "A": qType = dns.TypeA case "AAAA": qType = dns.TypeAAAA default: panic(fmt.Errorf("unknown type %q", *pType)) } // dns服务 srv := *pServer if !strings.Contains(srv, ":") { srv += ":53" } // 协议 network := "udp" if *pTCP { network = "tcp" } // 并行度 concurrent := *pConcurrency // 限速 limitStr := "" var limit ratelimit.Limiter if *pRate > 0 { limit = ratelimit.New(*pRate) limitStr = fmt.Sprintf("(limited to %d QPS)", *pRate) } // 静默 if !*pSilent { fmt.Printf("Benchmarking %s via %s with %d concurrent requests %s\n\n", srv, network, concurrent, limitStr) } // 状态 stats := make([]*rstats, concurrent) // wg控制并发 var wg sync.WaitGroup var w uint32 for w = 0; w < concurrent; w++ { st := &rstats{hist: hdrhistogram.New(pHistMin.Nanoseconds(), pHistMax.Nanoseconds(), *pHistPre)} stats[w] = st if *pRCodes { st.codes = make(map[int]int64) } var co *dns.Conn var err error wg.Add(1) // 起routine go func(st *rstats) { defer func() { if co != nil { co.Close() } wg.Done() }() // 构造 var r *dns.Msg m := new(dns.Msg) // 是否允许递归 m.RecursionDesired = *pRecurse m.Question = make([]dns.Question, 1) question := dns.Question{Name: "", Qtype: qType, Qclass: dns.ClassINET} // create a new lock free rand source for this goroutine rando := rand.New(rand.NewSource(time.Now().Unix())) var i int64 for i = 0; i < *pCount; i++ { for _, q := range questions { if ctx.Err() != nil { return } if co != nil && *pQperConn > 0 && i%*pQperConn == 0 { co.Close() co = nil } atomic.AddInt64(&count, 1) // instead of setting the question, do this manually for lower overhead and lock free access to id question.Name = q m.Id = uint16(rando.Uint32()) m.Question[0] = question // 初始化co if co == nil { co, err = dns.DialTimeout(network, srv, dnsTimeout) if err != nil { atomic.AddInt64(&cerror, 1) // 如果有网络错误,就打印 if *pIOErrors { fmt.Fprintln(os.Stderr, "i/o error dialing: ", err.Error()) } continue } // 如果udpSize大于0,那么就设置edns0 if udpSize := *pUdpSize; udpSize > 0 { m.SetEdns0(udpSize, true) co.UDPSize = udpSize } } // 限速 if limit != nil { limit.Take() } // 设置超时,发起请求 start := time.Now() co.SetWriteDeadline(start.Add(*pWriteTimeout)) if err = co.WriteMsg(m); err != nil { // error writing atomic.AddInt64(&ecount, 1) if *pIOErrors { fmt.Fprintln(os.Stderr, "i/o error writing: ", err.Error()) } co.Close() co = nil continue } // 读取返回 co.SetReadDeadline(time.Now().Add(*pReadTimeout)) r, err = co.ReadMsg() if err != nil { // error reading atomic.AddInt64(&ecount, 1) if *pIOErrors { fmt.Fprintln(os.Stderr, "i/o error reading: ", err.Error()) } co.Close() co = nil continue } // 耗时 timing := time.Since(start) st.hist.RecordValue(timing.Nanoseconds()) if r.Rcode == dns.RcodeSuccess { if r.Id != m.Id { atomic.AddInt64(&mismatch, 1) continue } atomic.AddInt64(&success, 1) // 判断期待返回 if expect := *pExpect; len(expect) > 0 { for _, s := range r.Answer { ok := false switch s.Header().Rrtype { //TODO: Rest of them pt 3 case dns.TypeA: a := s.(*dns.A) ok = isExpected(a.A.To4().String()) case dns.TypeAAAA: a := s.(*dns.AAAA) ok = isExpected(a.AAAA.String()) case dns.TypeTXT: t := s.(*dns.TXT) ok = isExpected(strings.Join(t.Txt, "")) } if ok { atomic.AddInt64(&matched, 1) break } } } } if st.codes != nil { var c int64 if v, ok := st.codes[r.Rcode]; ok { c = v } c++ st.codes[r.Rcode] = c } } } }(st) } wg.Wait() return stats } func printProgress() { if *pSilent { return } fmt.Println() errorFprint := color.New(color.FgRed).Fprint successFprint := color.New(color.FgGreen).Fprint total := uint64(*pCount) * uint64(len(*pQueries)) * uint64(*pConcurrency) acount := atomic.LoadInt64(&count) acerror := atomic.LoadInt64(&cerror) aecount := atomic.LoadInt64(&ecount) amismatch := atomic.LoadInt64(&mismatch) asuccess := atomic.LoadInt64(&success) amatched := atomic.LoadInt64(&matched) fmt.Printf("Total requests:\t %d of %d (%0.1f%%)\n", acount, total, 100.0*float64(acount)/float64(total)) if acerror > 0 || aecount > 0 { errorFprint(os.Stdout, "Connection errors:\t", acerror, "\n") errorFprint(os.Stdout, "Read/Write errors:\t", aecount, "\n") } if amismatch > 0 { errorFprint(os.Stdout, "ID mismatch errors:\t", amismatch, "\n") } successFprint(os.Stdout, "DNS success codes:\t", asuccess, "\n") if len(*pExpect) > 0 { expect := successFprint if amatched != asuccess { expect = errorFprint } expect(os.Stdout, "Expected results:\t", amatched, "\n") } } func printReport(t time.Duration, stats []*rstats, csv *os.File) { defer func() { if csv != nil { csv.Close() } }() // merge all the stats here timings := hdrhistogram.New(pHistMin.Nanoseconds(), pHistMax.Nanoseconds(), *pHistPre) codeTotals := make(map[int]int64) for _, s := range stats { timings.Merge(s.hist) if s.codes != nil { for k, v := range s.codes { codeTotals[k] = codeTotals[k] + v } } } if csv != nil { writeBars(csv, timings.Distribution()) fmt.Println() fmt.Println("DNS distribution written to", csv.Name()) } if *pSilent { return } printProgress() if len(codeTotals) > 0 { errorFprint := color.New(color.FgRed).Fprint successFprint := color.New(color.FgGreen).Fprint fmt.Println() fmt.Println("DNS response codes") for i := dns.RcodeSuccess; i <= dns.RcodeBadCookie; i++ { printFn := errorFprint if i == dns.RcodeSuccess { printFn = successFprint } if c, ok := codeTotals[i]; ok { printFn(os.Stdout, "\t", dns.RcodeToString[i]+":\t", c, "\n") } } } fmt.Println() fmt.Println("Time taken for tests:\t", t.String()) fmt.Printf("Questions per second:\t %0.1f\n", float64(count)/t.Seconds()) min := time.Duration(timings.Min()) mean := time.Duration(timings.Mean()) sd := time.Duration(timings.StdDev()) max := time.Duration(timings.Max()) if tc := timings.TotalCount(); tc > 0 { fmt.Println() fmt.Println("DNS timings,", tc, "datapoints") fmt.Println("\t min:\t\t", min) fmt.Println("\t mean:\t\t", mean) fmt.Println("\t [+/-sd]:\t", sd) fmt.Println("\t max:\t\t", max) dist := timings.Distribution() if *pHistDisplay && tc > 1 { fmt.Println() fmt.Println("DNS distribution,", tc, "datapoints") printBars(dist) } } } func writeBars(f *os.File, bars []hdrhistogram.Bar) { f.WriteString("From (ns), To (ns), Count\n") for _, b := range bars { f.WriteString(b.String()) } } func printBars(bars []hdrhistogram.Bar) { counts := make([]int64, 0, len(bars)) lines := make([][]string, 0, len(bars)) added := false var max int64 for _, b := range bars { if b.Count == 0 && !added { // trim the start continue } if b.Count > max { max = b.Count } added = true line := make([]string, 3) lines = append(lines, line) counts = append(counts, b.Count) line[0] = time.Duration(b.To/2 + b.From/2).String() line[2] = strconv.FormatInt(b.Count, 10) } for i, l := range lines { l[1] = makeBar(counts[i], max) } table := tablewriter.NewWriter(os.Stdout) table.SetHeader([]string{"Latency", "", "Count"}) table.SetBorder(false) table.AppendBulk(lines) table.Render() } func makeBar(c int64, max int64) string { if c == 0 { return "" } t := int((43 * float64(c) / float64(max)) + 0.5) return strings.Repeat("▄", t) } const fileNoBuffer = 9 // app itself needs about 9 for libs func main() { version := "unknown" if Tag == "" { if Commit != "" { version = Commit } } else { version = fmt.Sprintf("%s-%s", Tag, Commit) } pApp.Version(version) kingpin.MustParse(pApp.Parse(os.Args[1:])) // process args color.NoColor = !*pColor var rLimit syscall.Rlimit if err := syscall.Getrlimit(syscall.RLIMIT_NOFILE, &rLimit); err == nil { if needed := uint64(*pConcurrency) + uint64(fileNoBuffer); rLimit.Cur < needed { fmt.Fprintf(os.Stderr, "current process limit for number of files is %d and insufficient for level of requested concurrency.\n", rLimit.Cur) os.Exit(2) } } var csv *os.File if *pCsv != "" { f, err := os.Create(*pCsv) if err != nil { fmt.Fprintln(os.Stderr, err.Error()) os.Exit(2) } csv = f } sigsInt := make(chan os.Signal, 8) signal.Notify(sigsInt, syscall.SIGINT) sigsHup := make(chan os.Signal, 8) signal.Notify(sigsHup, syscall.SIGHUP) defer close(sigsInt) defer close(sigsHup) ctx, cancel := context.WithCancel(context.Background()) go func() { <-sigsInt printProgress() fmt.Fprintln(os.Stderr, "Cancelling benchmark ^C, again to terminate now.") cancel() <-sigsInt os.Exit(130) }() go func() { for range sigsHup { printProgress() } }() // get going rand.Seed(time.Now().UnixNano()) start := time.Now() res := do(ctx) end := time.Now() printReport(end.Sub(start), res, csv) if cerror > 0 || ecount > 0 || mismatch > 0 { // something was wrong os.Exit(1) } }

本文链接:https://troy.wang/post/dnstrace.html

-- EOF --