golang实现简单网关

golang实现简单网关 #

网关=反向代理+负载均衡+各种策略,技术实现也有多种多样,有基于nginx使用lua的实现,比如openresty、kong;也有基于zuul的通用网关;还有就是golang的网关,比如tyk。

这篇文章主要是讲如何基于golang实现一个简单的网关。

1. 预备 #

1.1. 准备两个后端web服务 #

启动两个后端web服务(代码)
type RealServer struct {
	Addr string
}

func (r *RealServer) Run() {
	log.Println("start http server at " + r.Addr)
	mux := http.NewServeMux()
	mux.HandleFunc("/", r.EchoHandler)
	mux.HandleFunc("/base/error", r.ErrorHandler)
	mux.HandleFunc("/timeout", r.TimeoutHandler)

	server := &http.Server{
		Addr:         r.Addr,
		WriteTimeout: time.Second * 3,
		Handler:      mux,
	}

	go func() {
		log.Fatal(server.ListenAndServe())
	}()
}

func (r *RealServer) EchoHandler(w http.ResponseWriter, req *http.Request) {
	upath := fmt.Sprintf("http://%s%s\n", r.Addr, req.URL.Path)
	realIP := fmt.Sprintf("RemoteAddr=%s,X-Forwarded-For=%v,X-Real-Ip=%v\n", req.RemoteAddr, req.Header.Get("X-Forwarded-For"), req.Header.Get("X-Real-Ip"))
	header := fmt.Sprintf("headers =%v\n", req.Header)
	io.WriteString(w, upath)
	io.WriteString(w, realIP)
	io.WriteString(w, header)
}

func (r *RealServer) ErrorHandler(w http.ResponseWriter, req *http.Request) {
	w.WriteHeader(500)
	io.WriteString(w, "error handler")
}

func (r *RealServer) TimeoutHandler(w http.ResponseWriter, req *http.Request) {
	time.Sleep(6 * time.Second)
	w.WriteHeader(200)
	io.WriteString(w, "timeout handler")
}

func main() {
	rs1 := &RealServer{Addr: "127.0.0.1:2003"}
	rs1.Run()
	rs2 := &RealServer{Addr: "127.0.0.1:2004"}
	rs2.Run()

	quit := make(chan os.Signal)
	signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
	<-quit
}

1.2. 访问工具 #

这里使用命令行工具进行测试

curl -v http://localhost:2002/base

2. 反向代理 #

2.1. 单后端(target)反向代理 #

具体代码
package main

import (
	"log"
	"net/http"
	"net/http/httputil"
	"net/url"
)

var (
	addr = "127.0.0.1:2002"
)

func main()  {
	rsUrl, _:=url.Parse("http://127.0.0.1:2003/base")
	reversePorxy := httputil.NewSingleHostReverseProxy(rsUrl)
	log.Println("Starting Httpserver at " + addr)
	log.Fatal(http.ListenAndServe(addr, reversePorxy))
}

直接使用基础库httputil提供的NewSingleHostReverseProxy即可,返回的reverseProxy对象实现了serveHttp方法,因此可以直接作为handler。

2.2. 分析反向代理代码,并添加修改response内容 #

具体代码
package main

import (
	"bytes"
	"fmt"
	"io/ioutil"
	"log"
	"net/http"
	"net/http/httputil"
	"net/url"
	"strings"
)

var (
	addr = "127.0.0.1:2002"
)

func main()  {
	rsUrl, _:=url.Parse("http://127.0.0.1:2003/base")
	reversePorxy := NewSingleHostReverseProxy(rsUrl)
	log.Println("Starting httpserver at " + addr)
	log.Fatal(http.ListenAndServe(addr, reversePorxy))
}

func NewSingleHostReverseProxy(target *url.URL) *httputil.ReverseProxy {
	targetQuery := target.RawQuery
	director := func(req *http.Request) {
		req.URL.Scheme = target.Scheme
		req.URL.Host = target.Host
		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
		if targetQuery == "" || req.URL.RawQuery == "" {
			req.URL.RawQuery = targetQuery + req.URL.RawQuery
		} else {
			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
		}
		if _, ok := req.Header["User-Agent"]; !ok {
			// explicitly disable User-Agent so it's not set to default value
			req.Header.Set("User-Agent", "")
		}
		// add when the reverseproxy is the first rp
		req.Header.Set("X-Real-Ip", strings.Split(req.RemoteAddr, ":")[0])
	}

	modifyFunc := func(res *http.Response) error {
		if res.StatusCode != http.StatusOK {
		    oldPayLoad, err := ioutil.ReadAll(res.Body)

		    if err != nil {
		    	return err
	    	}
	    	newPayLoad := []byte("hello " + string(oldPayLoad))

	    	res.Body = ioutil.NopCloser(bytes.NewBuffer(newPayLoad))
	    	res.ContentLength = int64(len(newPayLoad))
		    res.Header.Set("Content-Length",fmt.Sprint(len(newPayLoad)))
		}
		return nil
	}
	return &httputil.ReverseProxy{Director: director, ModifyResponse: modifyFunc}
}

func singleJoiningSlash(a, b string) string {
	aslash := strings.HasSuffix(a, "/")
	bslash := strings.HasPrefix(b, "/")
	switch {
	case aslash && bslash:
		return a + b[1:]
	case !aslash && !bslash:
		return a + "/" + b
	}
	return a + b
}

director中定义回调函数,入参为*http.Request,决定如何构造向后端的请求,比如host是否向后传递,是否进行url重写,对于header的处理,后端target的选择等,都可以在这里完成。

director在这里具体做了:

  1. 根据后端target,构造到后端请求的url
  2. 选择性传递必要的header
  3. 设置代理相关的header,比如X-Forwarded-ForX-Real-Ip
    1. X-Forwarded-For记录经过的所有代理,以proxyIp01, proxyIp02, proxyIp03的格式记录,由于是追加,可能被篡改,当然,如果第一代理以覆盖该头的方式进行记录,也是可信的
    2. X-Real-Ip用于记录客户端IP,一般放在第一代理上,用于记录客户端的来源公网IP,可信

modifyResponse中定义回调函数,入参为*http.Response,用于修改响应的信息,比如响应的Body,响应的Header等信息。

最终依旧是返回一个ReverseProxy,然后将这个对象作为handler传入即可。

2.3. 支持多个后端服务器 #

参考2.2 中的NewSingleHostReverseProxy,只需要实现一个类似的、支持多targets的方法即可,具体实现见后面。

3. 负载均衡 #

作为一个网关服务,在上面2.3的基础上,需要支持必要的负载均衡策略,比如:

  • 随机
  • 轮询
  • 加权轮询
  • 一致性hash

3.1. 负载均衡算法 #

3.1.1. 随机 #

随便random一个整数作为索引,然后取对应的地址即可,实现比较简单。

具体代码
type RandomN struct {
	rss []string
}

func (r *RandomN) Add(params ...string) error {
	if len(params) != 1 {
		return fmt.Errorf("param length should be one")
	}

	r.rss = append(r.rss, params[0])

	return nil
}

func (r *RandomN) Next() string {
	if len(r.rss)	== 0 {
		return ""
	}

	return r.rss[rand.Intn(len(r.rss))]
}

func (r *RandomN) Get(key string) (string, error) {
	return r.Next(), nil
}

3.1.2. 轮询 #

使用curIndex进行累加计数,一旦超过rss数组的长度,则重置。

具体代码
type RR struct {
	curIndex int
	rss []string
}

func (r *RR) Add(params ...string) error {
	if len(params) != 1 {
		return fmt.Errorf("param length should be one")
	}

	r.rss = append(r.rss, params[0])

	return nil
}

func (r *RR) Next() string {
	if len(r.rss)	== 0 {
		return ""
	}

	if r.curIndex == len(r.rss) {
		r.curIndex = 0
	}

	node := r.rss[r.curIndex]

	r.curIndex++
	return node
}

func (r *RR) Get(key string) (string, error) {
	return r.Next(), nil
}

3.1.3. 加权轮询 #

轮询带权重,如果使用计数递减的方式,如果权重是5,1,1那么后端rs依次为a,a,a,a,a,b,c,a,a,a,a...,其中a后端会瞬间压力过大;参考nginx内部的加权轮询,或者应该称之为平滑加权轮询,思路是:

后端真实节点包含三个权重:

  • 本身权重weight —— 设置的权重
  • 有效权重effectiveWeight —— 根据后端节点健康状态动态变化,当异常时,减一;当正常时,加一,最多到weight值
  • 当前权重curWeight —— 初始值为weight,计算时curWeight += effectiveWeight,如果curWeight最大,则被选中,然后curWeight -= total

操作步骤:

  1. 计算curWeight
  2. 选取最大curWeight的节点
  3. 重新计算curWeight
具体代码
type WeightedRR struct {
	rss []*WeightedNode
}

type WeightedNode struct {
	addr string
	weight int
	curWeight int
	effectiveWeight int
}

func (r *WeightedRR) Add(params ...string) error {
	if len(params) != 2 {
		return fmt.Errorf("param length should be two")
	}

	addr := params[0]
	weight, err := strconv.ParseInt(params[1], 10, 64)

	if err != nil {
		return err
	}

	node := &WeightedNode{
		addr:            addr,
		weight:          int(weight),
		curWeight:       int(weight),
		effectiveWeight: int(weight),
	}

	r.rss = append(r.rss, node)

	return nil
}

func (r *WeightedRR) Next() string {
	// 平滑加权轮询 --> 1 计算total, 2 变更临时权重 3. 选择最大临时权重 4。 变更临时权重
	total := 0
	var best *WeightedNode

	for _, node := range r.rss {
		n := node

		total += n.effectiveWeight

		n.curWeight += n.effectiveWeight

		if best == nil || n.curWeight > best.curWeight {
			best = n
		}
	}

	if best == nil {
		return ""
	}

	best.curWeight -= total

	return best.addr
}

func (r *WeightedRR) Get(key string) (string, error) {
	return r.Next(), nil
}

3.1.4. 一致性hash #

一致性hash算法,主要是用于分布式cache热点/命中问题;这里用于基于某key的hash值,路由到固定后端,但是只能是基本满足流量绑定,一旦后端目标节点故障,会自动平移到环上最近的那么个节点。

实现:

  • 首先存在一个环,环上的每个点都能被选择的hash函数映射到

  • 然后将后端真实节点+序号(副本数)映射到环上

  • 当请求进来的时候,使用某特定组成的key代入hash函数计算得到一个位置

    • 如果key是由url组成,那就是url hash
    • 如果key是由remoteIp组成,那么就是IP hash
  • 使用二分查找,找到其在环上的下一个节点

具体代码
type Keys []uint32

func (k Keys) Less(i,  j int) bool {
	return k[i] < k[j]
}

func (k Keys) Swap(i, j int)  {
	k[i], k[j] = k[j], k[i]
}

func (k Keys) Len() int {
	return len(k)
}

type ConsistentHash struct {
	mux sync.RWMutex
	hash func(data []byte) uint32
	replicas int
	keys Keys
	hashMap map[uint32]string
}

func NewConsistentHash(replicas int, fn func(data []byte) uint32) *ConsistentHash {
	m := &ConsistentHash{
		hash:     fn,
		replicas: replicas,
		hashMap:  make((map[uint32]string)),
	}

	if m.hash == nil {
		m.hash = crc32.ChecksumIEEE
	}

	return m
}

func (c *ConsistentHash) Add(params ...string) error {
	if len(params) == 0 {
		return errors.New("param len 1 at least")
	}
	addr := params[0]
	c.mux.Lock()
	defer c.mux.Unlock()

	for i := 0; i < c.replicas; i++ {
		hash := c.hash([]byte(strconv.Itoa(i) + addr))
		c.keys = append(c.keys, hash)
		c.hashMap[hash] = addr
	}

	sort.Sort(c.keys)
	return nil
}

func (c *ConsistentHash) IsEmpty() bool {
	return len(c.keys) == 0
}

func (c *ConsistentHash) Get(key string) (string, error) {
	if c.IsEmpty() {
		err := fmt.Errorf("nodes empty")
		return "", err
	}

	hash := c.hash([]byte(key))

	idx := sort.Search(len(c.keys), func(i int) bool {
		return c.keys[i] >= hash
	})

	if idx == len(c.keys) {
		idx = 0
	}
	return c.hashMap[c.keys[idx]], nil
}

3.2. 通用接口/工厂模式 #

type LoadBalanceStrategy interface {
	Add(...string) error
	Get(string) (string, error)
}

每一种不同的负载均衡算法,只需要实现添加以及获取的接口即可。

type LbType int

const (
	LbRandom LbType = iota
	LbRoundRobin
	LbWeightRoundRobin
	LbConsistentHash
)

func LoadBanlanceFactory(lbType LbType) LoadBalanceStrategy {
	switch lbType {
	case LbRandom:
		return &RandomN{}
	case LbConsistentHash:
		return NewConsistentHash(10, nil)
	case LbRoundRobin:
		return &RR{}
	case LbWeightRoundRobin:
		return &WeightedRR{}
	default:
		return &RR{}
	}
}

然后使用工厂方法,根据传入的参数,决定使用哪种负载均衡策略。

3.3. 支持负载均衡算法的反向代理实现 #

  • 使用LoadBanlanceFactory工厂函数,传入负载均衡类型,获取负载均衡对象
  • 添加后端真实节点
  • 然后初始化NewMultiTargetsReverseProxy,在director回调函数中,根据负载均衡策略获取要请求的后端真实节点
  • 剩下的逻辑同2.2
具体代码
func NewMultiTargetsReverseProxy(lb lb_strategy.LoadBalanceStrategy) *httputil.ReverseProxy {
	director := func(req *http.Request) {
		remoteIP := strings.Split(req.RemoteAddr, ":")[0]
		nextAddr, err := lb.Get(remoteIP)

		if err != nil {
			log.Fatal("get next addr fail")
		}

		target, err := url.Parse(nextAddr)
		if err != nil {
			log.Fatal(err)
		}
		targetQuery := target.RawQuery
		req.URL.Scheme = target.Scheme
		req.URL.Host = target.Host
		req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
		if targetQuery == "" || req.URL.RawQuery == "" {
			req.URL.RawQuery = targetQuery + req.URL.RawQuery
		} else {
			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
		}
		if _, ok := req.Header["User-Agent"]; !ok {
			req.Header.Set("User-Agent", "user-agent")
		}
	}

	modifyFunc := func(resp *http.Response) error {
		//请求以下命令:curl 'http://127.0.0.1:2002/error'
		if resp.StatusCode != 200 {
			//获取内容
			oldPayload, err := ioutil.ReadAll(resp.Body)
			if err != nil {
				return err
			}
			//追加内容
			newPayload := []byte("StatusCode error:" + string(oldPayload))
			resp.Body = ioutil.NopCloser(bytes.NewBuffer(newPayload))
			resp.ContentLength = int64(len(newPayload))
			resp.Header.Set("Content-Length", strconv.FormatInt(int64(len(newPayload)), 10))
		}
		return nil
	}

	errFunc := func(w http.ResponseWriter, r *http.Request, err error) {
		//todo 如果是权重的负载则调整临时权重
		http.Error(w, "ErrorHandler error:"+err.Error(), 500)
	}

	return &httputil.ReverseProxy{Director: director, Transport: transport, ModifyResponse: modifyFunc, ErrorHandler: errFunc}
}

func singleJoiningSlash(a, b string) string {
	aslash := strings.HasSuffix(a, "/")
	bslash := strings.HasPrefix(b, "/")
	switch {
	case aslash && bslash:
		return a + b[1:]
	case !aslash && !bslash:
		return a + "/" + b
	}
	return a + b
}

func main()  {
	rb := lb_strategy.LoadBanlanceFactory(lb_strategy.LbConsistentHash)
	rb.Add("http://127.0.0.1:2003/base")
	rb.Add("http://127.0.0.1:2004/base")
	rb.Add("http://127.0.0.1:2005/base")

	proxy := NewMultiTargetsReverseProxy(rb)

	log.Println("Starting httpserver at " + addr)
	log.Fatal(http.ListenAndServe(addr, proxy))
}

4. 中间件 #

作为网关,中间件必不可少,这类包括请求响应的模式,一般称作洋葱模式,每一层都是中间件,一层层进去,然后一层层出来。

中间件的实现一般有两种,一种是使用数组,然后配合index计数;一种是链式调用。

4.1. 基于数组的中间件实现 #

  1. NewSliceRouterHandler 获取SliceRouterHandler对象,该对象实现了Hanlder接口,可以作为handler传入http服务
    • ServeHTTP方法中,调用newSliceRouterContext初始化SliceRouterContext,并且根据req中的url,按照最长url前缀匹配的规则寻找groups中满足条件的SliceGroup丢给SliceRouterContext
    • ServeHTTP方法中,调用Next方法开始整个handlers数组的handler调用
  2. SliceRouterHandler包含coreFunc以及SliceRouter对象
  3. SliceRouter包含SliceGroup列表
  4. SliceGroup对象包含path以及handlers
    • 使用Use方法来添加中间件,并且去重添加到SliceRouter中的groups中去
    • 使用Group方法初始化一个SliceGroup
  5. 贯穿整条调用链的是SliceRouterContext对象,包含:
    • SliceGroup指针
    • ResponseWriter
    • Request指针
    • Context
    • index索引
  6. 中间件中可以调用SliceRouterContext中的Next方法继续,也可以调用Abort方法进行终止
  7. Abort终止的方式就是设置索引index为abortIndex
具体代码
const abortIndex int8 = math.MaxInt8 / 2

type HandlerFunc func(*SliceRouterContext)

type SliceRouter struct {
   groups []* SliceGroup
}

type SliceGroup struct {
   *SliceRouter
   path string
   handlers []HandlerFunc
}


// slice router context
type SliceRouterContext struct {
   *SliceGroup
   RespW http.ResponseWriter
   Req *http.Request
   Ctx context.Context
   index int8
}

func newSliceRouterContext(rw http.ResponseWriter, req *http.Request, r *SliceRouter) *SliceRouterContext  {
   newSliceGroup := &SliceGroup{}

   matchUrlLen := 0

   for _, group := range r.groups {
      if strings.HasPrefix(req.RequestURI, group.path) {
         pathLen := len(group.path)
         if pathLen > matchUrlLen {
            matchUrlLen = pathLen
            *newSliceGroup = *group //浅拷贝数组指针
         }
      }
   }

   c := &SliceRouterContext{RespW: rw, Req: req, SliceGroup: newSliceGroup, Ctx: req.Context()}
   c.Reset()
   return c
}

// 获取上下文值
func (ctx *SliceRouterContext) Get(key interface{}) interface{} {
   return ctx.Ctx.Value(key)
}

// 设置上下文值
func (ctx *SliceRouterContext) Set(key, val interface{}) {
   ctx.Ctx = context.WithValue(ctx.Ctx, key, val)
}

//
func (ctx *SliceRouterContext) Next()  {
   ctx.index++

   for ctx.index < int8(len(ctx.groups)) {
      ctx.handlers[ctx.index](ctx)
      ctx.index++
   }
}

// 重置handlers数组计数
func (ctx *SliceRouterContext) Reset()  {
   ctx.index = -1
}

func (ctx *SliceRouterContext) Abort() {
   ctx.index = abortIndex
}

// 是否跳过了回调
func (ctx *SliceRouterContext) IsAborted() bool {
   return ctx.index >= abortIndex
}

// sliceRouterHandler
type SliceRouterHandler struct {
   coreFunc func(*SliceRouterContext) http.Handler
   router   *SliceRouter
}

func NewSliceRouterHandler(coreFunc func(*SliceRouterContext) http.Handler, router *SliceRouter) *SliceRouterHandler {
   return &SliceRouterHandler{
      coreFunc: coreFunc,
      router:   router,
   }
}

func (w *SliceRouterHandler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
   c := newSliceRouterContext(rw, req, w.router)
   if w.coreFunc != nil {
      c.handlers = append(c.handlers, func(c *SliceRouterContext) {
         w.coreFunc(c).ServeHTTP(rw, req)
      })
   }
   c.Reset()
   c.Next()
}

// 构造 router
func NewSliceRouter() *SliceRouter {
   return &SliceRouter{}
}

// 创建 Group
func (g *SliceRouter) Group(path string) *SliceGroup {
   return &SliceGroup{
      SliceRouter: g,
      path:        path,
   }
}

// 构造回调方法
func (g *SliceGroup) Use(middlewares ...HandlerFunc) *SliceGroup {
   g.handlers = append(g.handlers, middlewares...)
   existsFlag := false
   for _, oldGroup := range g.SliceRouter.groups {
      if oldGroup == g {
         existsFlag = true
      }
   }
   if !existsFlag {
      g.SliceRouter.groups = append(g.SliceRouter.groups, g)
   }
   return g
}
tracelog中间件 具体代码
func TraceLogSliceMiddleware() func(c *SliceRouterContext) {
   return func(c *SliceRouterContext) {
      log.Println("trace_in")
      c.Abort()
      log.Println("trace_out")
   }
}
中间件使用 具体代码
var addr = "127.0.0.1:2002"


func main() {
   reverseProxy := func(c *middleware.SliceRouterContext) http.Handler {
      rs1 := "http://127.0.0.1:2003/base"
      url1, err1 := url.Parse(rs1)
      if err1 != nil {
         log.Println(err1)
      }

      rs2 := "http://127.0.0.1:2004/base"
      url2, err2 := url.Parse(rs2)
      if err2 != nil {
         log.Println(err2)
      }

      urls := []*url.URL{url1, url2}
      return proxy.NewMultipleHostsReverseProxy(c, urls)
   }

   log.Println("Starting httpserver at " + addr)

   sliceRouter := middleware.NewSliceRouter()

   sliceRouter.Group("/base").Use(middleware.TraceLogSliceMiddleware(), func(c *middleware.SliceRouterContext) {
      c.RespW.Write([]byte("test func"))

   })

   sliceRouter.Group("/").Use(middleware.TraceLogSliceMiddleware(), func(c *middleware.SliceRouterContext) {
      fmt.Println("reverseProxy")
      reverseProxy(c).ServeHTTP(c.RespW, c.Req)
   })

   routerHandler := middleware.NewSliceRouterHandler(nil, sliceRouter)
   log.Fatal(http.ListenAndServe(addr, routerHandler))
}