跳转至

限流算法

一. 滑动窗口限流

1. 代码实现

采用一个滚动数组来承载 unit 窗口,使用 curIdx 标记当前的窗口下标。在最开始启动一个协程执行 curIdx 的定时更新

每次只需要在 curIdx 的位置更新当前 unit 的 count 即可

type SlidingWindowLimiter struct {
    limit        int
    unitCnt      int
    unitTime     int64
    arrSize      int
    unitCountArr []int
    curIdx       int
    mu        sync.Mutex
}

func NewSlidingWindowLimiter(limit, unitCnt int, unitTime int64, arrSize int) *SlidingWindowLimiter {
    unitCountArr := make([]int, arrSize)
    ret := &SlidingWindowLimiter{
        limit:        limit,
        unitCnt:      unitCnt,
        unitTime:     unitTime,
        arrSize:      arrSize,
        unitCountArr: unitCountArr,
        curIdx:       0,
    }
    go func() {
        ticker := time.NewTicker(time.Duration(ret.unitTime) * time.Millisecond)
        defer ticker.Stop()
        for {
            select {
            case <-ticker.C:
                ret.mu.Lock()
                ret.curIdx = (ret.curIdx + 1) % ret.arrSize
                ret.unitCountArr[ret.curIdx] = 0 // 重置当前时间窗口的计数
                ret.mu.Unlock()
            }
        }
    }()
    return ret
}

func (l *SlidingWindowLimiter) TryAcquire() (bool, error) {
    l.mu.Lock()
    defer l.mu.Unlock()

    totalCount := 0
    endIdx := l.curIdx
    startIdx := (endIdx - l.unitCnt + 1 + l.arrSize) % l.arrSize
    if startIdx > endIdx {
        for i := 0; i <= endIdx; i++ {
            totalCount += l.unitCountArr[i]
        }
        for i := startIdx; i < l.arrSize; i++ {
            totalCount += l.unitCountArr[i]
        }
    } else {
        for i := startIdx; i <= endIdx; i++ {
            totalCount += l.unitCountArr[i]
        }
    }

    if totalCount + 1 > l.limit {
        return false, nil
    }

    l.unitCountArr[l.curIdx]++
    return true, nil
}

2. 优点与缺点

  • 优点:解决了固定窗口带来的窗口边界问题(一个边界的左侧和右侧分别有突发请求)
  • 缺点:虽然相比于固定窗口更加平滑,但依旧不够平滑。比如第一秒发送了 100 个请求达到流量限制;后面的请求没有回旋余地,直接被拒绝

二. 漏桶算法

请求会进行排队,以恒定速率出队,分布式系统下可以使用消息队列来缓存请求队列,消费者服务以恒定速率从队列中拉取

如果使用数组存储请求 + level计数 + mutex 会出现比较大的并发问题

使 channel 作为队列,用 goroutine 来控制漏速

type LeakyBucket struct {
    rate int64
    capacity int64
    queue chan struct{}
    stopCh chan struct{} // stop signal
    mu sync.Mutex
}

func NewLeakyBucket(rate, capacity int64) *LeakyBucket {
    ret := &LeakyBucket{
        rate: rate,
        capacity: capacity,
        queue: make(chan struct{}, capacity),
        stopCh: make(chan struct{}),
    }

    go func() {
        ticker := time.NewTicker(time.Second / time.Duration(rate)) // rate 为每秒放行的请求数量
        for {
            select {
            case <-ticker.C:
                select {
                case req := <-ret.queue:
                    fmt.Println(req)
                default: // 桶为空
                }
            case <-ret.stopCh:
                return
            } 
        }
    }()

    return ret
}

func (l *LeakyBucket) TryEnterBucket(req struct{}) bool {
    select {
    case l.queue<-req:
        return true
    default:
        return false
    }
}

func (l *LeakyBucket) Stop() {
    close(l.stopCh)
}

三. 分布式令牌桶算法

以固定速率向桶中添加令牌。一个请求到来时,如果桶中没有令牌,则被拒绝(或排队等待)

// 使用原子性的 lua 脚本保证 token 扣减的原子性
func TryAcquire(ctx context.Context, key string, rate, capacity int64) bool {
    client := redis.NewClient(&redis.Options{Addr: "localhost:6379"})
    defer client.Close()

    // 0 成功获取令牌,1 失败
    script := `
        local key = KEYS[1]
        local capacity = tonumber(ARGV[1])
        local ttl = tonumber(ARGV[2])

        local current = redis.call('GET', key)

        if not current then
            redis.call('SET', key, capacity, 'EX', ttl)
            current = capacity
        else
            current = tonumber(current)
        end

        if current <= 0 then
            return 0
        end

        local newValue = redis.call('DECR', key)
        if newValue < 0 then
            redis.call('INCR', key)
            return 0
        end

        return 1
    `

    result, err := client.Eval(ctx, script, []string{key}, capacity, 10).Result()
    if err != nil {
        return false
    }

    return result.(int64) == 1
}

// 按照 rate 补充token,应该在后台goroutine中运行
func RefillToken(ctx context.Context, key string, rate, capacity int64) {
    client := redis.NewClient(&redis.Options{Addr: "localhost:6379"})
    defer client.Close()

    script := `
        local key = KEYS[1]
        local rate = tonumber(ARGV[1])
        local capacity = tonumber(ARGV[2])
        local ttl = tonumber(ARGV[3])

        local current = redis.call('GET', key)

        if not current then
            redis.call('SET', key, capacity, 'EX', ttl)
            return capacity
        end

        current = tonumber(current)
        local newValue = math.min(current + rate, capacity)
        redis.call('SET', key, newValue, 'EX', ttl)

        return newValue
    `

    // 每秒补充一次token
    ticker := time.NewTicker(time.Second)
    defer ticker.Stop()

    for {
        select {
        case <-ctx.Done():
            return
        case <-ticker.C:
            client.Eval(ctx, script, []string{key}, rate, capacity, 10)
        }
    }
}