From 8b8ec19bb38cd379f7d19265d55fd09c11de920a Mon Sep 17 00:00:00 2001 From: kolaente Date: Wed, 4 Sep 2024 19:54:22 +0200 Subject: [PATCH] fix(subscriptions): cleanup and simplify fetching subscribers for tasks and projects logic Vikunja now uses one recursive CTE and a few optimizations to fetch all subscribers for a task or project. This makes the relevant code easier to maintain and more performant. (cherry picked from commit 4ff8815fe1bfe72e02c10f6a6877c93a630f36a4) --- pkg/models/error.go | 28 +- pkg/models/listeners.go | 12 +- pkg/models/project.go | 26 +- pkg/models/subscription.go | 455 +++++++++++++++----------------- pkg/models/subscription_test.go | 40 ++- pkg/models/tasks.go | 3 +- pkg/utils/strings.go | 31 +++ 7 files changed, 330 insertions(+), 265 deletions(-) create mode 100644 pkg/utils/strings.go diff --git a/pkg/models/error.go b/pkg/models/error.go index e7b99d77b..2f06ad00c 100644 --- a/pkg/models/error.go +++ b/pkg/models/error.go @@ -1745,7 +1745,7 @@ func IsErrSubscriptionAlreadyExists(err error) bool { } func (err *ErrSubscriptionAlreadyExists) Error() string { - return fmt.Sprintf("Subscription for this (entity_id, entity_type, user_id) already exists [EntityType: %d, EntityID: %d, ID: %d]", err.EntityType, err.EntityID, err.UserID) + return fmt.Sprintf("Subscription for this (entity_id, entity_type, user_id) already exists [EntityType: %d, EntityID: %d, UserID: %d]", err.EntityType, err.EntityID, err.UserID) } // ErrCodeSubscriptionAlreadyExists holds the unique world-error code of this error @@ -1760,6 +1760,32 @@ func (err ErrSubscriptionAlreadyExists) HTTPError() web.HTTPError { } } +// ErrMustProvideUser represents an error where you need to provide a user to fetch subscriptions +type ErrMustProvideUser struct { +} + +// IsErrMustProvideUser checks if an error is ErrMustProvideUser. +func IsErrMustProvideUser(err error) bool { + _, ok := err.(*ErrMustProvideUser) + return ok +} + +func (err *ErrMustProvideUser) Error() string { + return "no user provided while fetching subscriptions" +} + +// ErrCodeMustProvideUser holds the unique world-error code of this error +const ErrCodeMustProvideUser = 12003 + +// HTTPError holds the http error description +func (err ErrMustProvideUser) HTTPError() web.HTTPError { + return web.HTTPError{ + HTTPCode: http.StatusPreconditionFailed, + Code: ErrCodeMustProvideUser, + Message: "You must provide a user to fetch subscriptions", + } +} + // ================= // Link Share errors // ================= diff --git a/pkg/models/listeners.go b/pkg/models/listeners.go index b2811aa1e..5d2f90161 100644 --- a/pkg/models/listeners.go +++ b/pkg/models/listeners.go @@ -199,7 +199,7 @@ func (s *SendTaskCommentNotification) Handle(msg *message.Message) (err error) { return err } - subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID) + subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID) if err != nil { return err } @@ -279,7 +279,7 @@ func (s *SendTaskAssignedNotification) Handle(msg *message.Message) (err error) sess := db.NewSession() defer sess.Close() - subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID) + subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID) if err != nil { return err } @@ -340,12 +340,12 @@ func (s *SendTaskDeletedNotification) Handle(msg *message.Message) (err error) { sess := db.NewSession() defer sess.Close() - var subscribers []*Subscription - subscribers, err = getSubscribersForEntity(sess, SubscriptionEntityTask, event.Task.ID) + var subscribers []*SubscriptionWithUser + subscribers, err = GetSubscriptionsForEntity(sess, SubscriptionEntityTask, event.Task.ID) // If the task does not exist and no one has explicitly subscribed to it, we won't find any subscriptions for it. // Hence, we need to check for subscriptions to the parent project manually. if err != nil && (IsErrTaskDoesNotExist(err) || IsErrProjectDoesNotExist(err)) { - subscribers, err = getSubscribersForEntity(sess, SubscriptionEntityProject, event.Task.ProjectID) + subscribers, err = GetSubscriptionsForEntity(sess, SubscriptionEntityProject, event.Task.ProjectID) } if err != nil { return err @@ -801,7 +801,7 @@ func (s *SendProjectCreatedNotification) Handle(msg *message.Message) (err error sess := db.NewSession() defer sess.Close() - subscribers, err := getSubscribersForEntity(sess, SubscriptionEntityProject, event.Project.ID) + subscribers, err := GetSubscriptionsForEntity(sess, SubscriptionEntityProject, event.Project.ID) if err != nil { return err } diff --git a/pkg/models/project.go b/pkg/models/project.go index de6aa543c..8d4cdbbbe 100644 --- a/pkg/models/project.go +++ b/pkg/models/project.go @@ -297,10 +297,13 @@ func (p *Project) ReadOne(s *xorm.Session, a web.Auth) (err error) { return } - p.Subscription, err = GetSubscription(s, SubscriptionEntityProject, p.ID, a) + subs, err := GetSubscriptionForUser(s, SubscriptionEntityProject, p.ID, a) if err != nil && IsErrProjectDoesNotExist(err) && isFilter { return nil } + if subs != nil { + p.Subscription = &subs.Subscription + } p.Views, err = getViewsForProject(s, p.ID) return @@ -629,10 +632,23 @@ func addProjectDetails(s *xorm.Session, projects []*Project, a web.Auth) (err er return err } - subscriptions, err := GetSubscriptionsForProjects(s, projects, a) - if err != nil { - log.Errorf("An error occurred while getting project subscriptions for a project: %s", err.Error()) - subscriptions = make(map[int64][]*Subscription) + var subscriptions = make(map[int64][]*Subscription) + u, is := a.(*user.User) + if is { + subscriptionsWithUser, err := GetSubscriptionsForEntitiesAndUser(s, SubscriptionEntityProject, projectIDs, u) + if err != nil { + log.Errorf("An error occurred while getting project subscriptions for a project: %s", err.Error()) + } + if err == nil { + for pID, subs := range subscriptionsWithUser { + for _, sub := range subs { + if _, has := subscriptions[pID]; !has { + subscriptions[pID] = []*Subscription{} + } + subscriptions[pID] = append(subscriptions[pID], &sub.Subscription) + } + } + } } views := []*ProjectView{} diff --git a/pkg/models/subscription.go b/pkg/models/subscription.go index 3ff498e58..91431f4d8 100644 --- a/pkg/models/subscription.go +++ b/pkg/models/subscription.go @@ -17,12 +17,13 @@ package models import ( + "strconv" "time" - "xorm.io/builder" - "code.vikunja.io/api/pkg/user" + "code.vikunja.io/api/pkg/utils" "code.vikunja.io/api/pkg/web" + "xorm.io/xorm" ) @@ -52,8 +53,7 @@ type Subscription struct { EntityID int64 `xorm:"bigint index not null" json:"entity_id" param:"entityID"` // The user who made this subscription - User *user.User `xorm:"-" json:"user"` - UserID int64 `xorm:"bigint index not null" json:"-"` + UserID int64 `xorm:"bigint index not null" json:"-"` // A timestamp when this subscription was created. You cannot change this value. Created time.Time `xorm:"created not null" json:"created"` @@ -62,7 +62,18 @@ type Subscription struct { web.Rights `xorm:"-" json:"-"` } -// TableName gives us a better tabel name for the subscriptions table +type SubscriptionWithUser struct { + Subscription `xorm:"extends"` + User *user.User `xorm:"extends" json:"user"` +} + +type subscriptionResolved struct { + OriginalEntityID int64 + SubscriptionID int64 + SubscriptionWithUser `xorm:"extends"` +} + +// TableName gives us a better table name for the subscriptions table func (sb *Subscription) TableName() string { return "subscriptions" } @@ -115,28 +126,23 @@ func (et SubscriptionEntityType) validate() error { // @Failure 500 {object} models.Message "Internal error" // @Router /subscriptions/{entity}/{entityID} [put] func (sb *Subscription) Create(s *xorm.Session, auth web.Auth) (err error) { - // Rights method alread does the validation of the entity type so we don't need to do that here + // Rights method already does the validation of the entity type, so we don't need to do that here sb.UserID = auth.GetID() - sub, err := GetSubscription(s, sb.EntityType, sb.EntityID, auth) + sub, err := GetSubscriptionForUser(s, sb.EntityType, sb.EntityID, auth) if err != nil { return err } if sub != nil { return &ErrSubscriptionAlreadyExists{ - EntityID: sb.EntityID, - EntityType: sb.EntityType, - UserID: sb.UserID, + EntityID: sub.EntityID, + EntityType: sub.EntityType, + UserID: sub.UserID, } } _, err = s.Insert(sb) - if err != nil { - return - } - - sb.User, err = user.GetFromAuth(auth) return } @@ -163,261 +169,228 @@ func (sb *Subscription) Delete(s *xorm.Session, auth web.Auth) (err error) { return } -func getSubscriberCondForEntities(entityType SubscriptionEntityType, entityIDs []int64) (cond builder.Cond) { - if entityType == SubscriptionEntityProject { - return builder.And( - builder.In("entity_id", entityIDs), - builder.Eq{"entity_type": SubscriptionEntityProject}, - ) - } - - if entityType == SubscriptionEntityTask { - return builder.Or( - builder.And( - builder.In("entity_id", entityIDs), - builder.Eq{"entity_type": SubscriptionEntityTask}, - ), - builder.And( - builder.Eq{"entity_id": builder. - Select("project_id"). - From("tasks"). - Where(builder.In("id", entityIDs)), - // TODO parent project - }, - builder.Eq{"entity_type": SubscriptionEntityProject}, - ), - ) - } - - return -} - -// GetSubscription returns a matching subscription for an entity and user. -// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for -// that task, if there is none it will look for a subscription on the project the task belongs to. -func GetSubscription(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *Subscription, err error) { - subs, err := GetSubscriptions(s, entityType, entityID, a) - if err != nil || len(subs) == 0 { - return nil, err - } - - return subs[0], nil -} - -// GetSubscriptions returns a list of subscriptions to for an entity ID -func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscriptions []*Subscription, err error) { +func GetSubscriptionForUser(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *SubscriptionWithUser, err error) { u, is := a.(*user.User) if u != nil && !is { return } + + subs, err := GetSubscriptionsForEntitiesAndUser(s, entityType, []int64{entityID}, u) + if err != nil || len(subs) == 0 || len(subs[entityID]) == 0 { + return nil, err + } + + return subs[entityID][0], nil +} + +// GetSubscriptionsForEntities returns a list of subscriptions to for an entity ID +func GetSubscriptionsForEntities(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64) (subscriptions map[int64][]*SubscriptionWithUser, err error) { + return getSubscriptionsForEntitiesAndUser(s, entityType, entityIDs, nil, false) +} + +func GetSubscriptionsForEntitiesAndUser(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, u *user.User) (subscriptions map[int64][]*SubscriptionWithUser, err error) { + return getSubscriptionsForEntitiesAndUser(s, entityType, entityIDs, u, true) +} + +func GetSubscriptionsForEntity(s *xorm.Session, entityType SubscriptionEntityType, entityID int64) (subscriptions []*SubscriptionWithUser, err error) { + subs, err := GetSubscriptionsForEntities(s, entityType, []int64{entityID}) + if err != nil || len(subs[entityID]) == 0 { + return + } + + return subs[entityID], nil +} + +// This function returns a matching subscription for an entity and user. +// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for +// that task, if there is none it will look for a subscription on the project the task belongs to. +// It will return a map where the key is the entity id and the value is a slice with all subscriptions for that entity. +func getSubscriptionsForEntitiesAndUser(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, u *user.User, userOnly bool) (subscriptions map[int64][]*SubscriptionWithUser, err error) { if err := entityType.validate(); err != nil { return nil, err } + rawSubscriptions := []*subscriptionResolved{} + entityIDString := utils.JoinInt64Slice(entityIDs, ", ") + + var sUserCond string + if userOnly { + if u == nil { + return nil, &ErrMustProvideUser{} + } + sUserCond = " AND s.user_id = " + strconv.FormatInt(u.ID, 10) + } + switch entityType { case SubscriptionEntityProject: - project, err := GetProjectSimpleByID(s, entityID) - if err != nil { - return nil, err - } - subs, err := GetSubscriptionsForProjects(s, []*Project{project}, u) - if err != nil { - return nil, err - } - if _, has := subs[entityID]; has && subs[entityID] != nil { - return subs[entityID], nil - } + err = s.SQL(` +WITH RECURSIVE project_hierarchy AS ( + -- Base case: Start with the specified projects + SELECT + id, + parent_project_id, + 0 AS level, + id AS original_project_id + FROM projects + WHERE id IN (`+entityIDString+`) - for _, sub := range subs { - // Fallback to the first non-nil subscription - if len(sub) > 0 { - return sub, nil - } - } + UNION ALL - return nil, nil + -- Recursive case: Get parent projects + SELECT + p.id, + p.parent_project_id, + ph.level + 1, + ph.original_project_id + FROM projects p + INNER JOIN project_hierarchy ph ON p.id = ph.parent_project_id +), + +subscription_hierarchy AS ( + -- Check for project subscriptions (including parent projects) + SELECT + s.id, + s.entity_type, + s.entity_id, + s.created, + s.user_id, + CASE + WHEN s.entity_id = ph.original_project_id THEN 1 -- Direct project match + ELSE ph.level + 1 -- Parent projects + END AS priority, + ph.original_project_id + FROM subscriptions s + INNER JOIN project_hierarchy ph ON s.entity_id = ph.id + WHERE s.entity_type = ?`+sUserCond+` +) + +SELECT + p.id AS original_entity_id, + sh.id AS subscription_id, + sh.entity_type, + sh.entity_id, + sh.created, + sh.user_id, + CASE + WHEN sh.priority = 1 THEN 'Direct Project' + ELSE 'Parent Project' + END + AS subscription_level, + users.* +FROM projects p + LEFT JOIN ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY original_project_id, user_id ORDER BY priority) AS rn + FROM subscription_hierarchy +) sh ON p.id = sh.original_project_id AND sh.rn = 1 + LEFT JOIN users ON sh.user_id = users.id +WHERE p.id IN (`+entityIDString+`) +ORDER BY p.id, sh.user_id`, SubscriptionEntityProject). + Find(&rawSubscriptions) case SubscriptionEntityTask: - subs, err := getSubscriptionsForTask(s, entityID, u) - if err != nil { - return nil, err - } + err = s.SQL(` +WITH RECURSIVE project_hierarchy AS ( + -- Base case: Start with the projects associated with the tasks + SELECT + p.id, + p.parent_project_id, + 0 AS level, + t.id AS task_id + FROM tasks t + JOIN projects p ON t.project_id = p.id + WHERE t.id IN (`+entityIDString+`) - for _, sub := range subs { - // The subscriptions might also contain the immediate parent subscription, if that exists. - // This loop makes sure to only return the task subscription if it exists. The fallback - // happens in the next if after the loop. - if sub.EntityID == entityID && sub.EntityType == SubscriptionEntityTask { - return []*Subscription{sub}, nil - } - } + UNION ALL - if len(subs) > 0 { - return subs, nil - } + -- Recursive case: Get parent projects + SELECT + p.id, + p.parent_project_id, + ph.level + 1, + ph.task_id + FROM projects p + INNER JOIN project_hierarchy ph ON p.id = ph.parent_project_id +), - projects, err := GetProjectsSimplByTaskIDs(s, []int64{entityID}) - if err != nil { - return nil, err - } +subscription_hierarchy AS ( + -- Check for task subscriptions + SELECT + s.id, + s.entity_type, + s.entity_id, + s.created, + s.user_id, + 1 AS priority, + t.id AS task_id + FROM subscriptions s + JOIN tasks t ON s.entity_id = t.id + WHERE s.entity_type = ? AND t.id IN (`+entityIDString+`)`+sUserCond+` - projectSubscriptions, err := GetSubscriptionsForProjects(s, projects, u) - if err != nil { - return nil, err - } + UNION ALL - if _, has := projectSubscriptions[projects[0].ID]; has { - return projectSubscriptions[projects[0].ID], nil - } + -- Check for project subscriptions (including parent projects) + SELECT + s.id, + s.entity_type, + s.entity_id, + s.created, + s.user_id, + ph.level + 2 AS priority, + ph.task_id + FROM subscriptions s + INNER JOIN project_hierarchy ph ON s.entity_id = ph.id + WHERE s.entity_type = ? +) - for _, psub := range projectSubscriptions { - // Fallback to the first non-nil subscription - if len(psub) > 0 { - return psub, nil - } - } - - return subs, nil +SELECT + t.id AS original_entity_id, + sh.id AS subscription_id, + sh.entity_type, + sh.entity_id, + sh.created, + sh.user_id, + CASE + WHEN sh.entity_type = ? THEN 'Task' + WHEN sh.priority = ? THEN 'Direct Project' + ELSE 'Parent Project' + END + AS subscription_level, + users.* +FROM tasks t + LEFT JOIN ( + SELECT *, + ROW_NUMBER() OVER (PARTITION BY task_id, user_id ORDER BY priority) AS rn + FROM subscription_hierarchy +) sh ON t.id = sh.task_id AND sh.rn = 1 + LEFT JOIN users ON sh.user_id = users.id +WHERE t.id IN (`+entityIDString+`) +ORDER BY t.id, sh.user_id`, + SubscriptionEntityTask, SubscriptionEntityProject, SubscriptionEntityTask, SubscriptionEntityProject). + Find(&rawSubscriptions) + } + if err != nil { + return nil, err } - return -} + subscriptions = make(map[int64][]*SubscriptionWithUser) + for _, sub := range rawSubscriptions { -func GetSubscriptionsForProjects(s *xorm.Session, projects []*Project, a web.Auth) (projectsToSubscriptions map[int64][]*Subscription, err error) { - u, is := a.(*user.User) - if u != nil && !is { - return - } - - var ps = make(map[int64]*Project) - origProjectIDs := make([]int64, 0, len(projects)) - allProjectIDs := make([]int64, 0, len(projects)) - - for _, p := range projects { - ps[p.ID] = p - origProjectIDs = append(origProjectIDs, p.ID) - allProjectIDs = append(allProjectIDs, p.ID) - } - - // We can't just use the projects we have, we need to fetch the parents - // because they may not be loaded in the same object - - for _, p := range projects { - if p.ParentProjectID == 0 { + if sub.Subscription.EntityID == 0 { continue } - if _, has := ps[p.ParentProjectID]; has { - continue + _, has := subscriptions[sub.OriginalEntityID] + if !has { + subscriptions[sub.OriginalEntityID] = []*SubscriptionWithUser{} } - parents, err := GetAllParentProjects(s, p.ID) - if err != nil { - return nil, err + sub.Subscription.ID = sub.SubscriptionID + if sub.User != nil { + sub.User.ID = sub.UserID } - // Walk the tree up until we reach the top - var parent = parents[p.ParentProjectID] // parent now has a pointer… - ps[p.ID].ParentProject = parents[p.ParentProjectID] - for parent != nil { - allProjectIDs = append(allProjectIDs, parent.ID) - parent = parents[parent.ParentProjectID] // … which means we can update it here and then update the pointer in the map - } + subscriptions[sub.OriginalEntityID] = append(subscriptions[sub.OriginalEntityID], &sub.SubscriptionWithUser) } - var subscriptions []*Subscription - if u != nil { - err = s. - Where("user_id = ?", u.ID). - And(getSubscriberCondForEntities(SubscriptionEntityProject, allProjectIDs)). - Find(&subscriptions) - } else { - err = s. - And(getSubscriberCondForEntities(SubscriptionEntityProject, allProjectIDs)). - Find(&subscriptions) - } - if err != nil { - return nil, err - } - - projectsToSubscriptions = make(map[int64][]*Subscription) - for _, sub := range subscriptions { - sub.Entity = sub.EntityType.String() - projectsToSubscriptions[sub.EntityID] = append(projectsToSubscriptions[sub.EntityID], sub) - } - - // Rearrange so that subscriptions trickle down - - for _, eID := range origProjectIDs { - // If the current project does not have a subscription, climb up the tree until a project has one, - // then use that subscription for all child projects - _, has := projectsToSubscriptions[eID] - _, hasProject := ps[eID] - if !has && hasProject { - _, exists := ps[eID] - if !exists { - continue - } - var parent = ps[eID].ParentProject - for parent != nil { - sub, has := projectsToSubscriptions[parent.ID] - projectsToSubscriptions[eID] = sub - parent = parent.ParentProject - if has { // reached the top of the tree - break - } - } - } - } - - return projectsToSubscriptions, nil -} - -func getSubscriptionsForTask(s *xorm.Session, taskID int64, u *user.User) (subscriptions []*Subscription, err error) { - if u != nil { - err = s. - Where("user_id = ?", u.ID). - And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})). - Find(&subscriptions) - } else { - err = s. - And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})). - Find(&subscriptions) - } - if err != nil { - return nil, err - } - - for _, sub := range subscriptions { - sub.Entity = sub.EntityType.String() - } - - return -} - -func getSubscribersForEntity(s *xorm.Session, entityType SubscriptionEntityType, entityID int64) (subscriptions []*Subscription, err error) { - if err := entityType.validate(); err != nil { - return nil, err - } - - subs, err := GetSubscriptions(s, entityType, entityID, nil) - if err != nil { - return - } - - userIDs := []int64{} - subscriptions = make([]*Subscription, 0, len(subs)) - for _, subscription := range subs { - userIDs = append(userIDs, subscription.UserID) - subscriptions = append(subscriptions, subscription) - } - - users, err := user.GetUsersByIDs(s, userIDs) - if err != nil { - return - } - - for _, subscription := range subscriptions { - subscription.User = users[subscription.UserID] - } - return + return subscriptions, nil } diff --git a/pkg/models/subscription_test.go b/pkg/models/subscription_test.go index 574cfbc26..31dec2e96 100644 --- a/pkg/models/subscription_test.go +++ b/pkg/models/subscription_test.go @@ -52,7 +52,6 @@ func TestSubscription_Create(t *testing.T) { sb := &Subscription{ Entity: "task", EntityID: 1, - UserID: u.ID, } can, err := sb.CanCreate(s, u) @@ -61,7 +60,6 @@ func TestSubscription_Create(t *testing.T) { err = sb.Create(s, u) require.NoError(t, err) - assert.NotNil(t, sb.User) db.AssertExists(t, "subscriptions", map[string]interface{}{ "entity_type": 3, @@ -69,6 +67,26 @@ func TestSubscription_Create(t *testing.T) { "user_id": u.ID, }, false) }) + t.Run("already exists", func(t *testing.T) { + db.LoadAndAssertFixtures(t) + s := db.NewSession() + defer s.Close() + + sb := &Subscription{ + Entity: "task", + EntityID: 2, + UserID: u.ID, + } + + can, err := sb.CanCreate(s, u) + require.NoError(t, err) + assert.True(t, can) + + err = sb.Create(s, u) + require.Error(t, err) + terr := &ErrSubscriptionAlreadyExists{} + assert.ErrorAs(t, err, &terr) + }) t.Run("forbidden for link shares", func(t *testing.T) { db.LoadAndAssertFixtures(t) s := db.NewSession() @@ -86,7 +104,7 @@ func TestSubscription_Create(t *testing.T) { require.Error(t, err) assert.False(t, can) }) - t.Run("noneixsting project", func(t *testing.T) { + t.Run("nonexisting project", func(t *testing.T) { db.LoadAndAssertFixtures(t) s := db.NewSession() defer s.Close() @@ -240,7 +258,7 @@ func TestSubscriptionGet(t *testing.T) { s := db.NewSession() defer s.Close() - sub, err := GetSubscription(s, SubscriptionEntityProject, 12, u) + sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 12, u) require.NoError(t, err) assert.NotNil(t, sub) assert.Equal(t, int64(3), sub.ID) @@ -250,7 +268,7 @@ func TestSubscriptionGet(t *testing.T) { s := db.NewSession() defer s.Close() - sub, err := GetSubscription(s, SubscriptionEntityTask, 22, u) + sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 22, u) require.NoError(t, err) assert.NotNil(t, sub) assert.Equal(t, int64(4), sub.ID) @@ -263,7 +281,7 @@ func TestSubscriptionGet(t *testing.T) { defer s.Close() // Project 25 belongs to project 12 where user 6 has subscribed to - sub, err := GetSubscription(s, SubscriptionEntityProject, 25, u) + sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 25, u) require.NoError(t, err) assert.NotNil(t, sub) assert.Equal(t, int64(12), sub.EntityID) @@ -275,7 +293,7 @@ func TestSubscriptionGet(t *testing.T) { defer s.Close() // Project 26 belongs to project 25 which belongs to project 12 where user 6 has subscribed to - sub, err := GetSubscription(s, SubscriptionEntityProject, 26, u) + sub, err := GetSubscriptionForUser(s, SubscriptionEntityProject, 26, u) require.NoError(t, err) assert.NotNil(t, sub) assert.Equal(t, int64(12), sub.EntityID) @@ -287,7 +305,7 @@ func TestSubscriptionGet(t *testing.T) { defer s.Close() // Task 39 belongs to project 25 which belongs to project 12 where the user has subscribed - sub, err := GetSubscription(s, SubscriptionEntityTask, 39, u) + sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 39, u) require.NoError(t, err) assert.NotNil(t, sub) // assert.Equal(t, int64(2), sub.ID) TODO @@ -298,7 +316,7 @@ func TestSubscriptionGet(t *testing.T) { defer s.Close() // Task 21 belongs to project 32 which the user has subscribed to - sub, err := GetSubscription(s, SubscriptionEntityTask, 21, u) + sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 21, u) require.NoError(t, err) assert.NotNil(t, sub) assert.Equal(t, int64(8), sub.ID) @@ -309,7 +327,7 @@ func TestSubscriptionGet(t *testing.T) { s := db.NewSession() defer s.Close() - _, err := GetSubscription(s, 2342, 21, u) + _, err := GetSubscriptionForUser(s, 2342, 21, u) require.Error(t, err) assert.True(t, IsErrUnknownSubscriptionEntityType(err)) }) @@ -318,7 +336,7 @@ func TestSubscriptionGet(t *testing.T) { s := db.NewSession() defer s.Close() - sub, err := GetSubscription(s, SubscriptionEntityTask, 18, u) + sub, err := GetSubscriptionForUser(s, SubscriptionEntityTask, 18, u) require.NoError(t, err) assert.Equal(t, int64(9), sub.ID) }) diff --git a/pkg/models/tasks.go b/pkg/models/tasks.go index 192b2e411..98b3fa887 100644 --- a/pkg/models/tasks.go +++ b/pkg/models/tasks.go @@ -1573,10 +1573,11 @@ func (t *Task) ReadOne(s *xorm.Session, a web.Auth) (err error) { *t = *taskMap[t.ID] - t.Subscription, err = GetSubscription(s, SubscriptionEntityTask, t.ID, a) + subs, err := GetSubscriptionForUser(s, SubscriptionEntityTask, t.ID, a) if err != nil && IsErrProjectDoesNotExist(err) { return nil } + t.Subscription = &subs.Subscription return } diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go new file mode 100644 index 000000000..e236ed95d --- /dev/null +++ b/pkg/utils/strings.go @@ -0,0 +1,31 @@ +// Vikunja is a to-do list application to facilitate your life. +// Copyright 2018-present Vikunja and contributors. All rights reserved. +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public Licensee as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public Licensee for more details. +// +// You should have received a copy of the GNU Affero General Public Licensee +// along with this program. If not, see . + +package utils + +import "strconv" + +func JoinInt64Slice(ints []int64, delim string) string { + b := "" + for _, v := range ints { + if len(b) > 0 { + b += delim + } + b += strconv.FormatInt(v, 10) + } + + return b +}