1
0

Refactor & fix storing struct-values in redis keyvalue

This commit is contained in:
kolaente
2021-05-28 10:52:32 +02:00
parent df45675df3
commit d48aa101cf
10 changed files with 117 additions and 59 deletions

View File

@ -46,12 +46,12 @@ type Callback struct {
// Provider is the structure of an OpenID Connect provider
type Provider struct {
Name string `json:"name"`
Key string `json:"key"`
AuthURL string `json:"auth_url"`
ClientID string `json:"client_id"`
ClientSecret string `json:"-"`
OpenIDProvider *oidc.Provider `json:"-"`
Name string `json:"name"`
Key string `json:"key"`
AuthURL string `json:"auth_url"`
ClientID string `json:"client_id"`
ClientSecret string `json:"-"`
openIDProvider *oidc.Provider
Oauth2Config *oauth2.Config `json:"-"`
}
@ -66,6 +66,11 @@ func init() {
rand.Seed(time.Now().UTC().UnixNano())
}
func (p *Provider) setOicdProvider() (err error) {
p.openIDProvider, err = oidc.NewProvider(context.Background(), p.AuthURL)
return err
}
// HandleCallback handles the auth request callback after redirecting from the provider with an auth code
// @Summary Authenticate a user with OpenID Connect
// @Description After a redirect from the OpenID Connect provider to the frontend has been made with the authentication `code`, this endpoint can be used to obtain a jwt token for that user and thus log them in.
@ -122,7 +127,7 @@ func HandleCallback(c echo.Context) error {
return c.JSON(http.StatusBadRequest, models.Message{Message: "Missing token"})
}
verifier := provider.OpenIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID})
verifier := provider.openIDProvider.Verifier(&oidc.Config{ClientID: provider.ClientID})
// Parse and verify ID Token payload.
idToken, err := verifier.Verify(context.Background(), rawIDToken)
@ -140,7 +145,7 @@ func HandleCallback(c echo.Context) error {
}
if cl.Email == "" || cl.Name == "" || cl.PreferredUsername == "" {
info, err := provider.OpenIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token))
info, err := provider.openIDProvider.UserInfo(context.Background(), provider.Oauth2Config.TokenSource(context.Background(), oauth2Token))
if err != nil {
log.Errorf("Error getting userinfo for provider %s: %v", provider.Name, err)
return handler.HandleHTTPError(err, c)

View File

@ -17,7 +17,6 @@
package openid
import (
"context"
"regexp"
"strconv"
"strings"
@ -36,7 +35,8 @@ func GetAllProviders() (providers []*Provider, err error) {
return nil, nil
}
ps, exists, err := keyvalue.Get("openid_providers")
providers = []*Provider{}
exists, err := keyvalue.GetWithValue("openid_providers", &providers)
if !exists {
rawProviders := config.AuthOpenIDProviders.Get()
if rawProviders == nil {
@ -68,31 +68,30 @@ func GetAllProviders() (providers []*Provider, err error) {
err = keyvalue.Put("openid_providers", providers)
}
if ps != nil {
return ps.([]*Provider), nil
}
return
}
// GetProvider retrieves a provider from keyvalue
func GetProvider(key string) (provider *Provider, err error) {
var p interface{}
p, exists, err := keyvalue.Get("openid_provider_" + key)
provider = &Provider{}
exists, err := keyvalue.GetWithValue("openid_provider_"+key, provider)
if err != nil {
return nil, err
}
if !exists {
_, err = GetAllProviders() // This will put all providers in cache
if err != nil {
return nil, err
}
p, _, err = keyvalue.Get("openid_provider_" + key)
_, err = keyvalue.GetWithValue("openid_provider_"+key, provider)
if err != nil {
return nil, err
}
}
if p != nil {
return p.(*Provider), nil
}
return nil, err
err = provider.setOicdProvider()
return
}
func getKeyFromName(name string) string {
@ -100,7 +99,7 @@ func getKeyFromName(name string) string {
return reg.ReplaceAllString(strings.ToLower(name), "")
}
func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
func getProviderFromMap(pi map[interface{}]interface{}) (provider *Provider, err error) {
name, is := pi["name"].(string)
if !is {
return nil, nil
@ -108,7 +107,7 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
k := getKeyFromName(name)
provider := &Provider{
provider = &Provider{
Name: pi["name"].(string),
Key: k,
AuthURL: pi["authurl"].(string),
@ -122,10 +121,9 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
provider.ClientID = pi["clientid"].(string)
}
var err error
provider.OpenIDProvider, err = oidc.NewProvider(context.Background(), provider.AuthURL)
err = provider.setOicdProvider()
if err != nil {
return provider, err
return
}
provider.Oauth2Config = &oauth2.Config{
@ -134,7 +132,7 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
RedirectURL: config.AuthOpenIDRedirectURL.GetString() + k,
// Discovery returns the OAuth2 endpoints.
Endpoint: provider.OpenIDProvider.Endpoint(),
Endpoint: provider.openIDProvider.Endpoint(),
// "openid" is a required scope for OpenID Connect flows.
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
@ -142,5 +140,5 @@ func getProviderFromMap(pi map[interface{}]interface{}) (*Provider, error) {
provider.AuthURL = provider.Oauth2Config.Endpoint.AuthURL
return provider, nil
return
}