1
0

fix(subscriptions): correctly inherit subscriptions

Resolves https://community.vikunja.io/t/e-mail-notification-twice/2740/20

(cherry picked from commit 06305eb6b3300bf1c989e06e54766e427bcc749a)
This commit is contained in:
kolaente 2024-09-03 22:03:55 +02:00
parent 1c9590075a
commit fc8252e751
No known key found for this signature in database
GPG Key ID: F40E70337AB24C9B
3 changed files with 59 additions and 42 deletions

View File

@ -28,3 +28,13 @@
entity_id: 32 entity_id: 32
user_id: 6 user_id: 6
created: 2021-02-01 15:13:12 created: 2021-02-01 15:13:12
- id: 9
entity_type: 3 # Task
entity_id: 18
user_id: 6
created: 2021-02-01 15:13:12
- id: 10
entity_type: 2 # Project
entity_id: 9
user_id: 6
created: 2021-02-01 15:13:12

View File

@ -196,23 +196,16 @@ func getSubscriberCondForEntities(entityType SubscriptionEntityType, entityIDs [
// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for // It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for
// that task, if there is none it will look for a subscription on the project the task belongs to. // that task, if there is none it will look for a subscription on the project the task belongs to.
func GetSubscription(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *Subscription, err error) { func GetSubscription(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *Subscription, err error) {
subs, err := GetSubscriptions(s, entityType, []int64{entityID}, a) subs, err := GetSubscriptions(s, entityType, entityID, a)
if err != nil || len(subs) == 0 { if err != nil || len(subs) == 0 {
return nil, err return nil, err
} }
if sub, exists := subs[entityID]; exists && len(sub) > 0 {
return sub[0], nil // Take exact match first, if available return subs[0], nil
}
for _, sub := range subs {
if len(sub) > 0 {
return sub[0], nil // For parents, take next available
}
}
return nil, nil
} }
// GetSubscriptions returns a map of subscriptions to a set of given entity IDs // GetSubscriptions returns a list of subscriptions to for an entity ID
func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, a web.Auth) (projectsToSubscriptions map[int64][]*Subscription, err error) { func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscriptions []*Subscription, err error) {
u, is := a.(*user.User) u, is := a.(*user.User)
if u != nil && !is { if u != nil && !is {
return return
@ -223,23 +216,37 @@ func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entity
switch entityType { switch entityType {
case SubscriptionEntityProject: case SubscriptionEntityProject:
projects, err := GetProjectsByIDs(s, entityIDs) project, err := GetProjectSimpleByID(s, entityID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return GetSubscriptionsForProjects(s, projects, u) subs, err := GetSubscriptionsForProjects(s, []*Project{project}, u)
if err != nil {
return nil, err
}
if _, has := subs[entityID]; has && subs[entityID] != nil {
return subs[entityID], nil
}
for _, sub := range subs {
// Fallback to the first non-nil subscription
if len(sub) > 0 {
return sub, nil
}
}
return nil, nil
case SubscriptionEntityTask: case SubscriptionEntityTask:
subs, err := getSubscriptionsForTasks(s, entityIDs, u) subs, err := getSubscriptionsForTask(s, entityID, u)
if err != nil { if err != nil {
return nil, err return nil, err
} }
projects, err := GetProjectsSimplByTaskIDs(s, entityIDs) if len(subs) > 0 {
if err != nil { return subs, nil
return nil, err
} }
tasks, err := GetTasksSimpleByIDs(s, entityIDs) projects, err := GetProjectsSimplByTaskIDs(s, []int64{entityID})
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -249,18 +256,14 @@ func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entity
return nil, err return nil, err
} }
for _, task := range tasks { if _, has := projectSubscriptions[projects[0].ID]; has {
// If a task is already subscribed through the parent project, return projectSubscriptions[projects[0].ID], nil
// remove the task subscription since that's a duplicate. }
// But if the user is not subscribed to the task but a parent project is, add that to the subscriptions
psub, hasProjectSub := projectSubscriptions[task.ProjectID]
_, hasTaskSub := subs[task.ID]
if hasProjectSub && hasTaskSub {
delete(subs, task.ID)
}
if !hasTaskSub && !hasProjectSub { for _, psub := range projectSubscriptions {
subs[task.ID] = psub // Fallback to the first non-nil subscription
if len(psub) > 0 {
return psub, nil
} }
} }
@ -360,26 +363,23 @@ func GetSubscriptionsForProjects(s *xorm.Session, projects []*Project, a web.Aut
return projectsToSubscriptions, nil return projectsToSubscriptions, nil
} }
func getSubscriptionsForTasks(s *xorm.Session, taskIDs []int64, u *user.User) (projectsToSubscriptions map[int64][]*Subscription, err error) { func getSubscriptionsForTask(s *xorm.Session, taskID int64, u *user.User) (subscriptions []*Subscription, err error) {
var subscriptions []*Subscription
if u != nil { if u != nil {
err = s. err = s.
Where("user_id = ?", u.ID). Where("user_id = ?", u.ID).
And(getSubscriberCondForEntities(SubscriptionEntityTask, taskIDs)). And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
Find(&subscriptions) Find(&subscriptions)
} else { } else {
err = s. err = s.
And(getSubscriberCondForEntities(SubscriptionEntityTask, taskIDs)). And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
Find(&subscriptions) Find(&subscriptions)
} }
if err != nil { if err != nil {
return nil, err return nil, err
} }
projectsToSubscriptions = make(map[int64][]*Subscription)
for _, sub := range subscriptions { for _, sub := range subscriptions {
sub.Entity = sub.EntityType.String() sub.Entity = sub.EntityType.String()
projectsToSubscriptions[sub.EntityID] = append(projectsToSubscriptions[sub.EntityID], sub)
} }
return return
@ -390,18 +390,16 @@ func getSubscribersForEntity(s *xorm.Session, entityType SubscriptionEntityType,
return nil, err return nil, err
} }
subs, err := GetSubscriptions(s, entityType, []int64{entityID}, nil) subs, err := GetSubscriptions(s, entityType, entityID, nil)
if err != nil { if err != nil {
return return
} }
userIDs := []int64{} userIDs := []int64{}
subscriptions = make([]*Subscription, 0, len(subs)) subscriptions = make([]*Subscription, 0, len(subs))
for _, subss := range subs { for _, subscription := range subs {
for _, subscription := range subss { userIDs = append(userIDs, subscription.UserID)
userIDs = append(userIDs, subscription.UserID) subscriptions = append(subscriptions, subscription)
subscriptions = append(subscriptions, subscription)
}
} }
users, err := user.GetUsersByIDs(s, userIDs) users, err := user.GetUsersByIDs(s, userIDs)

View File

@ -313,4 +313,13 @@ func TestSubscriptionGet(t *testing.T) {
require.Error(t, err) require.Error(t, err)
assert.True(t, IsErrUnknownSubscriptionEntityType(err)) assert.True(t, IsErrUnknownSubscriptionEntityType(err))
}) })
t.Run("double subscription should be returned once", func(t *testing.T) {
db.LoadAndAssertFixtures(t)
s := db.NewSession()
defer s.Close()
sub, err := GetSubscription(s, SubscriptionEntityTask, 18, u)
require.NoError(t, err)
assert.Equal(t, int64(9), sub.ID)
})
} }