1
0

feature/rate-limit (#91)

This commit is contained in:
konrad
2019-07-21 21:27:30 +00:00
committed by Gitea
parent 2e599e792e
commit 4327a559e5
33 changed files with 1520 additions and 86 deletions

View File

@ -0,0 +1,28 @@
package common
import (
"time"
"github.com/ulule/limiter/v3"
)
// GetContextFromState generate a new limiter.Context from given state.
func GetContextFromState(now time.Time, rate limiter.Rate, expiration time.Time, count int64) limiter.Context {
limit := rate.Limit
remaining := int64(0)
reached := true
if count <= limit {
remaining = limit - count
reached = false
}
reset := expiration.Unix()
return limiter.Context{
Limit: limit,
Remaining: remaining,
Reset: reset,
Reached: reached,
}
}

View File

@ -0,0 +1,159 @@
package memory
import (
"runtime"
"sync"
"time"
)
// Forked from https://github.com/patrickmn/go-cache
// CacheWrapper is used to ensure that the underlying cleaner goroutine used to clean expired keys will not prevent
// Cache from being garbage collected.
type CacheWrapper struct {
*Cache
}
// A cleaner will periodically delete expired keys from cache.
type cleaner struct {
interval time.Duration
stop chan bool
}
// Run will periodically delete expired keys from given cache until GC notify that it should stop.
func (cleaner *cleaner) Run(cache *Cache) {
ticker := time.NewTicker(cleaner.interval)
for {
select {
case <-ticker.C:
cache.Clean()
case <-cleaner.stop:
ticker.Stop()
return
}
}
}
// stopCleaner is a callback from GC used to stop cleaner goroutine.
func stopCleaner(wrapper *CacheWrapper) {
wrapper.cleaner.stop <- true
}
// startCleaner will start a cleaner goroutine for given cache.
func startCleaner(cache *Cache, interval time.Duration) {
cleaner := &cleaner{
interval: interval,
stop: make(chan bool),
}
cache.cleaner = cleaner
go cleaner.Run(cache)
}
// Counter is a simple counter with an optional expiration.
type Counter struct {
Value int64
Expiration int64
}
// Expired returns true if the counter has expired.
func (counter Counter) Expired() bool {
if counter.Expiration == 0 {
return false
}
return time.Now().UnixNano() > counter.Expiration
}
// Cache contains a collection of counters.
type Cache struct {
mutex sync.RWMutex
counters map[string]Counter
cleaner *cleaner
}
// NewCache returns a new cache.
func NewCache(cleanInterval time.Duration) *CacheWrapper {
cache := &Cache{
counters: map[string]Counter{},
}
wrapper := &CacheWrapper{Cache: cache}
if cleanInterval > 0 {
startCleaner(cache, cleanInterval)
runtime.SetFinalizer(wrapper, stopCleaner)
}
return wrapper
}
// Increment increments given value on key.
// If key is undefined or expired, it will create it.
func (cache *Cache) Increment(key string, value int64, duration time.Duration) (int64, time.Time) {
cache.mutex.Lock()
counter, ok := cache.counters[key]
if !ok || counter.Expired() {
expiration := time.Now().Add(duration).UnixNano()
counter = Counter{
Value: value,
Expiration: expiration,
}
cache.counters[key] = counter
cache.mutex.Unlock()
return value, time.Unix(0, expiration)
}
value = counter.Value + value
counter.Value = value
expiration := counter.Expiration
cache.counters[key] = counter
cache.mutex.Unlock()
return value, time.Unix(0, expiration)
}
// Get returns key's value and expiration.
func (cache *Cache) Get(key string, duration time.Duration) (int64, time.Time) {
cache.mutex.RLock()
counter, ok := cache.counters[key]
if !ok || counter.Expired() {
expiration := time.Now().Add(duration).UnixNano()
cache.mutex.RUnlock()
return 0, time.Unix(0, expiration)
}
value := counter.Value
expiration := counter.Expiration
cache.mutex.RUnlock()
return value, time.Unix(0, expiration)
}
// Clean will deleted any expired keys.
func (cache *Cache) Clean() {
now := time.Now().UnixNano()
cache.mutex.Lock()
for key, counter := range cache.counters {
if now > counter.Expiration {
delete(cache.counters, key)
}
}
cache.mutex.Unlock()
}
// Reset changes the key's value and resets the expiration.
func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) {
cache.mutex.Lock()
delete(cache.counters, key)
cache.mutex.Unlock()
expiration := time.Now().Add(duration).UnixNano()
return 0, time.Unix(0, expiration)
}

View File

@ -0,0 +1,67 @@
package memory
import (
"context"
"fmt"
"time"
"github.com/ulule/limiter/v3"
"github.com/ulule/limiter/v3/drivers/store/common"
)
// Store is the in-memory store.
type Store struct {
// Prefix used for the key.
Prefix string
// cache used to store values in-memory.
cache *CacheWrapper
}
// NewStore creates a new instance of memory store with defaults.
func NewStore() limiter.Store {
return NewStoreWithOptions(limiter.StoreOptions{
Prefix: limiter.DefaultPrefix,
CleanUpInterval: limiter.DefaultCleanUpInterval,
})
}
// NewStoreWithOptions creates a new instance of memory store with options.
func NewStoreWithOptions(options limiter.StoreOptions) limiter.Store {
return &Store{
Prefix: options.Prefix,
cache: NewCache(options.CleanUpInterval),
}
}
// Get returns the limit for given identifier.
func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
count, expiration := store.cache.Increment(key, 1, rate.Period)
lctx := common.GetContextFromState(now, rate, expiration, count)
return lctx, nil
}
// Peek returns the limit for given identifier, without modification on current values.
func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
count, expiration := store.cache.Get(key, rate.Period)
lctx := common.GetContextFromState(now, rate, expiration, count)
return lctx, nil
}
// Reset returns the limit for given identifier.
func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
count, expiration := store.cache.Reset(key, rate.Period)
lctx := common.GetContextFromState(now, rate, expiration, count)
return lctx, nil
}

View File

@ -0,0 +1,320 @@
package redis
import (
"context"
"fmt"
"time"
libredis "github.com/go-redis/redis"
"github.com/pkg/errors"
"github.com/ulule/limiter/v3"
"github.com/ulule/limiter/v3/drivers/store/common"
)
// Client is an interface thats allows to use a redis cluster or a redis single client seamlessly.
type Client interface {
Ping() *libredis.StatusCmd
Get(key string) *libredis.StringCmd
Set(key string, value interface{}, expiration time.Duration) *libredis.StatusCmd
Watch(handler func(*libredis.Tx) error, keys ...string) error
Del(keys ...string) *libredis.IntCmd
SetNX(key string, value interface{}, expiration time.Duration) *libredis.BoolCmd
Eval(script string, keys []string, args ...interface{}) *libredis.Cmd
}
// Store is the redis store.
type Store struct {
// Prefix used for the key.
Prefix string
// MaxRetry is the maximum number of retry under race conditions.
MaxRetry int
// client used to communicate with redis server.
client Client
}
// NewStore returns an instance of redis store with defaults.
func NewStore(client Client) (limiter.Store, error) {
return NewStoreWithOptions(client, limiter.StoreOptions{
Prefix: limiter.DefaultPrefix,
CleanUpInterval: limiter.DefaultCleanUpInterval,
MaxRetry: limiter.DefaultMaxRetry,
})
}
// NewStoreWithOptions returns an instance of redis store with options.
func NewStoreWithOptions(client Client, options limiter.StoreOptions) (limiter.Store, error) {
store := &Store{
client: client,
Prefix: options.Prefix,
MaxRetry: options.MaxRetry,
}
if store.MaxRetry <= 0 {
store.MaxRetry = 1
}
_, err := store.ping()
if err != nil {
return nil, err
}
return store, nil
}
// Get returns the limit for given identifier.
func (store *Store) Get(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
lctx := limiter.Context{}
onWatch := func(rtx *libredis.Tx) error {
created, err := store.doSetValue(rtx, key, rate.Period)
if err != nil {
return err
}
if created {
expiration := now.Add(rate.Period)
lctx = common.GetContextFromState(now, rate, expiration, 1)
return nil
}
count, ttl, err := store.doUpdateValue(rtx, key, rate.Period)
if err != nil {
return err
}
expiration := now.Add(rate.Period)
if ttl > 0 {
expiration = now.Add(ttl)
}
lctx = common.GetContextFromState(now, rate, expiration, count)
return nil
}
err := store.client.Watch(onWatch, key)
if err != nil {
err = errors.Wrapf(err, "limiter: cannot get value for %s", key)
return limiter.Context{}, err
}
return lctx, nil
}
// Peek returns the limit for given identifier, without modification on current values.
func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
lctx := limiter.Context{}
onWatch := func(rtx *libredis.Tx) error {
count, ttl, err := store.doPeekValue(rtx, key)
if err != nil {
return err
}
expiration := now.Add(rate.Period)
if ttl > 0 {
expiration = now.Add(ttl)
}
lctx = common.GetContextFromState(now, rate, expiration, count)
return nil
}
err := store.client.Watch(onWatch, key)
if err != nil {
err = errors.Wrapf(err, "limiter: cannot peek value for %s", key)
return limiter.Context{}, err
}
return lctx, nil
}
// Reset returns the limit for given identifier which is set to zero.
func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) {
key = fmt.Sprintf("%s:%s", store.Prefix, key)
now := time.Now()
lctx := limiter.Context{}
onWatch := func(rtx *libredis.Tx) error {
err := store.doResetValue(rtx, key)
if err != nil {
return err
}
count := int64(0)
expiration := now.Add(rate.Period)
lctx = common.GetContextFromState(now, rate, expiration, count)
return nil
}
err := store.client.Watch(onWatch, key)
if err != nil {
err = errors.Wrapf(err, "limiter: cannot reset value for %s", key)
return limiter.Context{}, err
}
return lctx, nil
}
// doPeekValue will execute peekValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached.
func (store *Store) doPeekValue(rtx *libredis.Tx, key string) (int64, time.Duration, error) {
for i := 0; i < store.MaxRetry; i++ {
count, ttl, err := peekValue(rtx, key)
if err == nil {
return count, ttl, nil
}
}
return 0, 0, errors.New("retry limit exceeded")
}
// peekValue will retrieve the counter and its expiration for given key.
func peekValue(rtx *libredis.Tx, key string) (int64, time.Duration, error) {
pipe := rtx.Pipeline()
value := pipe.Get(key)
expire := pipe.PTTL(key)
_, err := pipe.Exec()
if err != nil && err != libredis.Nil {
return 0, 0, err
}
count, err := value.Int64()
if err != nil && err != libredis.Nil {
return 0, 0, err
}
ttl, err := expire.Result()
if err != nil {
return 0, 0, err
}
return count, ttl, nil
}
// doSetValue will execute setValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached.
func (store *Store) doSetValue(rtx *libredis.Tx, key string, expiration time.Duration) (bool, error) {
for i := 0; i < store.MaxRetry; i++ {
created, err := setValue(rtx, key, expiration)
if err == nil {
return created, nil
}
}
return false, errors.New("retry limit exceeded")
}
// setValue will try to initialize a new counter if given key doesn't exists.
func setValue(rtx *libredis.Tx, key string, expiration time.Duration) (bool, error) {
value := rtx.SetNX(key, 1, expiration)
created, err := value.Result()
if err != nil {
return false, err
}
return created, nil
}
// doUpdateValue will execute setValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached.
func (store *Store) doUpdateValue(rtx *libredis.Tx, key string,
expiration time.Duration) (int64, time.Duration, error) {
for i := 0; i < store.MaxRetry; i++ {
count, ttl, err := updateValue(rtx, key, expiration)
if err == nil {
return count, ttl, nil
}
// If ttl is negative and there is an error, do not retry an update.
if ttl < 0 {
return 0, 0, err
}
}
return 0, 0, errors.New("retry limit exceeded")
}
// updateValue will try to increment the counter identified by given key.
func updateValue(rtx *libredis.Tx, key string, expiration time.Duration) (int64, time.Duration, error) {
pipe := rtx.Pipeline()
value := pipe.Incr(key)
expire := pipe.PTTL(key)
_, err := pipe.Exec()
if err != nil {
return 0, 0, err
}
count, err := value.Result()
if err != nil {
return 0, 0, err
}
ttl, err := expire.Result()
if err != nil {
return 0, 0, err
}
// If ttl is -1ms, we have to define key expiration.
// PTTL return values changed as of Redis 2.8
// Now the command returns -2ms if the key does not exist, and -1ms if the key exists, but there is no expiry set
// We shouldn't try to set an expiry on a key that doesn't exist
if ttl == (-1 * time.Millisecond) {
expire := rtx.Expire(key, expiration)
ok, err := expire.Result()
if err != nil {
return count, ttl, err
}
if !ok {
return count, ttl, errors.New("cannot configure timeout on key")
}
}
return count, ttl, nil
}
// doResetValue will execute resetValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached.
func (store *Store) doResetValue(rtx *libredis.Tx, key string) error {
for i := 0; i < store.MaxRetry; i++ {
err := resetValue(rtx, key)
if err == nil {
return nil
}
}
return errors.New("retry limit exceeded")
}
// resetValue will try to reset the counter identified by given key.
func resetValue(rtx *libredis.Tx, key string) error {
deletion := rtx.Del(key)
count, err := deletion.Result()
if err != nil {
return err
}
if count != 1 {
return errors.New("cannot delete key")
}
return nil
}
// ping checks if redis is alive.
func (store *Store) ping() (bool, error) {
cmd := store.client.Ping()
pong, err := cmd.Result()
if err != nil {
return false, errors.Wrap(err, "limiter: cannot ping redis server")
}
return (pong == "PONG"), nil
}