From fc8252e751fcf06f3a2e3f507839ee9e90613d7f Mon Sep 17 00:00:00 2001 From: kolaente Date: Tue, 3 Sep 2024 22:03:55 +0200 Subject: [PATCH] fix(subscriptions): correctly inherit subscriptions Resolves https://community.vikunja.io/t/e-mail-notification-twice/2740/20 (cherry picked from commit 06305eb6b3300bf1c989e06e54766e427bcc749a) --- pkg/db/fixtures/subscriptions.yml | 10 ++++ pkg/models/subscription.go | 82 +++++++++++++++---------------- pkg/models/subscription_test.go | 9 ++++ 3 files changed, 59 insertions(+), 42 deletions(-) diff --git a/pkg/db/fixtures/subscriptions.yml b/pkg/db/fixtures/subscriptions.yml index 7f97ca40b..f4f298211 100644 --- a/pkg/db/fixtures/subscriptions.yml +++ b/pkg/db/fixtures/subscriptions.yml @@ -28,3 +28,13 @@ entity_id: 32 user_id: 6 created: 2021-02-01 15:13:12 +- id: 9 + entity_type: 3 # Task + entity_id: 18 + user_id: 6 + created: 2021-02-01 15:13:12 +- id: 10 + entity_type: 2 # Project + entity_id: 9 + user_id: 6 + created: 2021-02-01 15:13:12 diff --git a/pkg/models/subscription.go b/pkg/models/subscription.go index 4bd3761c1..e775ce81c 100644 --- a/pkg/models/subscription.go +++ b/pkg/models/subscription.go @@ -196,23 +196,16 @@ func getSubscriberCondForEntities(entityType SubscriptionEntityType, entityIDs [ // It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for // that task, if there is none it will look for a subscription on the project the task belongs to. func GetSubscription(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *Subscription, err error) { - subs, err := GetSubscriptions(s, entityType, []int64{entityID}, a) + subs, err := GetSubscriptions(s, entityType, entityID, a) if err != nil || len(subs) == 0 { return nil, err } - if sub, exists := subs[entityID]; exists && len(sub) > 0 { - return sub[0], nil // Take exact match first, if available - } - for _, sub := range subs { - if len(sub) > 0 { - return sub[0], nil // For parents, take next available - } - } - return nil, nil + + return subs[0], nil } -// GetSubscriptions returns a map of subscriptions to a set of given entity IDs -func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, a web.Auth) (projectsToSubscriptions map[int64][]*Subscription, err error) { +// GetSubscriptions returns a list of subscriptions to for an entity ID +func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscriptions []*Subscription, err error) { u, is := a.(*user.User) if u != nil && !is { return @@ -223,23 +216,37 @@ func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entity switch entityType { case SubscriptionEntityProject: - projects, err := GetProjectsByIDs(s, entityIDs) + project, err := GetProjectSimpleByID(s, entityID) if err != nil { return nil, err } - return GetSubscriptionsForProjects(s, projects, u) + subs, err := GetSubscriptionsForProjects(s, []*Project{project}, u) + if err != nil { + return nil, err + } + if _, has := subs[entityID]; has && subs[entityID] != nil { + return subs[entityID], nil + } + + for _, sub := range subs { + // Fallback to the first non-nil subscription + if len(sub) > 0 { + return sub, nil + } + } + + return nil, nil case SubscriptionEntityTask: - subs, err := getSubscriptionsForTasks(s, entityIDs, u) + subs, err := getSubscriptionsForTask(s, entityID, u) if err != nil { return nil, err } - projects, err := GetProjectsSimplByTaskIDs(s, entityIDs) - if err != nil { - return nil, err + if len(subs) > 0 { + return subs, nil } - tasks, err := GetTasksSimpleByIDs(s, entityIDs) + projects, err := GetProjectsSimplByTaskIDs(s, []int64{entityID}) if err != nil { return nil, err } @@ -249,18 +256,14 @@ func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entity return nil, err } - for _, task := range tasks { - // If a task is already subscribed through the parent project, - // remove the task subscription since that's a duplicate. - // But if the user is not subscribed to the task but a parent project is, add that to the subscriptions - psub, hasProjectSub := projectSubscriptions[task.ProjectID] - _, hasTaskSub := subs[task.ID] - if hasProjectSub && hasTaskSub { - delete(subs, task.ID) - } + if _, has := projectSubscriptions[projects[0].ID]; has { + return projectSubscriptions[projects[0].ID], nil + } - if !hasTaskSub && !hasProjectSub { - subs[task.ID] = psub + for _, psub := range projectSubscriptions { + // Fallback to the first non-nil subscription + if len(psub) > 0 { + return psub, nil } } @@ -360,26 +363,23 @@ func GetSubscriptionsForProjects(s *xorm.Session, projects []*Project, a web.Aut return projectsToSubscriptions, nil } -func getSubscriptionsForTasks(s *xorm.Session, taskIDs []int64, u *user.User) (projectsToSubscriptions map[int64][]*Subscription, err error) { - var subscriptions []*Subscription +func getSubscriptionsForTask(s *xorm.Session, taskID int64, u *user.User) (subscriptions []*Subscription, err error) { if u != nil { err = s. Where("user_id = ?", u.ID). - And(getSubscriberCondForEntities(SubscriptionEntityTask, taskIDs)). + And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})). Find(&subscriptions) } else { err = s. - And(getSubscriberCondForEntities(SubscriptionEntityTask, taskIDs)). + And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})). Find(&subscriptions) } if err != nil { return nil, err } - projectsToSubscriptions = make(map[int64][]*Subscription) for _, sub := range subscriptions { sub.Entity = sub.EntityType.String() - projectsToSubscriptions[sub.EntityID] = append(projectsToSubscriptions[sub.EntityID], sub) } return @@ -390,18 +390,16 @@ func getSubscribersForEntity(s *xorm.Session, entityType SubscriptionEntityType, return nil, err } - subs, err := GetSubscriptions(s, entityType, []int64{entityID}, nil) + subs, err := GetSubscriptions(s, entityType, entityID, nil) if err != nil { return } userIDs := []int64{} subscriptions = make([]*Subscription, 0, len(subs)) - for _, subss := range subs { - for _, subscription := range subss { - userIDs = append(userIDs, subscription.UserID) - subscriptions = append(subscriptions, subscription) - } + for _, subscription := range subs { + userIDs = append(userIDs, subscription.UserID) + subscriptions = append(subscriptions, subscription) } users, err := user.GetUsersByIDs(s, userIDs) diff --git a/pkg/models/subscription_test.go b/pkg/models/subscription_test.go index 5bf7dfd8d..574cfbc26 100644 --- a/pkg/models/subscription_test.go +++ b/pkg/models/subscription_test.go @@ -313,4 +313,13 @@ func TestSubscriptionGet(t *testing.T) { require.Error(t, err) assert.True(t, IsErrUnknownSubscriptionEntityType(err)) }) + t.Run("double subscription should be returned once", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + sub, err := GetSubscription(s, SubscriptionEntityTask, 18, u) + require.NoError(t, err) + assert.Equal(t, int64(9), sub.ID) + }) }