1
0

feature/rate-limit (#91)

This commit is contained in:
konrad
2019-07-21 21:27:30 +00:00
committed by Gitea
parent 2e599e792e
commit 4327a559e5
33 changed files with 1520 additions and 86 deletions

View File

@ -78,6 +78,12 @@ const (
LogHTTP Key = `log.echo`
LogEcho Key = `log.echo`
LogPath Key = `log.path`
RateLimitEnabled Key = `ratelimit.enabled`
RateLimitKind Key = `ratelimit.kind`
RateLimitPeriod Key = `ratelimit.period`
RateLimitLimit Key = `ratelimit.limit`
RateLimitStore Key = `ratelimit.store`
)
// GetString returns a string config value
@ -95,6 +101,11 @@ func (k Key) GetInt() int {
return viper.GetInt(string(k))
}
// GetInt64 returns an int64 config value
func (k Key) GetInt64() int64 {
return viper.GetInt64(string(k))
}
// GetDuration returns a duration config value
func (k Key) GetDuration() time.Duration {
return viper.GetDuration(string(k))
@ -174,6 +185,12 @@ func InitConfig() {
LogHTTP.setDefault("stdout")
LogEcho.setDefault("off")
LogPath.setDefault(ServiceRootpath.GetString() + "/logs")
// Rate Limit
RateLimitEnabled.setDefault(false)
RateLimitKind.setDefault("user")
RateLimitLimit.setDefault(100)
RateLimitPeriod.setDefault(60)
RateLimitStore.setDefault("memory")
// Init checking for environment variables
viper.SetEnvPrefix("vikunja")

View File

@ -116,7 +116,7 @@ func Debug(args ...interface{}) {
// Debugf is for debug messages
func Debugf(format string, args ...interface{}) {
logInstance.Debugf(format, args)
logInstance.Debugf(format, args...)
}
// Info is for info messages
@ -126,7 +126,7 @@ func Info(args ...interface{}) {
// Infof is for info messages
func Infof(format string, args ...interface{}) {
logInstance.Infof(format, args)
logInstance.Infof(format, args...)
}
// Error is for error messages
@ -136,7 +136,7 @@ func Error(args ...interface{}) {
// Errorf is for error messages
func Errorf(format string, args ...interface{}) {
logInstance.Errorf(format, args)
logInstance.Errorf(format, args...)
}
// Warning is for warning messages
@ -146,7 +146,7 @@ func Warning(args ...interface{}) {
// Warningf is for warning messages
func Warningf(format string, args ...interface{}) {
logInstance.Warningf(format, args)
logInstance.Warningf(format, args...)
}
// Critical is for critical messages
@ -156,7 +156,7 @@ func Critical(args ...interface{}) {
// Criticalf is for critical messages
func Criticalf(format string, args ...interface{}) {
logInstance.Critical(format, args)
logInstance.Criticalf(format, args...)
}
// Fatal is for fatal messages
@ -166,5 +166,5 @@ func Fatal(args ...interface{}) {
// Fatalf is for fatal messages
func Fatalf(format string, args ...interface{}) {
logInstance.Fatal(format, args)
logInstance.Fatalf(format, args...)
}

103
pkg/routes/rate_limit.go Normal file
View File

@ -0,0 +1,103 @@
// Copyright 2019 Vikunja and contriubtors. All rights reserved.
//
// This file is part of Vikunja.
//
// Vikunja is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Vikunja is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Vikunja. If not, see <https://www.gnu.org/licenses/>.
package routes
import (
"code.vikunja.io/api/pkg/config"
"code.vikunja.io/api/pkg/log"
"code.vikunja.io/api/pkg/models"
"code.vikunja.io/api/pkg/red"
"github.com/labstack/echo/v4"
"github.com/ulule/limiter/v3"
"github.com/ulule/limiter/v3/drivers/store/memory"
"github.com/ulule/limiter/v3/drivers/store/redis"
"net/http"
"strconv"
"time"
)
// RateLimit is the rate limit middleware
func RateLimit(rateLimiter *limiter.Limiter) echo.MiddlewareFunc {
return func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) (err error) {
var rateLimitKey string
switch config.RateLimitKind.GetString() {
case "ip":
rateLimitKey = c.RealIP()
case "user":
user, err := models.GetCurrentUser(c)
if err != nil {
log.Errorf("Error while getting the current user for rate limiting: %s", err)
}
rateLimitKey = "user_" + strconv.FormatInt(user.ID, 10)
default:
log.Errorf("Unknown rate limit kind configured: %s", config.RateLimitKind.GetString())
}
limiterCtx, err := rateLimiter.Get(c.Request().Context(), rateLimitKey)
if err != nil {
log.Errorf("IPRateLimit - rateLimiter.Get - err: %v, %s on %s", err, rateLimitKey, c.Request().URL)
return c.JSON(http.StatusInternalServerError, echo.Map{
"message": err,
})
}
h := c.Response().Header()
h.Set("X-RateLimit-Limit", strconv.FormatInt(limiterCtx.Limit, 10))
h.Set("X-RateLimit-Remaining", strconv.FormatInt(limiterCtx.Remaining, 10))
h.Set("X-RateLimit-Reset", strconv.FormatInt(limiterCtx.Reset, 10))
if limiterCtx.Reached {
log.Infof("Too Many Requests from %s on %s", rateLimitKey, c.Request().URL)
return c.JSON(http.StatusTooManyRequests, echo.Map{
"message": "Too Many Requests on " + c.Request().URL.String(),
})
}
// log.Printf("%s request continue", c.RealIP())
return next(c)
}
}
}
func setupRateLimit(a *echo.Group) {
if config.RateLimitEnabled.GetBool() {
rate := limiter.Rate{
Period: config.RateLimitPeriod.GetDuration() * time.Second,
Limit: config.RateLimitLimit.GetInt64(),
}
var store limiter.Store
var err error
switch config.RateLimitStore.GetString() {
case "memory":
store = memory.NewStore()
case "redis":
if !config.RedisEnabled.GetBool() {
log.Fatal("Redis is configured for rate limiting, but not enabled!")
}
store, err = redis.NewStore(red.GetRedis())
if err != nil {
log.Fatalf("Error while creating rate limit redis store: %s", err)
}
default:
log.Fatalf("Unknown Rate limit store \"%s\"", config.RateLimitStore.GetString())
}
rateLimiter := limiter.New(store, rate)
log.Debugf("Rate limit configured with %s and %v requests per %v", config.RateLimitStore.GetString(), rate.Limit, rate.Period)
a.Use(RateLimit(rateLimiter))
}
}

View File

@ -218,10 +218,13 @@ func registerAPIRoutes(a *echo.Group) {
// Info endpoint
a.GET("/info", apiv1.Info)
// ===== Routes with Authetification =====
// ===== Routes with Authetication =====
// Authetification
a.Use(middleware.JWT([]byte(config.ServiceJWTSecret.GetString())))
// Rate limit
setupRateLimit(a)
// Middleware to collect metrics
if config.ServiceJWTSecret.GetBool() {
a.Use(func(next echo.HandlerFunc) echo.HandlerFunc {