不依赖外部库的情况下,限流算法有什么实现的思路?本文介绍了3种实现限流的方式。
package main
import (
"fmt"
"sync"
"time"
)
// 每个请求来了,把需要执行的业务逻辑封装成Task,放入木桶,等待worker取出执行
type Task struct {
handler func() Result // worker从木桶中取出请求对象后要执行的业务逻辑函数
resChan chan Result // 等待worker执行并返回结果的channel
taskID int
}
// 封装业务逻辑的执行结果
type Result struct {
}
// 模拟业务逻辑的函数
func handler() Result {
time.Sleep(300 * time.Millisecond)
return Result{}
}
func NewTask(id int) Task {
return Task{
handler: handler,
resChan: make(chan Result),
taskID: id,
}
}
// 漏桶
type LeakyBucket struct {
BucketSize int // 木桶的大小
NumWorker int // 同时从木桶中获取任务执行的worker数量
bucket chan Task // 存方任务的木桶
}
func NewLeakyBucket(bucketSize int, numWorker int) *LeakyBucket {
return &LeakyBucket{
BucketSize: bucketSize,
NumWorker: numWorker,
bucket: make(chan Task, bucketSize),
}
}
func (b *LeakyBucket) validate(task Task) bool {
// 如果木桶已经满了,返回false
select {
case b.bucket <- task:
default:
fmt.Printf("request[id=%d] is refused ", task.taskID)
return false
}
// 等待worker执行
<-task.resChan
fmt.Printf("request[id=%d] is run ", task.taskID)
return true
}
func (b *LeakyBucket) Start() {
// 开启worker从木桶拉取任务执行
go func() {
for i := 0; i < b.NumWorker; i++ {
go func() {
for {
task := <-b.bucket
result := task.handler()
task.resChan <- result
}
}()
}
}()
}
func main() {
bucket := NewLeakyBucket(10, 4)
bucket.Start()
var wg sync.WaitGroup
for i := 0; i < 20; i++ {
wg.Add(1)
go func(id int) {
defer wg.Done()
task := NewTask(id)
bucket.validate(task)
}(i)
}
wg.Wait()
}
package main
import (
"fmt"
"sync"
"time"
)
// 并发访问同一个user_id/ip的记录需要上锁
var recordMu map[string]*sync.RWMutex
func init() {
recordMu = make(map[string]*sync.RWMutex)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
type TokenBucket struct {
BucketSize int // 木桶内的容量:最多可以存放多少个令牌
TokenRate time.Duration // 多长时间生成一个令牌
records map[string]*record // 报错user_id/ip的访问记录
}
// 上次访问时的时间戳和令牌数
type record struct {
last time.Time
token int
}
func NewTokenBucket(bucketSize int, tokenRate time.Duration) *TokenBucket {
return &TokenBucket{
BucketSize: bucketSize,
TokenRate: tokenRate,
records: make(map[string]*record),
}
}
func (t *TokenBucket) getUidOrIp() string {
// 获取请求用户的user_id或者ip地址
return "127.0.0.1"
}
// 获取这个user_id/ip上次访问时的时间戳和令牌数
func (t *TokenBucket) getRecord(uidOrIp string) *record {
if r, ok := t.records[uidOrIp]; ok {
return r
}
return &record{}
}
// 保存user_id/ip最近一次请求时的时间戳和令牌数量
func (t *TokenBucket) storeRecord(uidOrIp string, r *record) {
t.records[uidOrIp] = r
}
// 验证是否能获取一个令牌
func (t *TokenBucket) validate(uidOrIp string) bool {
// 并发修改同一个用户的记录上写锁
rl, ok := recordMu[uidOrIp]
if !ok {
var mu sync.RWMutex
rl = &mu
recordMu[uidOrIp] = rl
}
rl.Lock()
defer rl.Unlock()
r := t.getRecord(uidOrIp)
now := time.Now()
if r.last.IsZero() {
// 第一次访问初始化为最大令牌数
r.last, r.token = now, t.BucketSize
} else {
if r.last.Add(t.TokenRate).Before(now) {
// 如果与上次请求的间隔超过了token rate
// 则增加令牌,更新last
r.token += max(int(now.Sub(r.last) / t.TokenRate), t.BucketSize)
r.last = now
}
}
var result bool
if r.token > 0 {
// 如果令牌数大于1,取走一个令牌,validate结果为true
r.token--
result = true
}
// 保存最新的record
t.storeRecord(uidOrIp, r)
return result
}
// 返回是否被限流
func (t *TokenBucket) IsLimited() bool {
return !t.validate(t.getUidOrIp())
}
func main() {
tokenBucket := NewTokenBucket(5, 100*time.Millisecond)
for i := 0; i< 6; i++ {
fmt.Println(tokenBucket.IsLimited())
}
time.Sleep(100 * time.Millisecond)
fmt.Println(tokenBucket.IsLimited())
}
package main
import (
"fmt"
"sync"
"time"
)
var winMu map[string]*sync.RWMutex
func init() {
winMu = make(map[string]*sync.RWMutex)
}
type timeSlot struct {
timestamp time.Time // 这个timeSlot的时间起点
count int // 落在这个timeSlot内的请求数
}
func countReq(win []*timeSlot) int {
var count int
for _, ts := range win {
count += ts.count
}
return count
}
type SlidingWindowLimiter struct {
SlotDuration time.Duration // time slot的长度
WinDuration time.Duration // sliding window的长度
numSlots int // window内最多有多少个slot
windows map[string][]*timeSlot
maxReq int // win duration内允许的最大请求数
}
func NewSliding(slotDuration time.Duration, winDuration time.Duration, maxReq int) *SlidingWindowLimiter {
return &SlidingWindowLimiter{
SlotDuration: slotDuration,
WinDuration: winDuration,
numSlots: int(winDuration / slotDuration),
windows: make(map[string][]*timeSlot),
maxReq: maxReq,
}
}
// 获取user_id/ip的时间窗口
func (l *SlidingWindowLimiter) getWindow(uidOrIp string) []*timeSlot {
win, ok := l.windows[uidOrIp]
if !ok {
win = make([]*timeSlot, 0, l.numSlots)
}
return win
}
func (l *SlidingWindowLimiter) storeWindow(uidOrIp string, win []*timeSlot) {
l.windows[uidOrIp] = win
}
func (l *SlidingWindowLimiter) validate(uidOrIp string) bool {
// 同一user_id/ip并发安全
mu, ok := winMu[uidOrIp]
if !ok {
var m sync.RWMutex
mu = &m
winMu[uidOrIp] = mu
}
mu.Lock()
defer mu.Unlock()
win := l.getWindow(uidOrIp)
now := time.Now()
// 已经过期的time slot移出时间窗
timeoutOffset := -1
for i, ts := range win {
if ts.timestamp.Add(l.WinDuration).After(now) {
break
}
timeoutOffset = i
}
if timeoutOffset > -1 {
win = win[timeoutOffset+1:]
}
// 判断请求是否超限
var result bool
if countReq(win) < l.maxReq {
result = true
}
// 记录这次的请求数
var lastSlot *timeSlot
if len(win) > 0 {
lastSlot = win[len(win)-1]
if lastSlot.timestamp.Add(l.SlotDuration).Before(now) {
lastSlot = &timeSlot{timestamp: now, count: 1}
win = append(win, lastSlot)
} else {
lastSlot.count++
}
} else {
lastSlot = &timeSlot{timestamp: now, count: 1}
win = append(win, lastSlot)
}
l.storeWindow(uidOrIp, win)
return result
}
func (l *SlidingWindowLimiter) getUidOrIp() string {
return "127.0.0.1"
}
func (l *SlidingWindowLimiter) IsLimited() bool {
return !l.validate(l.getUidOrIp())
}
func main() {
limiter := NewSliding(100*time.Millisecond, time.Second, 10)
for i := 0; i < 5; i++ {
fmt.Println(limiter.IsLimited())
}
time.Sleep(100 * time.Millisecond)
for i := 0; i < 5; i++ {
fmt.Println(limiter.IsLimited())
}
fmt.Println(limiter.IsLimited())
for _, v := range limiter.windows[limiter.getUidOrIp()] {
fmt.Println(v.timestamp, v.count)
}
fmt.Println("a thousand years later...")
time.Sleep(time.Second)
for i := 0; i < 7; i++ {
fmt.Println(limiter.IsLimited())
}
for _, v := range limiter.windows[limiter.getUidOrIp()] {
fmt.Println(v.timestamp, v.count)
}
}
文章出处:【微信公众号:马哥Linux运维】欢迎添加关注!文章转载请注明出处。
全部0条评论
快来发表一下你的评论吧 !