Add password reset (#3)
This commit is contained in:
@ -17,8 +17,10 @@ func InitConfig() (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Service
|
||||
viper.SetDefault("service.JWTSecret", random)
|
||||
viper.SetDefault("service.interface", ":3456")
|
||||
viper.SetDefault("service.frontendurl", "")
|
||||
// Database
|
||||
viper.SetDefault("database.type", "sqlite")
|
||||
viper.SetDefault("database.host", "localhost")
|
||||
@ -34,6 +36,15 @@ func InitConfig() (err error) {
|
||||
viper.SetDefault("cache.maxelementsize", 1000)
|
||||
viper.SetDefault("cache.redishost", "localhost:6379")
|
||||
viper.SetDefault("cache.redispassword", "")
|
||||
// Mailer
|
||||
viper.SetDefault("mailer.host", "")
|
||||
viper.SetDefault("mailer.port", "587")
|
||||
viper.SetDefault("mailer.user", "user")
|
||||
viper.SetDefault("mailer.password", "")
|
||||
viper.SetDefault("mailer.skiptlsverify", false)
|
||||
viper.SetDefault("mailer.fromemail", "mail@vikunja")
|
||||
viper.SetDefault("mailer.queuelength", 100)
|
||||
viper.SetDefault("mailer.queuetimeout", 30)
|
||||
|
||||
// Init checking for environment variables
|
||||
viper.SetEnvPrefix("vikunja")
|
||||
|
@ -134,25 +134,45 @@ func (err ErrCouldNotGetUserID) HTTPError() HTTPError {
|
||||
return HTTPError{HTTPCode: http.StatusBadRequest, Code: ErrCodeCouldNotGetUserID, Message: "Could not get user id."}
|
||||
}
|
||||
|
||||
// ErrCannotDeleteLastUser represents a "ErrCannotDeleteLastUser" kind of error.
|
||||
type ErrCannotDeleteLastUser struct{}
|
||||
|
||||
// IsErrCannotDeleteLastUser checks if an error is a ErrCannotDeleteLastUser.
|
||||
func IsErrCannotDeleteLastUser(err error) bool {
|
||||
_, ok := err.(ErrCannotDeleteLastUser)
|
||||
return ok
|
||||
// ErrNoPasswordResetToken represents an error where no password reset token exists for that user
|
||||
type ErrNoPasswordResetToken struct {
|
||||
UserID int64
|
||||
}
|
||||
|
||||
func (err ErrCannotDeleteLastUser) Error() string {
|
||||
return fmt.Sprintf("Cannot delete last user")
|
||||
func (err ErrNoPasswordResetToken) Error() string {
|
||||
return fmt.Sprintf("No token to reset a password [UserID: %d]", err.UserID)
|
||||
}
|
||||
|
||||
// ErrCodeCannotDeleteLastUser holds the unique world-error code of this error
|
||||
const ErrCodeCannotDeleteLastUser = 1007
|
||||
// ErrCodeNoPasswordResetToken holds the unique world-error code of this error
|
||||
const ErrCodeNoPasswordResetToken = 1008
|
||||
|
||||
// HTTPError holds the http error description
|
||||
func (err ErrCannotDeleteLastUser) HTTPError() HTTPError {
|
||||
return HTTPError{HTTPCode: http.StatusConflict, Code: ErrCodeCannotDeleteLastUser, Message: "Cannot delete the last user on the server."}
|
||||
func (err ErrNoPasswordResetToken) HTTPError() HTTPError {
|
||||
return HTTPError{HTTPCode: http.StatusPreconditionFailed, Code: ErrCodeNoPasswordResetToken, Message: "No token to reset a user's password provided."}
|
||||
}
|
||||
|
||||
// ErrInvalidPasswordResetToken is an error where the password reset token is invalid
|
||||
type ErrInvalidPasswordResetToken struct {
|
||||
UserID int64
|
||||
Token string
|
||||
}
|
||||
|
||||
func (err ErrInvalidPasswordResetToken) Error() string {
|
||||
return fmt.Sprintf("Invalid token to reset a password [UserID: %d, Token: %s]", err.UserID, err.Token)
|
||||
}
|
||||
|
||||
// ErrCodeInvalidPasswordResetToken holds the unique world-error code of this error
|
||||
const ErrCodeInvalidPasswordResetToken = 1009
|
||||
|
||||
// HTTPError holds the http error description
|
||||
func (err ErrInvalidPasswordResetToken) HTTPError() HTTPError {
|
||||
return HTTPError{HTTPCode: http.StatusPreconditionFailed, Code: ErrCodeInvalidPasswordResetToken, Message: "Invalid token to reset a user's password provided."}
|
||||
}
|
||||
|
||||
// IsErrInvalidPasswordResetToken checks if an error is a ErrInvalidPasswordResetToken.
|
||||
func IsErrInvalidPasswordResetToken(err error) bool {
|
||||
_, ok := err.(ErrInvalidPasswordResetToken)
|
||||
return ok
|
||||
}
|
||||
|
||||
// ===================
|
||||
|
66
models/mail/mail.go
Normal file
66
models/mail/mail.go
Normal file
@ -0,0 +1,66 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"github.com/spf13/viper"
|
||||
"gopkg.in/gomail.v2"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Queue is the mail queue
|
||||
var Queue chan *gomail.Message
|
||||
|
||||
// StartMailDaemon starts the mail daemon
|
||||
func StartMailDaemon() {
|
||||
if viper.GetString("mailer.host") == "" {
|
||||
//models.Log.Warning("Mailer seems to be not configured! Please see the config docs for more details.")
|
||||
fmt.Println("Mailer seems to be not configured! Please see the config docs for more details.")
|
||||
return
|
||||
}
|
||||
|
||||
Queue = make(chan *gomail.Message, viper.GetInt("mailer.queuelength"))
|
||||
|
||||
go func() {
|
||||
d := gomail.NewDialer(viper.GetString("mailer.host"), viper.GetInt("mailer.port"), viper.GetString("mailer.username"), viper.GetString("mailer.password"))
|
||||
d.TLSConfig = &tls.Config{InsecureSkipVerify: viper.GetBool("mailer.skiptlsverify")}
|
||||
|
||||
var s gomail.SendCloser
|
||||
var err error
|
||||
open := false
|
||||
for {
|
||||
select {
|
||||
case m, ok := <-Queue:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if !open {
|
||||
if s, err = d.Dial(); err != nil {
|
||||
// models.Log.Error("Error during connect to smtp server: %s", err)
|
||||
fmt.Printf("Error during connect to smtp server: %s \n", err)
|
||||
}
|
||||
open = true
|
||||
}
|
||||
if err := gomail.Send(s, m); err != nil {
|
||||
// models.Log.Error("Error when sending mail: %s", err)
|
||||
fmt.Printf("Error when sending mail: %s \n", err)
|
||||
}
|
||||
// Close the connection to the SMTP server if no email was sent in
|
||||
// the last 30 seconds.
|
||||
case <-time.After(viper.GetDuration("mailer.queuetimeout") * time.Second):
|
||||
if open {
|
||||
if err := s.Close(); err != nil {
|
||||
fmt.Printf("Error closing the mail server connection: %s\n", err)
|
||||
}
|
||||
fmt.Println("Closed connection to mailserver")
|
||||
open = false
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StopMailDaemon closes the mail queue channel, aka stops the daemon
|
||||
func StopMailDaemon() {
|
||||
close(Queue)
|
||||
}
|
101
models/mail/send_mail.go
Normal file
101
models/mail/send_mail.go
Normal file
@ -0,0 +1,101 @@
|
||||
package mail
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"code.vikunja.io/api/models/utils"
|
||||
"github.com/labstack/gommon/log"
|
||||
"github.com/spf13/viper"
|
||||
"gopkg.in/gomail.v2"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
// Opts holds infos for a mail
|
||||
type Opts struct {
|
||||
To string
|
||||
Subject string
|
||||
Message string
|
||||
HTMLMessage string
|
||||
ContentType ContentType
|
||||
Boundary string
|
||||
Headers []*header
|
||||
}
|
||||
|
||||
// ContentType represents mail content types
|
||||
type ContentType int
|
||||
|
||||
// Enumerate all the team rights
|
||||
const (
|
||||
ContentTypePlain ContentType = iota
|
||||
ContentTypeHTML
|
||||
ContentTypeMultipart
|
||||
)
|
||||
|
||||
type header struct {
|
||||
Field string
|
||||
Content string
|
||||
}
|
||||
|
||||
// SendMail puts a mail in the queue
|
||||
func SendMail(opts *Opts) {
|
||||
m := gomail.NewMessage()
|
||||
m.SetHeader("From", viper.GetString("mailer.fromemail"))
|
||||
m.SetHeader("To", opts.To)
|
||||
m.SetHeader("Subject", opts.Subject)
|
||||
for _, h := range opts.Headers {
|
||||
m.SetHeader(h.Field, h.Content)
|
||||
}
|
||||
|
||||
switch opts.ContentType {
|
||||
case ContentTypePlain:
|
||||
m.SetBody("text/plain", opts.Message)
|
||||
case ContentTypeHTML:
|
||||
m.SetBody("text/html", opts.Message)
|
||||
case ContentTypeMultipart:
|
||||
m.SetBody("text/plain", opts.Message)
|
||||
m.AddAlternative("text/html", opts.HTMLMessage)
|
||||
}
|
||||
|
||||
Queue <- m
|
||||
}
|
||||
|
||||
// Template holds a pointer about a template
|
||||
type Template struct {
|
||||
Templates *template.Template
|
||||
}
|
||||
|
||||
// SendMailWithTemplate parses a template and sends it via mail
|
||||
func SendMailWithTemplate(to, subject, tpl string, data map[string]interface{}) {
|
||||
var htmlContent bytes.Buffer
|
||||
var plainContent bytes.Buffer
|
||||
|
||||
t := &Template{
|
||||
Templates: template.Must(template.ParseGlob("templates/mail/*.tmpl")),
|
||||
}
|
||||
|
||||
boundary := "np" + utils.MakeRandomString(13)
|
||||
|
||||
data["Boundary"] = boundary
|
||||
data["FrontendURL"] = viper.GetString("service.frontendurl")
|
||||
|
||||
if err := t.Templates.ExecuteTemplate(&htmlContent, tpl+".html.tmpl", data); err != nil {
|
||||
log.Error(3, "Template: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := t.Templates.ExecuteTemplate(&plainContent, tpl+".plain.tmpl", data); err != nil {
|
||||
log.Error(3, "Template: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
opts := &Opts{
|
||||
To: to,
|
||||
Subject: subject,
|
||||
Message: plainContent.String(),
|
||||
HTMLMessage: htmlContent.String(),
|
||||
ContentType: ContentTypeMultipart,
|
||||
Boundary: boundary,
|
||||
Headers: []*header{{Field: "MIME-Version", Content: "1.0"}},
|
||||
}
|
||||
|
||||
SendMail(opts)
|
||||
}
|
@ -18,8 +18,11 @@ type User struct {
|
||||
Username string `xorm:"varchar(250) not null unique" json:"username"`
|
||||
Password string `xorm:"varchar(250) not null" json:"-"`
|
||||
Email string `xorm:"varchar(250)" json:"email"`
|
||||
Created int64 `xorm:"created" json:"-"`
|
||||
Updated int64 `xorm:"updated" json:"-"`
|
||||
|
||||
PasswordResetToken string `xorm:"varchar(450)" json:"-"`
|
||||
|
||||
Created int64 `xorm:"created" json:"-"`
|
||||
Updated int64 `xorm:"updated" json:"-"`
|
||||
}
|
||||
|
||||
// TableName returns the table name for users
|
||||
@ -61,7 +64,7 @@ func GetUser(user User) (userOut User, err error) {
|
||||
exists, err := x.Get(&userOut)
|
||||
|
||||
if !exists {
|
||||
return User{}, ErrUserDoesNotExist{}
|
||||
return User{}, ErrUserDoesNotExist{UserID: user.ID}
|
||||
}
|
||||
|
||||
return userOut, err
|
||||
|
@ -7,18 +7,8 @@ func DeleteUserByID(id int64, doer *User) error {
|
||||
return ErrIDCannotBeZero{}
|
||||
}
|
||||
|
||||
// Check if there is > 1 user
|
||||
total, err := x.Count(User{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if total < 2 {
|
||||
return ErrCannotDeleteLastUser{}
|
||||
}
|
||||
|
||||
// Delete the user
|
||||
_, err = x.Id(id).Delete(&User{})
|
||||
_, err := x.Id(id).Delete(&User{})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
|
90
models/user_password_reset.go
Normal file
90
models/user_password_reset.go
Normal file
@ -0,0 +1,90 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"code.vikunja.io/api/models/mail"
|
||||
"code.vikunja.io/api/models/utils"
|
||||
)
|
||||
|
||||
// PasswordReset holds the data to reset a password
|
||||
type PasswordReset struct {
|
||||
UserID int64 `json:"user_id"`
|
||||
Token string `json:"token"`
|
||||
NewPassword string `json:"new_password"`
|
||||
}
|
||||
|
||||
// UserPasswordReset resets a users password
|
||||
func UserPasswordReset(reset *PasswordReset) (err error) {
|
||||
|
||||
// Check if the password is not empty
|
||||
if reset.NewPassword == "" {
|
||||
return ErrNoUsernamePassword{}
|
||||
}
|
||||
|
||||
// Check if the user exists
|
||||
user, err := GetUserByID(reset.UserID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if we have a token
|
||||
exists, err := x.Where("password_reset_token = ? AND id = ?", reset.Token, user.ID).Exist(&User{})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !exists {
|
||||
return ErrInvalidPasswordResetToken{UserID: reset.UserID, Token: reset.Token}
|
||||
}
|
||||
|
||||
// Hash the password
|
||||
user.Password, err = hashPassword(reset.NewPassword)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Save it
|
||||
_, err = x.Where("id = ?", user.ID).Update(&user)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Send a mail to the user to notify it his password was changed.
|
||||
data := map[string]interface{}{
|
||||
"User": user,
|
||||
}
|
||||
|
||||
mail.SendMailWithTemplate(user.Email, "Your password on Vikunja was changed", "password-changed", data)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// PasswordTokenRequest defines the request format for password reset resqest
|
||||
type PasswordTokenRequest struct {
|
||||
Username string `json:"user_name"`
|
||||
}
|
||||
|
||||
// RequestUserPasswordResetToken inserts a random token to reset a users password into the databsse
|
||||
func RequestUserPasswordResetToken(tr *PasswordTokenRequest) (err error) {
|
||||
// Check if the user exists
|
||||
user, err := GetUser(User{Username: tr.Username})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Generate a token and save it
|
||||
user.PasswordResetToken = utils.MakeRandomString(400)
|
||||
|
||||
// Save it
|
||||
_, err = x.Where("id = ?", user.ID).Update(&user)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
data := map[string]interface{}{
|
||||
"User": user,
|
||||
}
|
||||
|
||||
// Send the user a mail with the reset token
|
||||
mail.SendMailWithTemplate(user.Email, "Reset your password on Vikunja", "reset-password", data)
|
||||
return
|
||||
}
|
@ -1,6 +1,7 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"code.vikunja.io/api/models/utils"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"testing"
|
||||
)
|
||||
@ -20,24 +21,12 @@ func TestCreateUser(t *testing.T) {
|
||||
Email: "noone@example.com",
|
||||
}
|
||||
|
||||
// Delete every preexisting user to have a fresh start
|
||||
_, err = x.Where("1 = 1").Delete(&User{})
|
||||
assert.NoError(t, err)
|
||||
|
||||
allusers, err := ListUsers("")
|
||||
assert.NoError(t, err)
|
||||
for _, user := range allusers {
|
||||
// Delete it
|
||||
err := DeleteUserByID(user.ID, &doer)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
// Create a new user
|
||||
createdUser, err := CreateUser(dummyuser)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Create a second new user
|
||||
createdUser2, err := CreateUser(User{Username: dummyuser.Username + "2", Email: dummyuser.Email + "m", Password: dummyuser.Password})
|
||||
_, err = CreateUser(User{Username: dummyuser.Username + "2", Email: dummyuser.Email + "m", Password: dummyuser.Password})
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Check if it fails to create the same user again
|
||||
@ -128,9 +117,39 @@ func TestCreateUser(t *testing.T) {
|
||||
err = DeleteUserByID(0, &doer)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, IsErrIDCannotBeZero(err))
|
||||
|
||||
// Try delete the last user (Should fail)
|
||||
err = DeleteUserByID(createdUser2.ID, &doer)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, IsErrCannotDeleteLastUser(err))
|
||||
}
|
||||
|
||||
func TestUserPasswordReset(t *testing.T) {
|
||||
// Request a new token
|
||||
tr := &PasswordTokenRequest{
|
||||
UserID: 1,
|
||||
}
|
||||
err := RequestUserPasswordResetToken(tr)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Get the token / inside the user object
|
||||
userWithToken, err := GetUserByID(1)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Try resetting it
|
||||
reset := &PasswordReset{
|
||||
UserID: 1,
|
||||
Token: userWithToken.PasswordResetToken,
|
||||
}
|
||||
|
||||
// Try resetting it without a password
|
||||
reset.NewPassword = ""
|
||||
err = UserPasswordReset(reset)
|
||||
assert.True(t, IsErrNoUsernamePassword(err))
|
||||
|
||||
// Reset it
|
||||
reset.NewPassword = "1234"
|
||||
err = UserPasswordReset(reset)
|
||||
assert.NoError(t, err)
|
||||
|
||||
// Try resetting it with a wrong token
|
||||
reset.Token = utils.MakeRandomString(400)
|
||||
err = UserPasswordReset(reset)
|
||||
assert.Error(t, err)
|
||||
assert.True(t, IsErrInvalidPasswordResetToken(err))
|
||||
}
|
||||
|
36
models/utils/random_string.go
Normal file
36
models/utils/random_string.go
Normal file
@ -0,0 +1,36 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
}
|
||||
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
const (
|
||||
letterIdxBits = 6 // 6 bits to represent a letter index
|
||||
letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits
|
||||
letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits
|
||||
)
|
||||
|
||||
// MakeRandomString return a random string
|
||||
func MakeRandomString(n int) string {
|
||||
b := make([]byte, n)
|
||||
// A rand.Int63() generates 63 random bits, enough for letterIdxMax letters!
|
||||
for i, cache, remain := n-1, rand.Int63(), letterIdxMax; i >= 0; {
|
||||
if remain == 0 {
|
||||
cache, remain = rand.Int63(), letterIdxMax
|
||||
}
|
||||
if idx := int(cache & letterIdxMask); idx < len(letterBytes) {
|
||||
b[i] = letterBytes[idx]
|
||||
i--
|
||||
}
|
||||
cache >>= letterIdxBits
|
||||
remain--
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
Reference in New Issue
Block a user