1
0

fix(subscriptions): cleanup and simplify fetching subscribers for tasks and projects logic

Vikunja now uses one recursive CTE and a few optimizations to fetch all subscribers for a task or project. This makes the relevant code easier to maintain and more performant.

(cherry picked from commit 4ff8815fe1bfe72e02c10f6a6877c93a630f36a4)
This commit is contained in:
kolaente
2024-09-04 19:54:22 +02:00
parent 7646c7f0c9
commit 8b8ec19bb3
7 changed files with 330 additions and 265 deletions

View File

@ -17,12 +17,13 @@
package models
import (
"strconv"
"time"
"xorm.io/builder"
"code.vikunja.io/api/pkg/user"
"code.vikunja.io/api/pkg/utils"
"code.vikunja.io/api/pkg/web"
"xorm.io/xorm"
)
@ -52,8 +53,7 @@ type Subscription struct {
EntityID int64 `xorm:"bigint index not null" json:"entity_id" param:"entityID"`
// The user who made this subscription
User *user.User `xorm:"-" json:"user"`
UserID int64 `xorm:"bigint index not null" json:"-"`
UserID int64 `xorm:"bigint index not null" json:"-"`
// A timestamp when this subscription was created. You cannot change this value.
Created time.Time `xorm:"created not null" json:"created"`
@ -62,7 +62,18 @@ type Subscription struct {
web.Rights `xorm:"-" json:"-"`
}
// TableName gives us a better tabel name for the subscriptions table
type SubscriptionWithUser struct {
Subscription `xorm:"extends"`
User *user.User `xorm:"extends" json:"user"`
}
type subscriptionResolved struct {
OriginalEntityID int64
SubscriptionID int64
SubscriptionWithUser `xorm:"extends"`
}
// TableName gives us a better table name for the subscriptions table
func (sb *Subscription) TableName() string {
return "subscriptions"
}
@ -115,28 +126,23 @@ func (et SubscriptionEntityType) validate() error {
// @Failure 500 {object} models.Message "Internal error"
// @Router /subscriptions/{entity}/{entityID} [put]
func (sb *Subscription) Create(s *xorm.Session, auth web.Auth) (err error) {
// Rights method alread does the validation of the entity type so we don't need to do that here
// Rights method already does the validation of the entity type, so we don't need to do that here
sb.UserID = auth.GetID()
sub, err := GetSubscription(s, sb.EntityType, sb.EntityID, auth)
sub, err := GetSubscriptionForUser(s, sb.EntityType, sb.EntityID, auth)
if err != nil {
return err
}
if sub != nil {
return &ErrSubscriptionAlreadyExists{
EntityID: sb.EntityID,
EntityType: sb.EntityType,
UserID: sb.UserID,
EntityID: sub.EntityID,
EntityType: sub.EntityType,
UserID: sub.UserID,
}
}
_, err = s.Insert(sb)
if err != nil {
return
}
sb.User, err = user.GetFromAuth(auth)
return
}
@ -163,261 +169,228 @@ func (sb *Subscription) Delete(s *xorm.Session, auth web.Auth) (err error) {
return
}
func getSubscriberCondForEntities(entityType SubscriptionEntityType, entityIDs []int64) (cond builder.Cond) {
if entityType == SubscriptionEntityProject {
return builder.And(
builder.In("entity_id", entityIDs),
builder.Eq{"entity_type": SubscriptionEntityProject},
)
}
if entityType == SubscriptionEntityTask {
return builder.Or(
builder.And(
builder.In("entity_id", entityIDs),
builder.Eq{"entity_type": SubscriptionEntityTask},
),
builder.And(
builder.Eq{"entity_id": builder.
Select("project_id").
From("tasks").
Where(builder.In("id", entityIDs)),
// TODO parent project
},
builder.Eq{"entity_type": SubscriptionEntityProject},
),
)
}
return
}
// GetSubscription returns a matching subscription for an entity and user.
// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for
// that task, if there is none it will look for a subscription on the project the task belongs to.
func GetSubscription(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *Subscription, err error) {
subs, err := GetSubscriptions(s, entityType, entityID, a)
if err != nil || len(subs) == 0 {
return nil, err
}
return subs[0], nil
}
// GetSubscriptions returns a list of subscriptions to for an entity ID
func GetSubscriptions(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscriptions []*Subscription, err error) {
func GetSubscriptionForUser(s *xorm.Session, entityType SubscriptionEntityType, entityID int64, a web.Auth) (subscription *SubscriptionWithUser, err error) {
u, is := a.(*user.User)
if u != nil && !is {
return
}
subs, err := GetSubscriptionsForEntitiesAndUser(s, entityType, []int64{entityID}, u)
if err != nil || len(subs) == 0 || len(subs[entityID]) == 0 {
return nil, err
}
return subs[entityID][0], nil
}
// GetSubscriptionsForEntities returns a list of subscriptions to for an entity ID
func GetSubscriptionsForEntities(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
return getSubscriptionsForEntitiesAndUser(s, entityType, entityIDs, nil, false)
}
func GetSubscriptionsForEntitiesAndUser(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, u *user.User) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
return getSubscriptionsForEntitiesAndUser(s, entityType, entityIDs, u, true)
}
func GetSubscriptionsForEntity(s *xorm.Session, entityType SubscriptionEntityType, entityID int64) (subscriptions []*SubscriptionWithUser, err error) {
subs, err := GetSubscriptionsForEntities(s, entityType, []int64{entityID})
if err != nil || len(subs[entityID]) == 0 {
return
}
return subs[entityID], nil
}
// This function returns a matching subscription for an entity and user.
// It will return the next parent of a subscription. That means for tasks, it will first look for a subscription for
// that task, if there is none it will look for a subscription on the project the task belongs to.
// It will return a map where the key is the entity id and the value is a slice with all subscriptions for that entity.
func getSubscriptionsForEntitiesAndUser(s *xorm.Session, entityType SubscriptionEntityType, entityIDs []int64, u *user.User, userOnly bool) (subscriptions map[int64][]*SubscriptionWithUser, err error) {
if err := entityType.validate(); err != nil {
return nil, err
}
rawSubscriptions := []*subscriptionResolved{}
entityIDString := utils.JoinInt64Slice(entityIDs, ", ")
var sUserCond string
if userOnly {
if u == nil {
return nil, &ErrMustProvideUser{}
}
sUserCond = " AND s.user_id = " + strconv.FormatInt(u.ID, 10)
}
switch entityType {
case SubscriptionEntityProject:
project, err := GetProjectSimpleByID(s, entityID)
if err != nil {
return nil, err
}
subs, err := GetSubscriptionsForProjects(s, []*Project{project}, u)
if err != nil {
return nil, err
}
if _, has := subs[entityID]; has && subs[entityID] != nil {
return subs[entityID], nil
}
err = s.SQL(`
WITH RECURSIVE project_hierarchy AS (
-- Base case: Start with the specified projects
SELECT
id,
parent_project_id,
0 AS level,
id AS original_project_id
FROM projects
WHERE id IN (`+entityIDString+`)
for _, sub := range subs {
// Fallback to the first non-nil subscription
if len(sub) > 0 {
return sub, nil
}
}
UNION ALL
return nil, nil
-- Recursive case: Get parent projects
SELECT
p.id,
p.parent_project_id,
ph.level + 1,
ph.original_project_id
FROM projects p
INNER JOIN project_hierarchy ph ON p.id = ph.parent_project_id
),
subscription_hierarchy AS (
-- Check for project subscriptions (including parent projects)
SELECT
s.id,
s.entity_type,
s.entity_id,
s.created,
s.user_id,
CASE
WHEN s.entity_id = ph.original_project_id THEN 1 -- Direct project match
ELSE ph.level + 1 -- Parent projects
END AS priority,
ph.original_project_id
FROM subscriptions s
INNER JOIN project_hierarchy ph ON s.entity_id = ph.id
WHERE s.entity_type = ?`+sUserCond+`
)
SELECT
p.id AS original_entity_id,
sh.id AS subscription_id,
sh.entity_type,
sh.entity_id,
sh.created,
sh.user_id,
CASE
WHEN sh.priority = 1 THEN 'Direct Project'
ELSE 'Parent Project'
END
AS subscription_level,
users.*
FROM projects p
LEFT JOIN (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY original_project_id, user_id ORDER BY priority) AS rn
FROM subscription_hierarchy
) sh ON p.id = sh.original_project_id AND sh.rn = 1
LEFT JOIN users ON sh.user_id = users.id
WHERE p.id IN (`+entityIDString+`)
ORDER BY p.id, sh.user_id`, SubscriptionEntityProject).
Find(&rawSubscriptions)
case SubscriptionEntityTask:
subs, err := getSubscriptionsForTask(s, entityID, u)
if err != nil {
return nil, err
}
err = s.SQL(`
WITH RECURSIVE project_hierarchy AS (
-- Base case: Start with the projects associated with the tasks
SELECT
p.id,
p.parent_project_id,
0 AS level,
t.id AS task_id
FROM tasks t
JOIN projects p ON t.project_id = p.id
WHERE t.id IN (`+entityIDString+`)
for _, sub := range subs {
// The subscriptions might also contain the immediate parent subscription, if that exists.
// This loop makes sure to only return the task subscription if it exists. The fallback
// happens in the next if after the loop.
if sub.EntityID == entityID && sub.EntityType == SubscriptionEntityTask {
return []*Subscription{sub}, nil
}
}
UNION ALL
if len(subs) > 0 {
return subs, nil
}
-- Recursive case: Get parent projects
SELECT
p.id,
p.parent_project_id,
ph.level + 1,
ph.task_id
FROM projects p
INNER JOIN project_hierarchy ph ON p.id = ph.parent_project_id
),
projects, err := GetProjectsSimplByTaskIDs(s, []int64{entityID})
if err != nil {
return nil, err
}
subscription_hierarchy AS (
-- Check for task subscriptions
SELECT
s.id,
s.entity_type,
s.entity_id,
s.created,
s.user_id,
1 AS priority,
t.id AS task_id
FROM subscriptions s
JOIN tasks t ON s.entity_id = t.id
WHERE s.entity_type = ? AND t.id IN (`+entityIDString+`)`+sUserCond+`
projectSubscriptions, err := GetSubscriptionsForProjects(s, projects, u)
if err != nil {
return nil, err
}
UNION ALL
if _, has := projectSubscriptions[projects[0].ID]; has {
return projectSubscriptions[projects[0].ID], nil
}
-- Check for project subscriptions (including parent projects)
SELECT
s.id,
s.entity_type,
s.entity_id,
s.created,
s.user_id,
ph.level + 2 AS priority,
ph.task_id
FROM subscriptions s
INNER JOIN project_hierarchy ph ON s.entity_id = ph.id
WHERE s.entity_type = ?
)
for _, psub := range projectSubscriptions {
// Fallback to the first non-nil subscription
if len(psub) > 0 {
return psub, nil
}
}
return subs, nil
SELECT
t.id AS original_entity_id,
sh.id AS subscription_id,
sh.entity_type,
sh.entity_id,
sh.created,
sh.user_id,
CASE
WHEN sh.entity_type = ? THEN 'Task'
WHEN sh.priority = ? THEN 'Direct Project'
ELSE 'Parent Project'
END
AS subscription_level,
users.*
FROM tasks t
LEFT JOIN (
SELECT *,
ROW_NUMBER() OVER (PARTITION BY task_id, user_id ORDER BY priority) AS rn
FROM subscription_hierarchy
) sh ON t.id = sh.task_id AND sh.rn = 1
LEFT JOIN users ON sh.user_id = users.id
WHERE t.id IN (`+entityIDString+`)
ORDER BY t.id, sh.user_id`,
SubscriptionEntityTask, SubscriptionEntityProject, SubscriptionEntityTask, SubscriptionEntityProject).
Find(&rawSubscriptions)
}
if err != nil {
return nil, err
}
return
}
subscriptions = make(map[int64][]*SubscriptionWithUser)
for _, sub := range rawSubscriptions {
func GetSubscriptionsForProjects(s *xorm.Session, projects []*Project, a web.Auth) (projectsToSubscriptions map[int64][]*Subscription, err error) {
u, is := a.(*user.User)
if u != nil && !is {
return
}
var ps = make(map[int64]*Project)
origProjectIDs := make([]int64, 0, len(projects))
allProjectIDs := make([]int64, 0, len(projects))
for _, p := range projects {
ps[p.ID] = p
origProjectIDs = append(origProjectIDs, p.ID)
allProjectIDs = append(allProjectIDs, p.ID)
}
// We can't just use the projects we have, we need to fetch the parents
// because they may not be loaded in the same object
for _, p := range projects {
if p.ParentProjectID == 0 {
if sub.Subscription.EntityID == 0 {
continue
}
if _, has := ps[p.ParentProjectID]; has {
continue
_, has := subscriptions[sub.OriginalEntityID]
if !has {
subscriptions[sub.OriginalEntityID] = []*SubscriptionWithUser{}
}
parents, err := GetAllParentProjects(s, p.ID)
if err != nil {
return nil, err
sub.Subscription.ID = sub.SubscriptionID
if sub.User != nil {
sub.User.ID = sub.UserID
}
// Walk the tree up until we reach the top
var parent = parents[p.ParentProjectID] // parent now has a pointer…
ps[p.ID].ParentProject = parents[p.ParentProjectID]
for parent != nil {
allProjectIDs = append(allProjectIDs, parent.ID)
parent = parents[parent.ParentProjectID] // … which means we can update it here and then update the pointer in the map
}
subscriptions[sub.OriginalEntityID] = append(subscriptions[sub.OriginalEntityID], &sub.SubscriptionWithUser)
}
var subscriptions []*Subscription
if u != nil {
err = s.
Where("user_id = ?", u.ID).
And(getSubscriberCondForEntities(SubscriptionEntityProject, allProjectIDs)).
Find(&subscriptions)
} else {
err = s.
And(getSubscriberCondForEntities(SubscriptionEntityProject, allProjectIDs)).
Find(&subscriptions)
}
if err != nil {
return nil, err
}
projectsToSubscriptions = make(map[int64][]*Subscription)
for _, sub := range subscriptions {
sub.Entity = sub.EntityType.String()
projectsToSubscriptions[sub.EntityID] = append(projectsToSubscriptions[sub.EntityID], sub)
}
// Rearrange so that subscriptions trickle down
for _, eID := range origProjectIDs {
// If the current project does not have a subscription, climb up the tree until a project has one,
// then use that subscription for all child projects
_, has := projectsToSubscriptions[eID]
_, hasProject := ps[eID]
if !has && hasProject {
_, exists := ps[eID]
if !exists {
continue
}
var parent = ps[eID].ParentProject
for parent != nil {
sub, has := projectsToSubscriptions[parent.ID]
projectsToSubscriptions[eID] = sub
parent = parent.ParentProject
if has { // reached the top of the tree
break
}
}
}
}
return projectsToSubscriptions, nil
}
func getSubscriptionsForTask(s *xorm.Session, taskID int64, u *user.User) (subscriptions []*Subscription, err error) {
if u != nil {
err = s.
Where("user_id = ?", u.ID).
And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
Find(&subscriptions)
} else {
err = s.
And(getSubscriberCondForEntities(SubscriptionEntityTask, []int64{taskID})).
Find(&subscriptions)
}
if err != nil {
return nil, err
}
for _, sub := range subscriptions {
sub.Entity = sub.EntityType.String()
}
return
}
func getSubscribersForEntity(s *xorm.Session, entityType SubscriptionEntityType, entityID int64) (subscriptions []*Subscription, err error) {
if err := entityType.validate(); err != nil {
return nil, err
}
subs, err := GetSubscriptions(s, entityType, entityID, nil)
if err != nil {
return
}
userIDs := []int64{}
subscriptions = make([]*Subscription, 0, len(subs))
for _, subscription := range subs {
userIDs = append(userIDs, subscription.UserID)
subscriptions = append(subscriptions, subscription)
}
users, err := user.GetUsersByIDs(s, userIDs)
if err != nil {
return
}
for _, subscription := range subscriptions {
subscription.User = users[subscription.UserID]
}
return
return subscriptions, nil
}