feature/rate-limit (#91)
This commit is contained in:
@ -78,6 +78,12 @@ const (
|
||||
LogHTTP Key = `log.echo`
|
||||
LogEcho Key = `log.echo`
|
||||
LogPath Key = `log.path`
|
||||
|
||||
RateLimitEnabled Key = `ratelimit.enabled`
|
||||
RateLimitKind Key = `ratelimit.kind`
|
||||
RateLimitPeriod Key = `ratelimit.period`
|
||||
RateLimitLimit Key = `ratelimit.limit`
|
||||
RateLimitStore Key = `ratelimit.store`
|
||||
)
|
||||
|
||||
// GetString returns a string config value
|
||||
@ -95,6 +101,11 @@ func (k Key) GetInt() int {
|
||||
return viper.GetInt(string(k))
|
||||
}
|
||||
|
||||
// GetInt64 returns an int64 config value
|
||||
func (k Key) GetInt64() int64 {
|
||||
return viper.GetInt64(string(k))
|
||||
}
|
||||
|
||||
// GetDuration returns a duration config value
|
||||
func (k Key) GetDuration() time.Duration {
|
||||
return viper.GetDuration(string(k))
|
||||
@ -174,6 +185,12 @@ func InitConfig() {
|
||||
LogHTTP.setDefault("stdout")
|
||||
LogEcho.setDefault("off")
|
||||
LogPath.setDefault(ServiceRootpath.GetString() + "/logs")
|
||||
// Rate Limit
|
||||
RateLimitEnabled.setDefault(false)
|
||||
RateLimitKind.setDefault("user")
|
||||
RateLimitLimit.setDefault(100)
|
||||
RateLimitPeriod.setDefault(60)
|
||||
RateLimitStore.setDefault("memory")
|
||||
|
||||
// Init checking for environment variables
|
||||
viper.SetEnvPrefix("vikunja")
|
||||
|
@ -116,7 +116,7 @@ func Debug(args ...interface{}) {
|
||||
|
||||
// Debugf is for debug messages
|
||||
func Debugf(format string, args ...interface{}) {
|
||||
logInstance.Debugf(format, args)
|
||||
logInstance.Debugf(format, args...)
|
||||
}
|
||||
|
||||
// Info is for info messages
|
||||
@ -126,7 +126,7 @@ func Info(args ...interface{}) {
|
||||
|
||||
// Infof is for info messages
|
||||
func Infof(format string, args ...interface{}) {
|
||||
logInstance.Infof(format, args)
|
||||
logInstance.Infof(format, args...)
|
||||
}
|
||||
|
||||
// Error is for error messages
|
||||
@ -136,7 +136,7 @@ func Error(args ...interface{}) {
|
||||
|
||||
// Errorf is for error messages
|
||||
func Errorf(format string, args ...interface{}) {
|
||||
logInstance.Errorf(format, args)
|
||||
logInstance.Errorf(format, args...)
|
||||
}
|
||||
|
||||
// Warning is for warning messages
|
||||
@ -146,7 +146,7 @@ func Warning(args ...interface{}) {
|
||||
|
||||
// Warningf is for warning messages
|
||||
func Warningf(format string, args ...interface{}) {
|
||||
logInstance.Warningf(format, args)
|
||||
logInstance.Warningf(format, args...)
|
||||
}
|
||||
|
||||
// Critical is for critical messages
|
||||
@ -156,7 +156,7 @@ func Critical(args ...interface{}) {
|
||||
|
||||
// Criticalf is for critical messages
|
||||
func Criticalf(format string, args ...interface{}) {
|
||||
logInstance.Critical(format, args)
|
||||
logInstance.Criticalf(format, args...)
|
||||
}
|
||||
|
||||
// Fatal is for fatal messages
|
||||
@ -166,5 +166,5 @@ func Fatal(args ...interface{}) {
|
||||
|
||||
// Fatalf is for fatal messages
|
||||
func Fatalf(format string, args ...interface{}) {
|
||||
logInstance.Fatal(format, args)
|
||||
logInstance.Fatalf(format, args...)
|
||||
}
|
||||
|
103
pkg/routes/rate_limit.go
Normal file
103
pkg/routes/rate_limit.go
Normal file
@ -0,0 +1,103 @@
|
||||
// Copyright 2019 Vikunja and contriubtors. All rights reserved.
|
||||
//
|
||||
// This file is part of Vikunja.
|
||||
//
|
||||
// Vikunja is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Vikunja is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Vikunja. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package routes
|
||||
|
||||
import (
|
||||
"code.vikunja.io/api/pkg/config"
|
||||
"code.vikunja.io/api/pkg/log"
|
||||
"code.vikunja.io/api/pkg/models"
|
||||
"code.vikunja.io/api/pkg/red"
|
||||
"github.com/labstack/echo/v4"
|
||||
"github.com/ulule/limiter/v3"
|
||||
"github.com/ulule/limiter/v3/drivers/store/memory"
|
||||
"github.com/ulule/limiter/v3/drivers/store/redis"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
// RateLimit is the rate limit middleware
|
||||
func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc {
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c echo.Context) (err error) {
|
||||
var rateLimitKey string
|
||||
switch config.RateLimitKind.GetString() {
|
||||
case "ip":
|
||||
rateLimitKey = c.RealIP()
|
||||
case "user":
|
||||
user, err := models.GetCurrentUser(c)
|
||||
if err != nil {
|
||||
log.Errorf("Error while getting the current user for rate limiting: %s", err)
|
||||
}
|
||||
rateLimitKey = "user_" + strconv.FormatInt(user.ID, 10)
|
||||
default:
|
||||
log.Errorf("Unknown rate limit kind configured: %s", config.RateLimitKind.GetString())
|
||||
}
|
||||
limiterCtx, err := rateLimiter.Get(c.Request().Context(), rateLimitKey)
|
||||
if err != nil {
|
||||
log.Errorf("IPRateLimit - rateLimiter.Get - err: %v, %s on %s", err, rateLimitKey, c.Request().URL)
|
||||
return c.JSON(http.StatusInternalServerError, echo.Map{
|
||||
"message": err,
|
||||
})
|
||||
}
|
||||
|
||||
h := c.Response().Header()
|
||||
h.Set("X-RateLimit-Limit", strconv.FormatInt(limiterCtx.Limit, 10))
|
||||
h.Set("X-RateLimit-Remaining", strconv.FormatInt(limiterCtx.Remaining, 10))
|
||||
h.Set("X-RateLimit-Reset", strconv.FormatInt(limiterCtx.Reset, 10))
|
||||
|
||||
if limiterCtx.Reached {
|
||||
log.Infof("Too Many Requests from %s on %s", rateLimitKey, c.Request().URL)
|
||||
return c.JSON(http.StatusTooManyRequests, echo.Map{
|
||||
"message": "Too Many Requests on " + c.Request().URL.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// log.Printf("%s request continue", c.RealIP())
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setupRateLimit(a *echo.Group) {
|
||||
if config.RateLimitEnabled.GetBool() {
|
||||
rate := limiter.Rate{
|
||||
Period: config.RateLimitPeriod.GetDuration() * time.Second,
|
||||
Limit: config.RateLimitLimit.GetInt64(),
|
||||
}
|
||||
var store limiter.Store
|
||||
var err error
|
||||
switch config.RateLimitStore.GetString() {
|
||||
case "memory":
|
||||
store = memory.NewStore()
|
||||
case "redis":
|
||||
if !config.RedisEnabled.GetBool() {
|
||||
log.Fatal("Redis is configured for rate limiting, but not enabled!")
|
||||
}
|
||||
store, err = redis.NewStore(red.GetRedis())
|
||||
if err != nil {
|
||||
log.Fatalf("Error while creating rate limit redis store: %s", err)
|
||||
}
|
||||
default:
|
||||
log.Fatalf("Unknown Rate limit store \"%s\"", config.RateLimitStore.GetString())
|
||||
}
|
||||
rateLimiter := limiter.New(store, rate)
|
||||
log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period)
|
||||
a.Use(RateLimit(rateLimiter))
|
||||
}
|
||||
}
|
@ -218,10 +218,13 @@ func registerAPIRoutes(a *echo.Group) {
|
||||
// Info endpoint
|
||||
a.GET("/info", apiv1.Info)
|
||||
|
||||
// ===== Routes with Authetification =====
|
||||
// ===== Routes with Authetication =====
|
||||
// Authetification
|
||||
a.Use(middleware.JWT([]byte(config.ServiceJWTSecret.GetString())))
|
||||
|
||||
// Rate limit
|
||||
setupRateLimit(a)
|
||||
|
||||
// Middleware to collect metrics
|
||||
if config.ServiceJWTSecret.GetBool() {
|
||||
a.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
|
Reference in New Issue
Block a user