Move db related basic functions to models/db (#17075)

* Move db related basic functions to models/db

* Fix lint

* Fix lint

* Fix test

* Fix lint

* Fix lint

* revert unnecessary change

* Fix test

* Fix wrong replace string

* Use *Context

* Correct committer spelling and fix wrong replaced words

Co-authored-by: zeripath <art27@cantab.net>
This commit is contained in:
Lunny Xiao
2021-09-19 19:49:59 +08:00
committed by GitHub
parent 462306e263
commit a4bfef265d
335 changed files with 4191 additions and 3654 deletions
+14 -9
View File
@@ -8,6 +8,7 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
)
@@ -71,7 +72,11 @@ type Access struct {
Mode AccessMode
}
func accessLevel(e Engine, user *User, repo *Repository) (AccessMode, error) {
func init() {
db.RegisterModel(new(Access))
}
func accessLevel(e db.Engine, user *User, repo *Repository) (AccessMode, error) {
mode := AccessModeNone
var userID int64
restricted := false
@@ -111,7 +116,7 @@ func (repoAccess) TableName() string {
// GetRepositoryAccesses finds all repositories with their access mode where a user has access but does not own.
func (user *User) GetRepositoryAccesses() (map[*Repository]AccessMode, error) {
rows, err := x.
rows, err := db.DefaultContext().Engine().
Join("INNER", "repository", "repository.id = access.repo_id").
Where("access.user_id = ?", user.ID).
And("repository.owner_id <> ?", user.ID).
@@ -146,7 +151,7 @@ func (user *User) GetRepositoryAccesses() (map[*Repository]AccessMode, error) {
// GetAccessibleRepositories finds repositories which the user has access but does not own.
// If limit is smaller than 1 means returns all found results.
func (user *User) GetAccessibleRepositories(limit int) (repos []*Repository, _ error) {
sess := x.
sess := db.DefaultContext().Engine().
Where("owner_id !=? ", user.ID).
Desc("updated_unix")
if limit > 0 {
@@ -185,7 +190,7 @@ func updateUserAccess(accessMap map[int64]*userAccess, user *User, mode AccessMo
}
// FIXME: do cross-comparison so reduce deletions and additions to the minimum?
func (repo *Repository) refreshAccesses(e Engine, accessMap map[int64]*userAccess) (err error) {
func (repo *Repository) refreshAccesses(e db.Engine, accessMap map[int64]*userAccess) (err error) {
minMode := AccessModeRead
if !repo.IsPrivate {
minMode = AccessModeWrite
@@ -219,7 +224,7 @@ func (repo *Repository) refreshAccesses(e Engine, accessMap map[int64]*userAcces
}
// refreshCollaboratorAccesses retrieves repository collaborations with their access modes.
func (repo *Repository) refreshCollaboratorAccesses(e Engine, accessMap map[int64]*userAccess) error {
func (repo *Repository) refreshCollaboratorAccesses(e db.Engine, accessMap map[int64]*userAccess) error {
collaborators, err := repo.getCollaborators(e, ListOptions{})
if err != nil {
return fmt.Errorf("getCollaborations: %v", err)
@@ -233,7 +238,7 @@ func (repo *Repository) refreshCollaboratorAccesses(e Engine, accessMap map[int6
// recalculateTeamAccesses recalculates new accesses for teams of an organization
// except the team whose ID is given. It is used to assign a team ID when
// remove repository from that team.
func (repo *Repository) recalculateTeamAccesses(e Engine, ignTeamID int64) (err error) {
func (repo *Repository) recalculateTeamAccesses(e db.Engine, ignTeamID int64) (err error) {
accessMap := make(map[int64]*userAccess, 20)
if err = repo.getOwner(e); err != nil {
@@ -276,7 +281,7 @@ func (repo *Repository) recalculateTeamAccesses(e Engine, ignTeamID int64) (err
// recalculateUserAccess recalculates new access for a single user
// Usable if we know access only affected one user
func (repo *Repository) recalculateUserAccess(e Engine, uid int64) (err error) {
func (repo *Repository) recalculateUserAccess(e db.Engine, uid int64) (err error) {
minMode := AccessModeRead
if !repo.IsPrivate {
minMode = AccessModeWrite
@@ -323,7 +328,7 @@ func (repo *Repository) recalculateUserAccess(e Engine, uid int64) (err error) {
return nil
}
func (repo *Repository) recalculateAccesses(e Engine) error {
func (repo *Repository) recalculateAccesses(e db.Engine) error {
if repo.Owner.IsOrganization() {
return repo.recalculateTeamAccesses(e, 0)
}
@@ -337,5 +342,5 @@ func (repo *Repository) recalculateAccesses(e Engine) error {
// RecalculateAccesses recalculates all accesses for repository.
func (repo *Repository) RecalculateAccesses() error {
return repo.recalculateAccesses(x)
return repo.recalculateAccesses(db.DefaultContext().Engine())
}
+34 -33
View File
@@ -7,27 +7,28 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestAccessLevel(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user2 := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user5 := AssertExistsAndLoadBean(t, &User{ID: 5}).(*User)
user29 := AssertExistsAndLoadBean(t, &User{ID: 29}).(*User)
user2 := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user5 := db.AssertExistsAndLoadBean(t, &User{ID: 5}).(*User)
user29 := db.AssertExistsAndLoadBean(t, &User{ID: 29}).(*User)
// A public repository owned by User 2
repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
assert.False(t, repo1.IsPrivate)
// A private repository owned by Org 3
repo3 := AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
repo3 := db.AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
assert.True(t, repo3.IsPrivate)
// Another public repository
repo4 := AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository)
repo4 := db.AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository)
assert.False(t, repo4.IsPrivate)
// org. owned private repo
repo24 := AssertExistsAndLoadBean(t, &Repository{ID: 24}).(*Repository)
repo24 := db.AssertExistsAndLoadBean(t, &Repository{ID: 24}).(*Repository)
level, err := AccessLevel(user2, repo1)
assert.NoError(t, err)
@@ -62,15 +63,15 @@ func TestAccessLevel(t *testing.T) {
}
func TestHasAccess(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user2 := AssertExistsAndLoadBean(t, &User{ID: 5}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user2 := db.AssertExistsAndLoadBean(t, &User{ID: 5}).(*User)
// A public repository owned by User 2
repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
assert.False(t, repo1.IsPrivate)
// A private repository owned by Org 3
repo2 := AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
repo2 := db.AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
assert.True(t, repo2.IsPrivate)
has, err := HasAccess(user1.ID, repo1)
@@ -88,33 +89,33 @@ func TestHasAccess(t *testing.T) {
}
func TestUser_GetRepositoryAccesses(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
accesses, err := user1.GetRepositoryAccesses()
assert.NoError(t, err)
assert.Len(t, accesses, 0)
user29 := AssertExistsAndLoadBean(t, &User{ID: 29}).(*User)
user29 := db.AssertExistsAndLoadBean(t, &User{ID: 29}).(*User)
accesses, err = user29.GetRepositoryAccesses()
assert.NoError(t, err)
assert.Len(t, accesses, 2)
}
func TestUser_GetAccessibleRepositories(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
repos, err := user1.GetAccessibleRepositories(0)
assert.NoError(t, err)
assert.Len(t, repos, 0)
user2 := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user2 := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
repos, err = user2.GetAccessibleRepositories(0)
assert.NoError(t, err)
assert.Len(t, repos, 4)
user29 := AssertExistsAndLoadBean(t, &User{ID: 29}).(*User)
user29 := db.AssertExistsAndLoadBean(t, &User{ID: 29}).(*User)
repos, err = user29.GetAccessibleRepositories(0)
assert.NoError(t, err)
assert.Len(t, repos, 2)
@@ -122,16 +123,16 @@ func TestUser_GetAccessibleRepositories(t *testing.T) {
func TestRepository_RecalculateAccesses(t *testing.T) {
// test with organization repo
assert.NoError(t, PrepareTestDatabase())
repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
assert.NoError(t, db.PrepareTestDatabase())
repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
assert.NoError(t, repo1.GetOwner())
_, err := x.Delete(&Collaboration{UserID: 2, RepoID: 3})
_, err := db.DefaultContext().Engine().Delete(&Collaboration{UserID: 2, RepoID: 3})
assert.NoError(t, err)
assert.NoError(t, repo1.RecalculateAccesses())
access := &Access{UserID: 2, RepoID: 3}
has, err := x.Get(access)
has, err := db.DefaultContext().Engine().Get(access)
assert.NoError(t, err)
assert.True(t, has)
assert.Equal(t, AccessModeOwner, access.Mode)
@@ -139,25 +140,25 @@ func TestRepository_RecalculateAccesses(t *testing.T) {
func TestRepository_RecalculateAccesses2(t *testing.T) {
// test with non-organization repo
assert.NoError(t, PrepareTestDatabase())
repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository)
assert.NoError(t, db.PrepareTestDatabase())
repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository)
assert.NoError(t, repo1.GetOwner())
_, err := x.Delete(&Collaboration{UserID: 4, RepoID: 4})
_, err := db.DefaultContext().Engine().Delete(&Collaboration{UserID: 4, RepoID: 4})
assert.NoError(t, err)
assert.NoError(t, repo1.RecalculateAccesses())
has, err := x.Get(&Access{UserID: 4, RepoID: 4})
has, err := db.DefaultContext().Engine().Get(&Access{UserID: 4, RepoID: 4})
assert.NoError(t, err)
assert.False(t, has)
}
func TestRepository_RecalculateAccesses3(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
team5 := AssertExistsAndLoadBean(t, &Team{ID: 5}).(*Team)
user29 := AssertExistsAndLoadBean(t, &User{ID: 29}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
team5 := db.AssertExistsAndLoadBean(t, &Team{ID: 5}).(*Team)
user29 := db.AssertExistsAndLoadBean(t, &User{ID: 29}).(*User)
has, err := x.Get(&Access{UserID: 29, RepoID: 23})
has, err := db.DefaultContext().Engine().Get(&Access{UserID: 29, RepoID: 23})
assert.NoError(t, err)
assert.False(t, has)
@@ -165,7 +166,7 @@ func TestRepository_RecalculateAccesses3(t *testing.T) {
// even though repo 23 is public
assert.NoError(t, AddTeamMember(team5, user29.ID))
has, err = x.Get(&Access{UserID: 29, RepoID: 23})
has, err = db.DefaultContext().Engine().Get(&Access{UserID: 29, RepoID: 23})
assert.NoError(t, err)
assert.True(t, has)
}
+9 -4
View File
@@ -12,6 +12,7 @@ import (
"strings"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
@@ -74,6 +75,10 @@ type Action struct {
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
}
func init() {
db.RegisterModel(new(Action))
}
// GetOpType gets the ActionType of this action.
func (a *Action) GetOpType() ActionType {
return a.OpType
@@ -203,10 +208,10 @@ func GetRepositoryFromMatch(ownerName, repoName string) (*Repository, error) {
// GetCommentLink returns link to action comment.
func (a *Action) GetCommentLink() string {
return a.getCommentLink(x)
return a.getCommentLink(db.DefaultContext().Engine())
}
func (a *Action) getCommentLink(e Engine) string {
func (a *Action) getCommentLink(e db.Engine) string {
if a == nil {
return "#"
}
@@ -312,7 +317,7 @@ func GetFeeds(opts GetFeedsOptions) ([]*Action, error) {
actions := make([]*Action, 0, setting.UI.FeedPagingNum)
if err := x.Limit(setting.UI.FeedPagingNum).Desc("id").Where(cond).Find(&actions); err != nil {
if err := db.DefaultContext().Engine().Limit(setting.UI.FeedPagingNum).Desc("id").Where(cond).Find(&actions); err != nil {
return nil, fmt.Errorf("Find: %v", err)
}
@@ -403,6 +408,6 @@ func DeleteOldActions(olderThan time.Duration) (err error) {
return nil
}
_, err = x.Where("created_unix < ?", time.Now().Add(-olderThan).Unix()).Delete(&Action{})
_, err = db.DefaultContext().Engine().Where("created_unix < ?", time.Now().Add(-olderThan).Unix()).Delete(&Action{})
return
}
+11 -7
View File
@@ -4,7 +4,11 @@
package models
import "fmt"
import (
"fmt"
"code.gitea.io/gitea/models/db"
)
// ActionList defines a list of actions
type ActionList []*Action
@@ -19,7 +23,7 @@ func (actions ActionList) getUserIDs() []int64 {
return keysInt64(userIDs)
}
func (actions ActionList) loadUsers(e Engine) ([]*User, error) {
func (actions ActionList) loadUsers(e db.Engine) ([]*User, error) {
if len(actions) == 0 {
return nil, nil
}
@@ -41,7 +45,7 @@ func (actions ActionList) loadUsers(e Engine) ([]*User, error) {
// LoadUsers loads actions' all users
func (actions ActionList) LoadUsers() ([]*User, error) {
return actions.loadUsers(x)
return actions.loadUsers(db.DefaultContext().Engine())
}
func (actions ActionList) getRepoIDs() []int64 {
@@ -54,7 +58,7 @@ func (actions ActionList) getRepoIDs() []int64 {
return keysInt64(repoIDs)
}
func (actions ActionList) loadRepositories(e Engine) ([]*Repository, error) {
func (actions ActionList) loadRepositories(e db.Engine) ([]*Repository, error) {
if len(actions) == 0 {
return nil, nil
}
@@ -76,11 +80,11 @@ func (actions ActionList) loadRepositories(e Engine) ([]*Repository, error) {
// LoadRepositories loads actions' all repositories
func (actions ActionList) LoadRepositories() ([]*Repository, error) {
return actions.loadRepositories(x)
return actions.loadRepositories(db.DefaultContext().Engine())
}
// loadAttributes loads all attributes
func (actions ActionList) loadAttributes(e Engine) (err error) {
func (actions ActionList) loadAttributes(e db.Engine) (err error) {
if _, err = actions.loadUsers(e); err != nil {
return
}
@@ -94,5 +98,5 @@ func (actions ActionList) loadAttributes(e Engine) (err error) {
// LoadAttributes loads attributes of the actions
func (actions ActionList) LoadAttributes() error {
return actions.loadAttributes(x)
return actions.loadAttributes(db.DefaultContext().Engine())
}
+12 -11
View File
@@ -8,23 +8,24 @@ import (
"path"
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"github.com/stretchr/testify/assert"
)
func TestAction_GetRepoPath(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{}).(*Repository)
owner := AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
repo := db.AssertExistsAndLoadBean(t, &Repository{}).(*Repository)
owner := db.AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User)
action := &Action{RepoID: repo.ID}
assert.Equal(t, path.Join(owner.Name, repo.Name), action.GetRepoPath())
}
func TestAction_GetRepoLink(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{}).(*Repository)
owner := AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
repo := db.AssertExistsAndLoadBean(t, &Repository{}).(*Repository)
owner := db.AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User)
action := &Action{RepoID: repo.ID}
setting.AppSubURL = "/suburl"
expected := path.Join(setting.AppSubURL, owner.Name, repo.Name)
@@ -33,8 +34,8 @@ func TestAction_GetRepoLink(t *testing.T) {
func TestGetFeeds(t *testing.T) {
// test with an individual user
assert.NoError(t, PrepareTestDatabase())
user := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
actions, err := GetFeeds(GetFeedsOptions{
RequestedUser: user,
@@ -61,9 +62,9 @@ func TestGetFeeds(t *testing.T) {
func TestGetFeeds2(t *testing.T) {
// test with an organization user
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
user := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
actions, err := GetFeeds(GetFeedsOptions{
RequestedUser: org,
+19 -14
View File
@@ -8,6 +8,7 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/storage"
"code.gitea.io/gitea/modules/timeutil"
@@ -32,6 +33,10 @@ type Notice struct {
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
}
func init() {
db.RegisterModel(new(Notice))
}
// TrStr returns a translation format string.
func (n *Notice) TrStr() string {
return fmt.Sprintf("admin.notices.type_%d", n.Type)
@@ -39,10 +44,10 @@ func (n *Notice) TrStr() string {
// CreateNotice creates new system notice.
func CreateNotice(tp NoticeType, desc string, args ...interface{}) error {
return createNotice(x, tp, desc, args...)
return createNotice(db.DefaultContext().Engine(), tp, desc, args...)
}
func createNotice(e Engine, tp NoticeType, desc string, args ...interface{}) error {
func createNotice(e db.Engine, tp NoticeType, desc string, args ...interface{}) error {
if len(args) > 0 {
desc = fmt.Sprintf(desc, args...)
}
@@ -56,22 +61,22 @@ func createNotice(e Engine, tp NoticeType, desc string, args ...interface{}) err
// CreateRepositoryNotice creates new system notice with type NoticeRepository.
func CreateRepositoryNotice(desc string, args ...interface{}) error {
return createNotice(x, NoticeRepository, desc, args...)
return createNotice(db.DefaultContext().Engine(), NoticeRepository, desc, args...)
}
// RemoveAllWithNotice removes all directories in given path and
// creates a system notice when error occurs.
func RemoveAllWithNotice(title, path string) {
removeAllWithNotice(x, title, path)
removeAllWithNotice(db.DefaultContext().Engine(), title, path)
}
// RemoveStorageWithNotice removes a file from the storage and
// creates a system notice when error occurs.
func RemoveStorageWithNotice(bucket storage.ObjectStorage, title, path string) {
removeStorageWithNotice(x, bucket, title, path)
removeStorageWithNotice(db.DefaultContext().Engine(), bucket, title, path)
}
func removeStorageWithNotice(e Engine, bucket storage.ObjectStorage, title, path string) {
func removeStorageWithNotice(e db.Engine, bucket storage.ObjectStorage, title, path string) {
if err := bucket.Delete(path); err != nil {
desc := fmt.Sprintf("%s [%s]: %v", title, path, err)
log.Warn(title+" [%s]: %v", path, err)
@@ -81,7 +86,7 @@ func removeStorageWithNotice(e Engine, bucket storage.ObjectStorage, title, path
}
}
func removeAllWithNotice(e Engine, title, path string) {
func removeAllWithNotice(e db.Engine, title, path string) {
if err := util.RemoveAll(path); err != nil {
desc := fmt.Sprintf("%s [%s]: %v", title, path, err)
log.Warn(title+" [%s]: %v", path, err)
@@ -93,14 +98,14 @@ func removeAllWithNotice(e Engine, title, path string) {
// CountNotices returns number of notices.
func CountNotices() int64 {
count, _ := x.Count(new(Notice))
count, _ := db.DefaultContext().Engine().Count(new(Notice))
return count
}
// Notices returns notices in given page.
func Notices(page, pageSize int) ([]*Notice, error) {
notices := make([]*Notice, 0, pageSize)
return notices, x.
return notices, db.DefaultContext().Engine().
Limit(pageSize, (page-1)*pageSize).
Desc("id").
Find(&notices)
@@ -108,18 +113,18 @@ func Notices(page, pageSize int) ([]*Notice, error) {
// DeleteNotice deletes a system notice by given ID.
func DeleteNotice(id int64) error {
_, err := x.ID(id).Delete(new(Notice))
_, err := db.DefaultContext().Engine().ID(id).Delete(new(Notice))
return err
}
// DeleteNotices deletes all notices with ID from start to end (inclusive).
func DeleteNotices(start, end int64) error {
if start == 0 && end == 0 {
_, err := x.Exec("DELETE FROM notice")
_, err := db.DefaultContext().Engine().Exec("DELETE FROM notice")
return err
}
sess := x.Where("id >= ?", start)
sess := db.DefaultContext().Engine().Where("id >= ?", start)
if end > 0 {
sess.And("id <= ?", end)
}
@@ -132,7 +137,7 @@ func DeleteNoticesByIDs(ids []int64) error {
if len(ids) == 0 {
return nil
}
_, err := x.
_, err := db.DefaultContext().Engine().
In("id", ids).
Delete(new(Notice))
return err
@@ -141,7 +146,7 @@ func DeleteNoticesByIDs(ids []int64) error {
// GetAdminUser returns the first administrator
func GetAdminUser() (*User, error) {
var admin User
has, err := x.Where("is_admin=?", true).Get(&admin)
has, err := db.DefaultContext().Engine().Where("is_admin=?", true).Get(&admin)
if err != nil {
return nil, err
} else if !has {
+33 -32
View File
@@ -7,6 +7,7 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
@@ -19,38 +20,38 @@ func TestNotice_TrStr(t *testing.T) {
}
func TestCreateNotice(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
noticeBean := &Notice{
Type: NoticeRepository,
Description: "test description",
}
AssertNotExistsBean(t, noticeBean)
db.AssertNotExistsBean(t, noticeBean)
assert.NoError(t, CreateNotice(noticeBean.Type, noticeBean.Description))
AssertExistsAndLoadBean(t, noticeBean)
db.AssertExistsAndLoadBean(t, noticeBean)
}
func TestCreateRepositoryNotice(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
noticeBean := &Notice{
Type: NoticeRepository,
Description: "test description",
}
AssertNotExistsBean(t, noticeBean)
db.AssertNotExistsBean(t, noticeBean)
assert.NoError(t, CreateRepositoryNotice(noticeBean.Description))
AssertExistsAndLoadBean(t, noticeBean)
db.AssertExistsAndLoadBean(t, noticeBean)
}
// TODO TestRemoveAllWithNotice
func TestCountNotices(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
assert.Equal(t, int64(3), CountNotices())
}
func TestNotices(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
notices, err := Notices(1, 2)
assert.NoError(t, err)
@@ -67,47 +68,47 @@ func TestNotices(t *testing.T) {
}
func TestDeleteNotice(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
AssertExistsAndLoadBean(t, &Notice{ID: 3})
db.AssertExistsAndLoadBean(t, &Notice{ID: 3})
assert.NoError(t, DeleteNotice(3))
AssertNotExistsBean(t, &Notice{ID: 3})
db.AssertNotExistsBean(t, &Notice{ID: 3})
}
func TestDeleteNotices(t *testing.T) {
// delete a non-empty range
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
AssertExistsAndLoadBean(t, &Notice{ID: 1})
AssertExistsAndLoadBean(t, &Notice{ID: 2})
AssertExistsAndLoadBean(t, &Notice{ID: 3})
db.AssertExistsAndLoadBean(t, &Notice{ID: 1})
db.AssertExistsAndLoadBean(t, &Notice{ID: 2})
db.AssertExistsAndLoadBean(t, &Notice{ID: 3})
assert.NoError(t, DeleteNotices(1, 2))
AssertNotExistsBean(t, &Notice{ID: 1})
AssertNotExistsBean(t, &Notice{ID: 2})
AssertExistsAndLoadBean(t, &Notice{ID: 3})
db.AssertNotExistsBean(t, &Notice{ID: 1})
db.AssertNotExistsBean(t, &Notice{ID: 2})
db.AssertExistsAndLoadBean(t, &Notice{ID: 3})
}
func TestDeleteNotices2(t *testing.T) {
// delete an empty range
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
AssertExistsAndLoadBean(t, &Notice{ID: 1})
AssertExistsAndLoadBean(t, &Notice{ID: 2})
AssertExistsAndLoadBean(t, &Notice{ID: 3})
db.AssertExistsAndLoadBean(t, &Notice{ID: 1})
db.AssertExistsAndLoadBean(t, &Notice{ID: 2})
db.AssertExistsAndLoadBean(t, &Notice{ID: 3})
assert.NoError(t, DeleteNotices(3, 2))
AssertExistsAndLoadBean(t, &Notice{ID: 1})
AssertExistsAndLoadBean(t, &Notice{ID: 2})
AssertExistsAndLoadBean(t, &Notice{ID: 3})
db.AssertExistsAndLoadBean(t, &Notice{ID: 1})
db.AssertExistsAndLoadBean(t, &Notice{ID: 2})
db.AssertExistsAndLoadBean(t, &Notice{ID: 3})
}
func TestDeleteNoticesByIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
AssertExistsAndLoadBean(t, &Notice{ID: 1})
AssertExistsAndLoadBean(t, &Notice{ID: 2})
AssertExistsAndLoadBean(t, &Notice{ID: 3})
db.AssertExistsAndLoadBean(t, &Notice{ID: 1})
db.AssertExistsAndLoadBean(t, &Notice{ID: 2})
db.AssertExistsAndLoadBean(t, &Notice{ID: 3})
assert.NoError(t, DeleteNoticesByIDs([]int64{1, 3}))
AssertNotExistsBean(t, &Notice{ID: 1})
AssertExistsAndLoadBean(t, &Notice{ID: 2})
AssertNotExistsBean(t, &Notice{ID: 3})
db.AssertNotExistsBean(t, &Notice{ID: 1})
db.AssertExistsAndLoadBean(t, &Notice{ID: 2})
db.AssertNotExistsBean(t, &Notice{ID: 3})
}
+33 -28
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"path"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/storage"
"code.gitea.io/gitea/modules/timeutil"
@@ -30,10 +31,14 @@ type Attachment struct {
CreatedUnix timeutil.TimeStamp `xorm:"created"`
}
func init() {
db.RegisterModel(new(Attachment))
}
// IncreaseDownloadCount is update download count + 1
func (a *Attachment) IncreaseDownloadCount() error {
// Update download count.
if _, err := x.Exec("UPDATE `attachment` SET download_count=download_count+1 WHERE id=?", a.ID); err != nil {
if _, err := db.DefaultContext().Engine().Exec("UPDATE `attachment` SET download_count=download_count+1 WHERE id=?", a.ID); err != nil {
return fmt.Errorf("increase attachment count: %v", err)
}
@@ -81,10 +86,10 @@ func (a *Attachment) LinkedRepository() (*Repository, UnitType, error) {
// GetAttachmentByID returns attachment by given id
func GetAttachmentByID(id int64) (*Attachment, error) {
return getAttachmentByID(x, id)
return getAttachmentByID(db.DefaultContext().Engine(), id)
}
func getAttachmentByID(e Engine, id int64) (*Attachment, error) {
func getAttachmentByID(e db.Engine, id int64) (*Attachment, error) {
attach := &Attachment{}
if has, err := e.ID(id).Get(attach); err != nil {
return nil, err
@@ -94,7 +99,7 @@ func getAttachmentByID(e Engine, id int64) (*Attachment, error) {
return attach, nil
}
func getAttachmentByUUID(e Engine, uuid string) (*Attachment, error) {
func getAttachmentByUUID(e db.Engine, uuid string) (*Attachment, error) {
attach := &Attachment{}
has, err := e.Where("uuid=?", uuid).Get(attach)
if err != nil {
@@ -106,11 +111,11 @@ func getAttachmentByUUID(e Engine, uuid string) (*Attachment, error) {
}
// GetAttachmentsByUUIDs returns attachment by given UUID list.
func GetAttachmentsByUUIDs(ctx DBContext, uuids []string) ([]*Attachment, error) {
return getAttachmentsByUUIDs(ctx.e, uuids)
func GetAttachmentsByUUIDs(ctx *db.Context, uuids []string) ([]*Attachment, error) {
return getAttachmentsByUUIDs(ctx.Engine(), uuids)
}
func getAttachmentsByUUIDs(e Engine, uuids []string) ([]*Attachment, error) {
func getAttachmentsByUUIDs(e db.Engine, uuids []string) ([]*Attachment, error) {
if len(uuids) == 0 {
return []*Attachment{}, nil
}
@@ -122,41 +127,41 @@ func getAttachmentsByUUIDs(e Engine, uuids []string) ([]*Attachment, error) {
// GetAttachmentByUUID returns attachment by given UUID.
func GetAttachmentByUUID(uuid string) (*Attachment, error) {
return getAttachmentByUUID(x, uuid)
return getAttachmentByUUID(db.DefaultContext().Engine(), uuid)
}
// ExistAttachmentsByUUID returns true if attachment is exist by given UUID
func ExistAttachmentsByUUID(uuid string) (bool, error) {
return x.Where("`uuid`=?", uuid).Exist(new(Attachment))
return db.DefaultContext().Engine().Where("`uuid`=?", uuid).Exist(new(Attachment))
}
// GetAttachmentByReleaseIDFileName returns attachment by given releaseId and fileName.
func GetAttachmentByReleaseIDFileName(releaseID int64, fileName string) (*Attachment, error) {
return getAttachmentByReleaseIDFileName(x, releaseID, fileName)
return getAttachmentByReleaseIDFileName(db.DefaultContext().Engine(), releaseID, fileName)
}
func getAttachmentsByIssueID(e Engine, issueID int64) ([]*Attachment, error) {
func getAttachmentsByIssueID(e db.Engine, issueID int64) ([]*Attachment, error) {
attachments := make([]*Attachment, 0, 10)
return attachments, e.Where("issue_id = ? AND comment_id = 0", issueID).Find(&attachments)
}
// GetAttachmentsByIssueID returns all attachments of an issue.
func GetAttachmentsByIssueID(issueID int64) ([]*Attachment, error) {
return getAttachmentsByIssueID(x, issueID)
return getAttachmentsByIssueID(db.DefaultContext().Engine(), issueID)
}
// GetAttachmentsByCommentID returns all attachments if comment by given ID.
func GetAttachmentsByCommentID(commentID int64) ([]*Attachment, error) {
return getAttachmentsByCommentID(x, commentID)
return getAttachmentsByCommentID(db.DefaultContext().Engine(), commentID)
}
func getAttachmentsByCommentID(e Engine, commentID int64) ([]*Attachment, error) {
func getAttachmentsByCommentID(e db.Engine, commentID int64) ([]*Attachment, error) {
attachments := make([]*Attachment, 0, 10)
return attachments, e.Where("comment_id=?", commentID).Find(&attachments)
}
// getAttachmentByReleaseIDFileName return a file based on the the following infos:
func getAttachmentByReleaseIDFileName(e Engine, releaseID int64, fileName string) (*Attachment, error) {
func getAttachmentByReleaseIDFileName(e db.Engine, releaseID int64, fileName string) (*Attachment, error) {
attach := &Attachment{ReleaseID: releaseID, Name: fileName}
has, err := e.Get(attach)
if err != nil {
@@ -169,12 +174,12 @@ func getAttachmentByReleaseIDFileName(e Engine, releaseID int64, fileName string
// DeleteAttachment deletes the given attachment and optionally the associated file.
func DeleteAttachment(a *Attachment, remove bool) error {
_, err := DeleteAttachments(DefaultDBContext(), []*Attachment{a}, remove)
_, err := DeleteAttachments(db.DefaultContext(), []*Attachment{a}, remove)
return err
}
// DeleteAttachments deletes the given attachments and optionally the associated files.
func DeleteAttachments(ctx DBContext, attachments []*Attachment, remove bool) (int, error) {
func DeleteAttachments(ctx *db.Context, attachments []*Attachment, remove bool) (int, error) {
if len(attachments) == 0 {
return 0, nil
}
@@ -184,7 +189,7 @@ func DeleteAttachments(ctx DBContext, attachments []*Attachment, remove bool) (i
ids = append(ids, a.ID)
}
cnt, err := ctx.e.In("id", ids).NoAutoCondition().Delete(attachments[0])
cnt, err := ctx.Engine().In("id", ids).NoAutoCondition().Delete(attachments[0])
if err != nil {
return 0, err
}
@@ -206,7 +211,7 @@ func DeleteAttachmentsByIssue(issueID int64, remove bool) (int, error) {
return 0, err
}
return DeleteAttachments(DefaultDBContext(), attachments, remove)
return DeleteAttachments(db.DefaultContext(), attachments, remove)
}
// DeleteAttachmentsByComment deletes all attachments associated with the given comment.
@@ -216,24 +221,24 @@ func DeleteAttachmentsByComment(commentID int64, remove bool) (int, error) {
return 0, err
}
return DeleteAttachments(DefaultDBContext(), attachments, remove)
return DeleteAttachments(db.DefaultContext(), attachments, remove)
}
// UpdateAttachment updates the given attachment in database
func UpdateAttachment(atta *Attachment) error {
return updateAttachment(x, atta)
return updateAttachment(db.DefaultContext().Engine(), atta)
}
// UpdateAttachmentByUUID Updates attachment via uuid
func UpdateAttachmentByUUID(ctx DBContext, attach *Attachment, cols ...string) error {
func UpdateAttachmentByUUID(ctx *db.Context, attach *Attachment, cols ...string) error {
if attach.UUID == "" {
return fmt.Errorf("Attachement uuid should not blank")
}
_, err := ctx.e.Where("uuid=?", attach.UUID).Cols(cols...).Update(attach)
_, err := ctx.Engine().Where("uuid=?", attach.UUID).Cols(cols...).Update(attach)
return err
}
func updateAttachment(e Engine, atta *Attachment) error {
func updateAttachment(e db.Engine, atta *Attachment) error {
var sess *xorm.Session
if atta.ID != 0 && atta.UUID == "" {
sess = e.ID(atta.ID)
@@ -247,7 +252,7 @@ func updateAttachment(e Engine, atta *Attachment) error {
// DeleteAttachmentsByRelease deletes all attachments associated with the given release.
func DeleteAttachmentsByRelease(releaseID int64) error {
_, err := x.Where("release_id = ?", releaseID).Delete(&Attachment{})
_, err := db.DefaultContext().Engine().Where("release_id = ?", releaseID).Delete(&Attachment{})
return err
}
@@ -257,7 +262,7 @@ func IterateAttachment(f func(attach *Attachment) error) error {
const batchSize = 100
for {
attachments := make([]*Attachment, 0, batchSize)
if err := x.Limit(batchSize, start).Find(&attachments); err != nil {
if err := db.DefaultContext().Engine().Limit(batchSize, start).Find(&attachments); err != nil {
return err
}
if len(attachments) == 0 {
@@ -275,13 +280,13 @@ func IterateAttachment(f func(attach *Attachment) error) error {
// CountOrphanedAttachments returns the number of bad attachments
func CountOrphanedAttachments() (int64, error) {
return x.Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))").
return db.DefaultContext().Engine().Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))").
Count(new(Attachment))
}
// DeleteOrphanedAttachments delete all bad attachments
func DeleteOrphanedAttachments() error {
_, err := x.Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))").
_, err := db.DefaultContext().Engine().Where("(issue_id > 0 and issue_id not in (select id from issue)) or (release_id > 0 and release_id not in (select id from `release`))").
Delete(new(Attachment))
return err
}
+10 -9
View File
@@ -7,11 +7,12 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestIncreaseDownloadCount(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
attachment, err := GetAttachmentByUUID("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11")
assert.NoError(t, err)
@@ -27,7 +28,7 @@ func TestIncreaseDownloadCount(t *testing.T) {
}
func TestGetByCommentOrIssueID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
// count of attachments from issue ID
attachments, err := GetAttachmentsByIssueID(1)
@@ -40,7 +41,7 @@ func TestGetByCommentOrIssueID(t *testing.T) {
}
func TestDeleteAttachments(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
count, err := DeleteAttachmentsByIssue(4, false)
assert.NoError(t, err)
@@ -60,7 +61,7 @@ func TestDeleteAttachments(t *testing.T) {
}
func TestGetAttachmentByID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
attach, err := GetAttachmentByID(1)
assert.NoError(t, err)
@@ -76,7 +77,7 @@ func TestAttachment_DownloadURL(t *testing.T) {
}
func TestUpdateAttachment(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
attach, err := GetAttachmentByID(1)
assert.NoError(t, err)
@@ -85,13 +86,13 @@ func TestUpdateAttachment(t *testing.T) {
attach.Name = "new_name"
assert.NoError(t, UpdateAttachment(attach))
AssertExistsAndLoadBean(t, &Attachment{Name: "new_name"})
db.AssertExistsAndLoadBean(t, &Attachment{Name: "new_name"})
}
func TestGetAttachmentsByUUIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
attachList, err := GetAttachmentsByUUIDs(DefaultDBContext(), []string{"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a17", "not-existing-uuid"})
attachList, err := GetAttachmentsByUUIDs(db.DefaultContext(), []string{"a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a17", "not-existing-uuid"})
assert.NoError(t, err)
assert.Len(t, attachList, 2)
assert.Equal(t, "a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11", attachList[0].UUID)
@@ -101,7 +102,7 @@ func TestGetAttachmentsByUUIDs(t *testing.T) {
}
func TestLinkedRepository(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testCases := []struct {
name string
attachID int64
+15 -14
View File
@@ -12,6 +12,7 @@ import (
"strconv"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/cache"
"code.gitea.io/gitea/modules/log"
@@ -24,6 +25,10 @@ type EmailHash struct {
Email string `xorm:"UNIQUE NOT NULL"`
}
func init() {
db.RegisterModel(new(EmailHash))
}
// DefaultAvatarLink the default avatar link
func DefaultAvatarLink() string {
u, err := url.Parse(setting.AppSubURL)
@@ -59,7 +64,7 @@ func GetEmailForHash(md5Sum string) (string, error) {
Hash: strings.ToLower(strings.TrimSpace(md5Sum)),
}
_, err := x.Get(&emailHash)
_, err := db.DefaultContext().Engine().Get(&emailHash)
return emailHash.Email, err
})
}
@@ -90,19 +95,15 @@ func HashedAvatarLink(email string, size int) string {
Hash: sum,
}
// OK we're going to open a session just because I think that that might hide away any problems with postgres reporting errors
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
// we don't care about any DB problem just return the lowerEmail
return lowerEmail, nil
}
has, err := sess.Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
if has || err != nil {
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
return lowerEmail, nil
}
_, _ = sess.Insert(emailHash)
if err := sess.Commit(); err != nil {
if err := db.WithTx(func(ctx *db.Context) error {
has, err := ctx.Engine().Where("email = ? AND hash = ?", emailHash.Email, emailHash.Hash).Get(new(EmailHash))
if has || err != nil {
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
return nil
}
_, _ = ctx.Engine().Insert(emailHash)
return nil
}); err != nil {
// Seriously we don't care about any DB problems just return the lowerEmail - we expect the transaction to fail most of the time
return lowerEmail, nil
}
+27 -42
View File
@@ -10,6 +10,7 @@ import (
"strings"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
@@ -49,6 +50,11 @@ type ProtectedBranch struct {
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
}
func init() {
db.RegisterModel(new(ProtectedBranch))
db.RegisterModel(new(DeletedBranch))
}
// IsProtected returns if the branch is protected
func (protectBranch *ProtectedBranch) IsProtected() bool {
return protectBranch.ID > 0
@@ -116,10 +122,10 @@ func (protectBranch *ProtectedBranch) IsUserMergeWhitelisted(userID int64, permi
// IsUserOfficialReviewer check if user is official reviewer for the branch (counts towards required approvals)
func (protectBranch *ProtectedBranch) IsUserOfficialReviewer(user *User) (bool, error) {
return protectBranch.isUserOfficialReviewer(x, user)
return protectBranch.isUserOfficialReviewer(db.DefaultContext().Engine(), user)
}
func (protectBranch *ProtectedBranch) isUserOfficialReviewer(e Engine, user *User) (bool, error) {
func (protectBranch *ProtectedBranch) isUserOfficialReviewer(e db.Engine, user *User) (bool, error) {
repo, err := getRepositoryByID(e, protectBranch.RepoID)
if err != nil {
return false, err
@@ -156,7 +162,7 @@ func (protectBranch *ProtectedBranch) HasEnoughApprovals(pr *PullRequest) bool {
// GetGrantedApprovalsCount returns the number of granted approvals for pr. A granted approval must be authored by a user in an approval whitelist.
func (protectBranch *ProtectedBranch) GetGrantedApprovalsCount(pr *PullRequest) int64 {
sess := x.Where("issue_id = ?", pr.IssueID).
sess := db.DefaultContext().Engine().Where("issue_id = ?", pr.IssueID).
And("type = ?", ReviewTypeApprove).
And("official = ?", true).
And("dismissed = ?", false)
@@ -177,7 +183,7 @@ func (protectBranch *ProtectedBranch) MergeBlockedByRejectedReview(pr *PullReque
if !protectBranch.BlockOnRejectedReviews {
return false
}
rejectExist, err := x.Where("issue_id = ?", pr.IssueID).
rejectExist, err := db.DefaultContext().Engine().Where("issue_id = ?", pr.IssueID).
And("type = ?", ReviewTypeReject).
And("official = ?", true).
And("dismissed = ?", false).
@@ -196,7 +202,7 @@ func (protectBranch *ProtectedBranch) MergeBlockedByOfficialReviewRequests(pr *P
if !protectBranch.BlockOnOfficialReviewRequests {
return false
}
has, err := x.Where("issue_id = ?", pr.IssueID).
has, err := db.DefaultContext().Engine().Where("issue_id = ?", pr.IssueID).
And("type = ?", ReviewTypeRequest).
And("official = ?", true).
Exist(new(Review))
@@ -294,10 +300,10 @@ func (protectBranch *ProtectedBranch) IsUnprotectedFile(patterns []glob.Glob, pa
// GetProtectedBranchBy getting protected branch by ID/Name
func GetProtectedBranchBy(repoID int64, branchName string) (*ProtectedBranch, error) {
return getProtectedBranchBy(x, repoID, branchName)
return getProtectedBranchBy(db.DefaultContext().Engine(), repoID, branchName)
}
func getProtectedBranchBy(e Engine, repoID int64, branchName string) (*ProtectedBranch, error) {
func getProtectedBranchBy(e db.Engine, repoID int64, branchName string) (*ProtectedBranch, error) {
rel := &ProtectedBranch{RepoID: repoID, BranchName: branchName}
has, err := e.Get(rel)
if err != nil {
@@ -369,13 +375,13 @@ func UpdateProtectBranch(repo *Repository, protectBranch *ProtectedBranch, opts
// Make sure protectBranch.ID is not 0 for whitelists
if protectBranch.ID == 0 {
if _, err = x.Insert(protectBranch); err != nil {
if _, err = db.DefaultContext().Engine().Insert(protectBranch); err != nil {
return fmt.Errorf("Insert: %v", err)
}
return nil
}
if _, err = x.ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil {
if _, err = db.DefaultContext().Engine().ID(protectBranch.ID).AllCols().Update(protectBranch); err != nil {
return fmt.Errorf("Update: %v", err)
}
@@ -385,7 +391,7 @@ func UpdateProtectBranch(repo *Repository, protectBranch *ProtectedBranch, opts
// GetProtectedBranches get all protected branches
func (repo *Repository) GetProtectedBranches() ([]*ProtectedBranch, error) {
protectedBranches := make([]*ProtectedBranch, 0)
return protectedBranches, x.Find(&protectedBranches, &ProtectedBranch{RepoID: repo.ID})
return protectedBranches, db.DefaultContext().Engine().Find(&protectedBranches, &ProtectedBranch{RepoID: repo.ID})
}
// GetBranchProtection get the branch protection of a branch
@@ -400,7 +406,7 @@ func (repo *Repository) IsProtectedBranch(branchName string) (bool, error) {
BranchName: branchName,
}
has, err := x.Exist(protectedBranch)
has, err := db.DefaultContext().Engine().Exist(protectedBranch)
if err != nil {
return true, err
}
@@ -487,19 +493,13 @@ func (repo *Repository) DeleteProtectedBranch(id int64) (err error) {
ID: id,
}
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
}
if affected, err := sess.Delete(protectedBranch); err != nil {
if affected, err := db.DefaultContext().Engine().Delete(protectedBranch); err != nil {
return err
} else if affected != 1 {
return fmt.Errorf("delete protected branch ID(%v) failed", id)
}
return sess.Commit()
return nil
}
// DeletedBranch struct
@@ -522,29 +522,20 @@ func (repo *Repository) AddDeletedBranch(branchName, commit string, deletedByID
DeletedByID: deletedByID,
}
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
if _, err := sess.InsertOne(deletedBranch); err != nil {
return err
}
return sess.Commit()
_, err := db.DefaultContext().Engine().InsertOne(deletedBranch)
return err
}
// GetDeletedBranches returns all the deleted branches
func (repo *Repository) GetDeletedBranches() ([]*DeletedBranch, error) {
deletedBranches := make([]*DeletedBranch, 0)
return deletedBranches, x.Where("repo_id = ?", repo.ID).Desc("deleted_unix").Find(&deletedBranches)
return deletedBranches, db.DefaultContext().Engine().Where("repo_id = ?", repo.ID).Desc("deleted_unix").Find(&deletedBranches)
}
// GetDeletedBranchByID get a deleted branch by its ID
func (repo *Repository) GetDeletedBranchByID(id int64) (*DeletedBranch, error) {
deletedBranch := &DeletedBranch{}
has, err := x.ID(id).Get(deletedBranch)
has, err := db.DefaultContext().Engine().ID(id).Get(deletedBranch)
if err != nil {
return nil, err
}
@@ -561,19 +552,13 @@ func (repo *Repository) RemoveDeletedBranch(id int64) (err error) {
ID: id,
}
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
}
if affected, err := sess.Delete(deletedBranch); err != nil {
if affected, err := db.DefaultContext().Engine().Delete(deletedBranch); err != nil {
return err
} else if affected != 1 {
return fmt.Errorf("remove deleted branch ID(%v) failed", id)
}
return sess.Commit()
return nil
}
// LoadUser loads the user that deleted the branch
@@ -588,7 +573,7 @@ func (deletedBranch *DeletedBranch) LoadUser() {
// RemoveDeletedBranch removes all deleted branches
func RemoveDeletedBranch(repoID int64, branch string) error {
_, err := x.Where("repo_id=? AND name=?", repoID, branch).Delete(new(DeletedBranch))
_, err := db.DefaultContext().Engine().Where("repo_id=? AND name=?", repoID, branch).Delete(new(DeletedBranch))
return err
}
@@ -598,7 +583,7 @@ func RemoveOldDeletedBranches(ctx context.Context, olderThan time.Duration) {
log.Trace("Doing: DeletedBranchesCleanup")
deleteBefore := time.Now().Add(-olderThan)
_, err := x.Where("deleted_unix < ?", deleteBefore.Unix()).Delete(new(DeletedBranch))
_, err := db.DefaultContext().Engine().Where("deleted_unix < ?", deleteBefore.Unix()).Delete(new(DeletedBranch))
if err != nil {
log.Error("DeletedBranchesCleanup: %v", err)
}
+17 -16
View File
@@ -7,21 +7,22 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestAddDeletedBranch(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
firstBranch := AssertExistsAndLoadBean(t, &DeletedBranch{ID: 1}).(*DeletedBranch)
assert.NoError(t, db.PrepareTestDatabase())
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
firstBranch := db.AssertExistsAndLoadBean(t, &DeletedBranch{ID: 1}).(*DeletedBranch)
assert.Error(t, repo.AddDeletedBranch(firstBranch.Name, firstBranch.Commit, firstBranch.DeletedByID))
assert.NoError(t, repo.AddDeletedBranch("test", "5655464564554545466464656", int64(1)))
}
func TestGetDeletedBranches(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
assert.NoError(t, db.PrepareTestDatabase())
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
branches, err := repo.GetDeletedBranches()
assert.NoError(t, err)
@@ -29,17 +30,17 @@ func TestGetDeletedBranches(t *testing.T) {
}
func TestGetDeletedBranch(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
firstBranch := AssertExistsAndLoadBean(t, &DeletedBranch{ID: 1}).(*DeletedBranch)
assert.NoError(t, db.PrepareTestDatabase())
firstBranch := db.AssertExistsAndLoadBean(t, &DeletedBranch{ID: 1}).(*DeletedBranch)
assert.NotNil(t, getDeletedBranch(t, firstBranch))
}
func TestDeletedBranchLoadUser(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
firstBranch := AssertExistsAndLoadBean(t, &DeletedBranch{ID: 1}).(*DeletedBranch)
secondBranch := AssertExistsAndLoadBean(t, &DeletedBranch{ID: 2}).(*DeletedBranch)
firstBranch := db.AssertExistsAndLoadBean(t, &DeletedBranch{ID: 1}).(*DeletedBranch)
secondBranch := db.AssertExistsAndLoadBean(t, &DeletedBranch{ID: 2}).(*DeletedBranch)
branch := getDeletedBranch(t, firstBranch)
assert.Nil(t, branch.DeletedBy)
@@ -55,19 +56,19 @@ func TestDeletedBranchLoadUser(t *testing.T) {
}
func TestRemoveDeletedBranch(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
assert.NoError(t, db.PrepareTestDatabase())
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
firstBranch := AssertExistsAndLoadBean(t, &DeletedBranch{ID: 1}).(*DeletedBranch)
firstBranch := db.AssertExistsAndLoadBean(t, &DeletedBranch{ID: 1}).(*DeletedBranch)
err := repo.RemoveDeletedBranch(1)
assert.NoError(t, err)
AssertNotExistsBean(t, firstBranch)
AssertExistsAndLoadBean(t, &DeletedBranch{ID: 2})
db.AssertNotExistsBean(t, firstBranch)
db.AssertExistsAndLoadBean(t, &DeletedBranch{ID: 2})
}
func getDeletedBranch(t *testing.T, branch *DeletedBranch) *DeletedBranch {
repo := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
deletedBranch, err := repo.GetDeletedBranchByID(branch.ID)
assert.NoError(t, err)
+18 -20
View File
@@ -10,6 +10,7 @@ import (
"strings"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
api "code.gitea.io/gitea/modules/structs"
@@ -37,7 +38,11 @@ type CommitStatus struct {
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func (status *CommitStatus) loadAttributes(e Engine) (err error) {
func init() {
db.RegisterModel(new(CommitStatus))
}
func (status *CommitStatus) loadAttributes(e db.Engine) (err error) {
if status.Repo == nil {
status.Repo, err = getRepositoryByID(e, status.RepoID)
if err != nil {
@@ -55,7 +60,7 @@ func (status *CommitStatus) loadAttributes(e Engine) (err error) {
// APIURL returns the absolute APIURL to this commit-status.
func (status *CommitStatus) APIURL() string {
_ = status.loadAttributes(x)
_ = status.loadAttributes(db.DefaultContext().Engine())
return fmt.Sprintf("%sapi/v1/repos/%s/statuses/%s",
setting.AppURL, status.Repo.FullName(), status.SHA)
}
@@ -112,7 +117,7 @@ func GetCommitStatuses(repo *Repository, sha string, opts *CommitStatusOptions)
}
func listCommitStatusesStatement(repo *Repository, sha string, opts *CommitStatusOptions) *xorm.Session {
sess := x.Where("repo_id = ?", repo.ID).And("sha = ?", sha)
sess := db.DefaultContext().Engine().Where("repo_id = ?", repo.ID).And("sha = ?", sha)
switch opts.State {
case "pending", "success", "error", "failure", "warning":
sess.And("state = ?", opts.State)
@@ -139,10 +144,10 @@ func sortCommitStatusesSession(sess *xorm.Session, sortType string) {
// GetLatestCommitStatus returns all statuses with a unique context for a given commit.
func GetLatestCommitStatus(repoID int64, sha string, listOptions ListOptions) ([]*CommitStatus, error) {
return getLatestCommitStatus(x, repoID, sha, listOptions)
return getLatestCommitStatus(db.DefaultContext().Engine(), repoID, sha, listOptions)
}
func getLatestCommitStatus(e Engine, repoID int64, sha string, listOptions ListOptions) ([]*CommitStatus, error) {
func getLatestCommitStatus(e db.Engine, repoID int64, sha string, listOptions ListOptions) ([]*CommitStatus, error) {
ids := make([]int64, 0, 10)
sess := e.Table(&CommitStatus{}).
Where("repo_id = ?", repoID).And("sha = ?", sha).
@@ -166,7 +171,7 @@ func getLatestCommitStatus(e Engine, repoID int64, sha string, listOptions ListO
func FindRepoRecentCommitStatusContexts(repoID int64, before time.Duration) ([]string, error) {
start := timeutil.TimeStampNow().AddDuration(-before)
ids := make([]int64, 0, 10)
if err := x.Table("commit_status").
if err := db.DefaultContext().Engine().Table("commit_status").
Where("repo_id = ?", repoID).
And("updated_unix >= ?", start).
Select("max( id ) as id").
@@ -179,7 +184,7 @@ func FindRepoRecentCommitStatusContexts(repoID int64, before time.Duration) ([]s
if len(ids) == 0 {
return contexts, nil
}
return contexts, x.Select("context").Table("commit_status").In("id", ids).Find(&contexts)
return contexts, db.DefaultContext().Engine().Select("context").Table("commit_status").In("id", ids).Find(&contexts)
}
// NewCommitStatusOptions holds options for creating a CommitStatus
@@ -201,12 +206,11 @@ func NewCommitStatus(opts NewCommitStatusOptions) error {
return fmt.Errorf("NewCommitStatus[%s, %s]: no user specified", repoPath, opts.SHA)
}
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return fmt.Errorf("NewCommitStatus[repo_id: %d, user_id: %d, sha: %s]: %v", opts.Repo.ID, opts.Creator.ID, opts.SHA, err)
}
defer committer.Close()
opts.CommitStatus.Description = strings.TrimSpace(opts.CommitStatus.Description)
opts.CommitStatus.Context = strings.TrimSpace(opts.CommitStatus.Context)
@@ -221,11 +225,8 @@ func NewCommitStatus(opts NewCommitStatusOptions) error {
SHA: opts.SHA,
RepoID: opts.Repo.ID,
}
has, err := sess.Desc("index").Limit(1).Get(lastCommitStatus)
has, err := ctx.Engine().Desc("index").Limit(1).Get(lastCommitStatus)
if err != nil {
if err := sess.Rollback(); err != nil {
log.Error("NewCommitStatus: sess.Rollback: %v", err)
}
return fmt.Errorf("NewCommitStatus[%s, %s]: %v", repoPath, opts.SHA, err)
}
if has {
@@ -238,14 +239,11 @@ func NewCommitStatus(opts NewCommitStatusOptions) error {
opts.CommitStatus.ContextHash = hashCommitStatusContext(opts.CommitStatus.Context)
// Insert new CommitStatus
if _, err = sess.Insert(opts.CommitStatus); err != nil {
if err := sess.Rollback(); err != nil {
log.Error("Insert CommitStatus: sess.Rollback: %v", err)
}
if _, err = ctx.Engine().Insert(opts.CommitStatus); err != nil {
return fmt.Errorf("Insert CommitStatus[%s, %s]: %v", repoPath, opts.SHA, err)
}
return sess.Commit()
return committer.Commit()
}
// SignCommitWithStatuses represents a commit with validation of signature and status state.
+3 -2
View File
@@ -7,14 +7,15 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/structs"
"github.com/stretchr/testify/assert"
)
func TestGetCommitStatuses(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
sha1 := "1234123412341234123412341234123412341234"
+47 -106
View File
@@ -5,13 +5,11 @@
package models
import (
"fmt"
"reflect"
"regexp"
"strings"
"testing"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
"xorm.io/builder"
)
@@ -43,7 +41,7 @@ func CheckConsistencyFor(t *testing.T, beansToCheck ...interface{}) {
ptrToSliceValue := reflect.New(sliceType)
ptrToSliceValue.Elem().Set(sliceValue)
assert.NoError(t, x.Table(bean).Find(ptrToSliceValue.Interface()))
assert.NoError(t, db.DefaultContext().Engine().Table(bean).Find(ptrToSliceValue.Interface()))
sliceValue = ptrToSliceValue.Elem()
for i := 0; i < sliceValue.Len(); i++ {
@@ -60,7 +58,7 @@ func CheckConsistencyFor(t *testing.T, beansToCheck ...interface{}) {
}
// getCount get the count of database entries matching bean
func getCount(t *testing.T, e Engine, bean interface{}) int64 {
func getCount(t *testing.T, e db.Engine, bean interface{}) int64 {
count, err := e.Count(bean)
assert.NoError(t, err)
return count
@@ -68,7 +66,7 @@ func getCount(t *testing.T, e Engine, bean interface{}) int64 {
// assertCount test the count of database entries matching bean
func assertCount(t *testing.T, bean interface{}, expected int) {
assert.EqualValues(t, expected, getCount(t, x, bean),
assert.EqualValues(t, expected, getCount(t, db.DefaultContext().Engine(), bean),
"Failed consistency test, the counted bean (of type %T) was %+v", bean, bean)
}
@@ -91,46 +89,46 @@ func (repo *Repository) checkForConsistency(t *testing.T) {
assertCount(t, &Milestone{RepoID: repo.ID}, repo.NumMilestones)
assertCount(t, &Repository{ForkID: repo.ID}, repo.NumForks)
if repo.IsFork {
AssertExistsAndLoadBean(t, &Repository{ID: repo.ForkID})
db.AssertExistsAndLoadBean(t, &Repository{ID: repo.ForkID})
}
actual := getCount(t, x.Where("Mode<>?", RepoWatchModeDont), &Watch{RepoID: repo.ID})
actual := getCount(t, db.DefaultContext().Engine().Where("Mode<>?", RepoWatchModeDont), &Watch{RepoID: repo.ID})
assert.EqualValues(t, repo.NumWatches, actual,
"Unexpected number of watches for repo %+v", repo)
actual = getCount(t, x.Where("is_pull=?", false), &Issue{RepoID: repo.ID})
actual = getCount(t, db.DefaultContext().Engine().Where("is_pull=?", false), &Issue{RepoID: repo.ID})
assert.EqualValues(t, repo.NumIssues, actual,
"Unexpected number of issues for repo %+v", repo)
actual = getCount(t, x.Where("is_pull=? AND is_closed=?", false, true), &Issue{RepoID: repo.ID})
actual = getCount(t, db.DefaultContext().Engine().Where("is_pull=? AND is_closed=?", false, true), &Issue{RepoID: repo.ID})
assert.EqualValues(t, repo.NumClosedIssues, actual,
"Unexpected number of closed issues for repo %+v", repo)
actual = getCount(t, x.Where("is_pull=?", true), &Issue{RepoID: repo.ID})
actual = getCount(t, db.DefaultContext().Engine().Where("is_pull=?", true), &Issue{RepoID: repo.ID})
assert.EqualValues(t, repo.NumPulls, actual,
"Unexpected number of pulls for repo %+v", repo)
actual = getCount(t, x.Where("is_pull=? AND is_closed=?", true, true), &Issue{RepoID: repo.ID})
actual = getCount(t, db.DefaultContext().Engine().Where("is_pull=? AND is_closed=?", true, true), &Issue{RepoID: repo.ID})
assert.EqualValues(t, repo.NumClosedPulls, actual,
"Unexpected number of closed pulls for repo %+v", repo)
actual = getCount(t, x.Where("is_closed=?", true), &Milestone{RepoID: repo.ID})
actual = getCount(t, db.DefaultContext().Engine().Where("is_closed=?", true), &Milestone{RepoID: repo.ID})
assert.EqualValues(t, repo.NumClosedMilestones, actual,
"Unexpected number of closed milestones for repo %+v", repo)
}
func (issue *Issue) checkForConsistency(t *testing.T) {
actual := getCount(t, x.Where("type=?", CommentTypeComment), &Comment{IssueID: issue.ID})
actual := getCount(t, db.DefaultContext().Engine().Where("type=?", CommentTypeComment), &Comment{IssueID: issue.ID})
assert.EqualValues(t, issue.NumComments, actual,
"Unexpected number of comments for issue %+v", issue)
if issue.IsPull {
pr := AssertExistsAndLoadBean(t, &PullRequest{IssueID: issue.ID}).(*PullRequest)
pr := db.AssertExistsAndLoadBean(t, &PullRequest{IssueID: issue.ID}).(*PullRequest)
assert.EqualValues(t, pr.Index, issue.Index)
}
}
func (pr *PullRequest) checkForConsistency(t *testing.T) {
issue := AssertExistsAndLoadBean(t, &Issue{ID: pr.IssueID}).(*Issue)
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: pr.IssueID}).(*Issue)
assert.True(t, issue.IsPull)
assert.EqualValues(t, issue.Index, pr.Index)
}
@@ -138,7 +136,7 @@ func (pr *PullRequest) checkForConsistency(t *testing.T) {
func (milestone *Milestone) checkForConsistency(t *testing.T) {
assertCount(t, &Issue{MilestoneID: milestone.ID}, milestone.NumIssues)
actual := getCount(t, x.Where("is_closed=?", true), &Issue{MilestoneID: milestone.ID})
actual := getCount(t, db.DefaultContext().Engine().Where("is_closed=?", true), &Issue{MilestoneID: milestone.ID})
assert.EqualValues(t, milestone.NumClosedIssues, actual,
"Unexpected number of closed issues for milestone %+v", milestone)
@@ -151,7 +149,7 @@ func (milestone *Milestone) checkForConsistency(t *testing.T) {
func (label *Label) checkForConsistency(t *testing.T) {
issueLabels := make([]*IssueLabel, 0, 10)
assert.NoError(t, x.Find(&issueLabels, &IssueLabel{LabelID: label.ID}))
assert.NoError(t, db.DefaultContext().Engine().Find(&issueLabels, &IssueLabel{LabelID: label.ID}))
assert.EqualValues(t, label.NumIssues, len(issueLabels),
"Unexpected number of issue for label %+v", label)
@@ -162,7 +160,7 @@ func (label *Label) checkForConsistency(t *testing.T) {
expected := int64(0)
if len(issueIDs) > 0 {
expected = getCount(t, x.In("id", issueIDs).Where("is_closed=?", true), &Issue{})
expected = getCount(t, db.DefaultContext().Engine().In("id", issueIDs).Where("is_closed=?", true), &Issue{})
}
assert.EqualValues(t, expected, label.NumClosedIssues,
"Unexpected number of closed issues for label %+v", label)
@@ -174,18 +172,18 @@ func (team *Team) checkForConsistency(t *testing.T) {
}
func (action *Action) checkForConsistency(t *testing.T) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: action.RepoID}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: action.RepoID}).(*Repository)
assert.Equal(t, repo.IsPrivate, action.IsPrivate, "action: %+v", action)
}
// CountOrphanedLabels return count of labels witch are broken and not accessible via ui anymore
func CountOrphanedLabels() (int64, error) {
noref, err := x.Table("label").Where("repo_id=? AND org_id=?", 0, 0).Count("label.id")
noref, err := db.DefaultContext().Engine().Table("label").Where("repo_id=? AND org_id=?", 0, 0).Count("label.id")
if err != nil {
return 0, err
}
norepo, err := x.Table("label").
norepo, err := db.DefaultContext().Engine().Table("label").
Where(builder.And(
builder.Gt{"repo_id": 0},
builder.NotIn("repo_id", builder.Select("id").From("repository")),
@@ -195,7 +193,7 @@ func CountOrphanedLabels() (int64, error) {
return 0, err
}
noorg, err := x.Table("label").
noorg, err := db.DefaultContext().Engine().Table("label").
Where(builder.And(
builder.Gt{"org_id": 0},
builder.NotIn("org_id", builder.Select("id").From("user")),
@@ -211,12 +209,12 @@ func CountOrphanedLabels() (int64, error) {
// DeleteOrphanedLabels delete labels witch are broken and not accessible via ui anymore
func DeleteOrphanedLabels() error {
// delete labels with no reference
if _, err := x.Table("label").Where("repo_id=? AND org_id=?", 0, 0).Delete(new(Label)); err != nil {
if _, err := db.DefaultContext().Engine().Table("label").Where("repo_id=? AND org_id=?", 0, 0).Delete(new(Label)); err != nil {
return err
}
// delete labels with none existing repos
if _, err := x.
if _, err := db.DefaultContext().Engine().
Where(builder.And(
builder.Gt{"repo_id": 0},
builder.NotIn("repo_id", builder.Select("id").From("repository")),
@@ -226,7 +224,7 @@ func DeleteOrphanedLabels() error {
}
// delete labels with none existing orgs
if _, err := x.
if _, err := db.DefaultContext().Engine().
Where(builder.And(
builder.Gt{"org_id": 0},
builder.NotIn("org_id", builder.Select("id").From("user")),
@@ -240,14 +238,14 @@ func DeleteOrphanedLabels() error {
// CountOrphanedIssueLabels return count of IssueLabels witch have no label behind anymore
func CountOrphanedIssueLabels() (int64, error) {
return x.Table("issue_label").
return db.DefaultContext().Engine().Table("issue_label").
NotIn("label_id", builder.Select("id").From("label")).
Count()
}
// DeleteOrphanedIssueLabels delete IssueLabels witch have no label behind anymore
func DeleteOrphanedIssueLabels() error {
_, err := x.
_, err := db.DefaultContext().Engine().
NotIn("label_id", builder.Select("id").From("label")).
Delete(IssueLabel{})
@@ -256,7 +254,7 @@ func DeleteOrphanedIssueLabels() error {
// CountOrphanedIssues count issues without a repo
func CountOrphanedIssues() (int64, error) {
return x.Table("issue").
return db.DefaultContext().Engine().Table("issue").
Join("LEFT", "repository", "issue.repo_id=repository.id").
Where(builder.IsNull{"repository.id"}).
Count("id")
@@ -264,15 +262,15 @@ func CountOrphanedIssues() (int64, error) {
// DeleteOrphanedIssues delete issues without a repo
func DeleteOrphanedIssues() error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
var ids []int64
if err := sess.Table("issue").Distinct("issue.repo_id").
if err := ctx.Engine().Table("issue").Distinct("issue.repo_id").
Join("LEFT", "repository", "issue.repo_id=repository.id").
Where(builder.IsNull{"repository.id"}).GroupBy("issue.repo_id").
Find(&ids); err != nil {
@@ -281,27 +279,28 @@ func DeleteOrphanedIssues() error {
var attachmentPaths []string
for i := range ids {
paths, err := deleteIssuesByRepoID(sess, ids[i])
paths, err := deleteIssuesByRepoID(ctx.Engine(), ids[i])
if err != nil {
return err
}
attachmentPaths = append(attachmentPaths, paths...)
}
if err := sess.Commit(); err != nil {
if err := committer.Commit(); err != nil {
return err
}
committer.Close()
// Remove issue attachment files.
for i := range attachmentPaths {
removeAllWithNotice(x, "Delete issue attachment", attachmentPaths[i])
removeAllWithNotice(db.DefaultContext().Engine(), "Delete issue attachment", attachmentPaths[i])
}
return nil
}
// CountOrphanedObjects count subjects with have no existing refobject anymore
func CountOrphanedObjects(subject, refobject, joinCond string) (int64, error) {
return x.Table("`"+subject+"`").
return db.DefaultContext().Engine().Table("`"+subject+"`").
Join("LEFT", refobject, joinCond).
Where(builder.IsNull{"`" + refobject + "`.id"}).
Count("id")
@@ -317,45 +316,45 @@ func DeleteOrphanedObjects(subject, refobject, joinCond string) error {
if err != nil {
return err
}
_, err = x.Exec(append([]interface{}{sql}, args...)...)
_, err = db.DefaultContext().Engine().Exec(append([]interface{}{sql}, args...)...)
return err
}
// CountNullArchivedRepository counts the number of repositories with is_archived is null
func CountNullArchivedRepository() (int64, error) {
return x.Where(builder.IsNull{"is_archived"}).Count(new(Repository))
return db.DefaultContext().Engine().Where(builder.IsNull{"is_archived"}).Count(new(Repository))
}
// FixNullArchivedRepository sets is_archived to false where it is null
func FixNullArchivedRepository() (int64, error) {
return x.Where(builder.IsNull{"is_archived"}).Cols("is_archived").NoAutoTime().Update(&Repository{
return db.DefaultContext().Engine().Where(builder.IsNull{"is_archived"}).Cols("is_archived").NoAutoTime().Update(&Repository{
IsArchived: false,
})
}
// CountWrongUserType count OrgUser who have wrong type
func CountWrongUserType() (int64, error) {
return x.Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Count(new(User))
return db.DefaultContext().Engine().Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Count(new(User))
}
// FixWrongUserType fix OrgUser who have wrong type
func FixWrongUserType() (int64, error) {
return x.Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Cols("type").NoAutoTime().Update(&User{Type: 1})
return db.DefaultContext().Engine().Where(builder.Eq{"type": 0}.And(builder.Neq{"num_teams": 0})).Cols("type").NoAutoTime().Update(&User{Type: 1})
}
// CountCommentTypeLabelWithEmptyLabel count label comments with empty label
func CountCommentTypeLabelWithEmptyLabel() (int64, error) {
return x.Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Count(new(Comment))
return db.DefaultContext().Engine().Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Count(new(Comment))
}
// FixCommentTypeLabelWithEmptyLabel count label comments with empty label
func FixCommentTypeLabelWithEmptyLabel() (int64, error) {
return x.Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Delete(new(Comment))
return db.DefaultContext().Engine().Where(builder.Eq{"type": CommentTypeLabel, "label_id": 0}).Delete(new(Comment))
}
// CountCommentTypeLabelWithOutsideLabels count label comments with outside label
func CountCommentTypeLabelWithOutsideLabels() (int64, error) {
return x.Where("comment.type = ? AND ((label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id))", CommentTypeLabel).
return db.DefaultContext().Engine().Where("comment.type = ? AND ((label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id))", CommentTypeLabel).
Table("comment").
Join("inner", "label", "label.id = comment.label_id").
Join("inner", "issue", "issue.id = comment.issue_id ").
@@ -365,7 +364,7 @@ func CountCommentTypeLabelWithOutsideLabels() (int64, error) {
// FixCommentTypeLabelWithOutsideLabels count label comments with outside label
func FixCommentTypeLabelWithOutsideLabels() (int64, error) {
res, err := x.Exec(`DELETE FROM comment WHERE comment.id IN (
res, err := db.DefaultContext().Engine().Exec(`DELETE FROM comment WHERE comment.id IN (
SELECT il_too.id FROM (
SELECT com.id
FROM comment AS com
@@ -384,7 +383,7 @@ func FixCommentTypeLabelWithOutsideLabels() (int64, error) {
// CountIssueLabelWithOutsideLabels count label comments with outside label
func CountIssueLabelWithOutsideLabels() (int64, error) {
return x.Where(builder.Expr("(label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id)")).
return db.DefaultContext().Engine().Where(builder.Expr("(label.org_id = 0 AND issue.repo_id != label.repo_id) OR (label.repo_id = 0 AND label.org_id != repository.owner_id)")).
Table("issue_label").
Join("inner", "label", "issue_label.label_id = label.id ").
Join("inner", "issue", "issue.id = issue_label.issue_id ").
@@ -394,7 +393,7 @@ func CountIssueLabelWithOutsideLabels() (int64, error) {
// FixIssueLabelWithOutsideLabels fix label comments with outside label
func FixIssueLabelWithOutsideLabels() (int64, error) {
res, err := x.Exec(`DELETE FROM issue_label WHERE issue_label.id IN (
res, err := db.DefaultContext().Engine().Exec(`DELETE FROM issue_label WHERE issue_label.id IN (
SELECT il_too.id FROM (
SELECT il_too_too.id
FROM issue_label AS il_too_too
@@ -411,61 +410,3 @@ func FixIssueLabelWithOutsideLabels() (int64, error) {
return res.RowsAffected()
}
// CountBadSequences looks for broken sequences from recreate-table mistakes
func CountBadSequences() (int64, error) {
if !setting.Database.UsePostgreSQL {
return 0, nil
}
sess := x.NewSession()
defer sess.Close()
var sequences []string
schema := sess.Engine().Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return 0, err
}
sess.Engine().SetSchema(schema)
return int64(len(sequences)), nil
}
// FixBadSequences fixes for broken sequences from recreate-table mistakes
func FixBadSequences() error {
if !setting.Database.UsePostgreSQL {
return nil
}
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
var sequences []string
schema := sess.Engine().Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return err
}
sess.Engine().SetSchema(schema)
sequenceRegexp := regexp.MustCompile(`tmp_recreate__(\w+)_id_seq.*`)
for _, sequence := range sequences {
tableName := sequenceRegexp.FindStringSubmatch(sequence)[1]
newSequenceName := tableName + "_id_seq"
if _, err := sess.Exec(fmt.Sprintf("ALTER SEQUENCE `%s` RENAME TO `%s`", sequence, newSequenceName)); err != nil {
return err
}
if _, err := sess.Exec(fmt.Sprintf("SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM `%s`), 1), false)", newSequenceName, tableName)); err != nil {
return err
}
}
return sess.Commit()
}
+5 -4
View File
@@ -7,16 +7,17 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestDeleteOrphanedObjects(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
countBefore, err := x.Count(&PullRequest{})
countBefore, err := db.DefaultContext().Engine().Count(&PullRequest{})
assert.NoError(t, err)
_, err = x.Insert(&PullRequest{IssueID: 1000}, &PullRequest{IssueID: 1001}, &PullRequest{IssueID: 1003})
_, err = db.DefaultContext().Engine().Insert(&PullRequest{IssueID: 1000}, &PullRequest{IssueID: 1001}, &PullRequest{IssueID: 1003})
assert.NoError(t, err)
orphaned, err := CountOrphanedObjects("pull_request", "issue", "pull_request.issue_id=issue.id")
@@ -26,7 +27,7 @@ func TestDeleteOrphanedObjects(t *testing.T) {
err = DeleteOrphanedObjects("pull_request", "issue", "pull_request.issue_id=issue.id")
assert.NoError(t, err)
countAfter, err := x.Count(&PullRequest{})
countAfter, err := db.DefaultContext().Engine().Count(&PullRequest{})
assert.NoError(t, err)
assert.EqualValues(t, countBefore, countAfter)
}
-71
View File
@@ -1,71 +0,0 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
import (
"code.gitea.io/gitea/modules/setting"
"xorm.io/builder"
)
// DBContext represents a db context
type DBContext struct {
e Engine
}
// DefaultDBContext represents a DBContext with default Engine
func DefaultDBContext() DBContext {
return DBContext{x}
}
// Committer represents an interface to Commit or Close the dbcontext
type Committer interface {
Commit() error
Close() error
}
// TxDBContext represents a transaction DBContext
func TxDBContext() (DBContext, Committer, error) {
sess := x.NewSession()
if err := sess.Begin(); err != nil {
sess.Close()
return DBContext{}, nil, err
}
return DBContext{sess}, sess, nil
}
// WithContext represents executing database operations
func WithContext(f func(ctx DBContext) error) error {
return f(DBContext{x})
}
// WithTx represents executing database operations on a transaction
func WithTx(f func(ctx DBContext) error) error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
if err := f(DBContext{sess}); err != nil {
return err
}
return sess.Commit()
}
// Iterate iterates the databases and doing something
func Iterate(ctx DBContext, tableBean interface{}, cond builder.Cond, fun func(idx int, bean interface{}) error) error {
return ctx.e.Where(cond).
BufferSize(setting.Database.IterateBufferSize).
Iterate(tableBean, fun)
}
// Insert inserts records into database
func Insert(ctx DBContext, beans ...interface{}) error {
_, err := ctx.e.Insert(beans...)
return err
}
+86
View File
@@ -0,0 +1,86 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"code.gitea.io/gitea/modules/setting"
"xorm.io/builder"
"xorm.io/xorm"
)
// Context represents a db context
type Context struct {
e Engine
}
// Engine returns db engine
func (ctx *Context) Engine() Engine {
return ctx.e
}
// NewSession returns a new session
func (ctx *Context) NewSession() *xorm.Session {
e, ok := ctx.e.(*xorm.Engine)
if ok {
return e.NewSession()
}
return nil
}
// DefaultContext represents a Context with default Engine
func DefaultContext() *Context {
return &Context{x}
}
// Committer represents an interface to Commit or Close the Context
type Committer interface {
Commit() error
Close() error
}
// TxContext represents a transaction Context
func TxContext() (*Context, Committer, error) {
sess := x.NewSession()
if err := sess.Begin(); err != nil {
sess.Close()
return nil, nil, err
}
return &Context{sess}, sess, nil
}
// WithContext represents executing database operations
func WithContext(f func(ctx *Context) error) error {
return f(&Context{x})
}
// WithTx represents executing database operations on a transaction
func WithTx(f func(ctx *Context) error) error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
if err := f(&Context{sess}); err != nil {
return err
}
return sess.Commit()
}
// Iterate iterates the databases and doing something
func Iterate(ctx *Context, tableBean interface{}, cond builder.Cond, fun func(idx int, bean interface{}) error) error {
return ctx.e.Where(cond).
BufferSize(setting.Database.IterateBufferSize).
Iterate(tableBean, fun)
}
// Insert inserts records into database
func Insert(ctx *Context, beans ...interface{}) error {
_, err := ctx.e.Insert(beans...)
return err
}
+1 -1
View File
@@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
package db
import (
"fmt"
+37 -148
View File
@@ -3,13 +3,14 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
package db
import (
"context"
"database/sql"
"errors"
"fmt"
"io"
"reflect"
"strings"
@@ -17,7 +18,6 @@ import (
// Needed for the MySQL driver
_ "github.com/go-sql-driver/mysql"
lru "github.com/hashicorp/golang-lru"
"xorm.io/xorm"
"xorm.io/xorm/names"
"xorm.io/xorm/schemas"
@@ -29,6 +29,15 @@ import (
_ "github.com/denisenkom/go-mssqldb"
)
var (
x *xorm.Engine
tables []interface{}
initFuncs []func() error
// HasEngine specifies if we have a xorm.Engine
HasEngine bool
)
// Engine represents a xorm engine or session.
type Engine interface {
Table(tableNameOrBean interface{}) *xorm.Session
@@ -51,96 +60,35 @@ type Engine interface {
Desc(colNames ...string) *xorm.Session
Limit(limit int, start ...int) *xorm.Session
SumInt(bean interface{}, columnName string) (res int64, err error)
Sync2(...interface{}) error
Select(string) *xorm.Session
NotIn(string, ...interface{}) *xorm.Session
OrderBy(string) *xorm.Session
Exist(...interface{}) (bool, error)
Distinct(...string) *xorm.Session
Query(...interface{}) ([]map[string][]byte, error)
Cols(...string) *xorm.Session
}
const (
// When queries are broken down in parts because of the number
// of parameters, attempt to break by this amount
maxQueryParameters = 300
)
// TableInfo returns table's information via an object
func TableInfo(v interface{}) (*schemas.Table, error) {
return x.TableInfo(v)
}
var (
x *xorm.Engine
tables []interface{}
// DumpTables dump tables information
func DumpTables(tables []*schemas.Table, w io.Writer, tp ...schemas.DBType) error {
return x.DumpTables(tables, w, tp...)
}
// HasEngine specifies if we have a xorm.Engine
HasEngine bool
)
// RegisterModel registers model, if initfunc provided, it will be invoked after data model sync
func RegisterModel(bean interface{}, initFunc ...func() error) {
tables = append(tables, bean)
if len(initFuncs) > 0 && initFunc[0] != nil {
initFuncs = append(initFuncs, initFunc[0])
}
}
func init() {
tables = append(tables,
new(User),
new(PublicKey),
new(AccessToken),
new(Repository),
new(DeployKey),
new(Collaboration),
new(Access),
new(Upload),
new(Watch),
new(Star),
new(Follow),
new(Action),
new(Issue),
new(PullRequest),
new(Comment),
new(Attachment),
new(Label),
new(IssueLabel),
new(Milestone),
new(Mirror),
new(Release),
new(LoginSource),
new(Webhook),
new(HookTask),
new(Team),
new(OrgUser),
new(TeamUser),
new(TeamRepo),
new(Notice),
new(EmailAddress),
new(Notification),
new(IssueUser),
new(LFSMetaObject),
new(TwoFactor),
new(GPGKey),
new(GPGKeyImport),
new(RepoUnit),
new(RepoRedirect),
new(ExternalLoginUser),
new(ProtectedBranch),
new(UserOpenID),
new(IssueWatch),
new(CommitStatus),
new(Stopwatch),
new(TrackedTime),
new(DeletedBranch),
new(RepoIndexerStatus),
new(IssueDependency),
new(LFSLock),
new(Reaction),
new(IssueAssignees),
new(U2FRegistration),
new(TeamUnit),
new(Review),
new(OAuth2Application),
new(OAuth2AuthorizationCode),
new(OAuth2Grant),
new(Task),
new(LanguageStat),
new(EmailHash),
new(UserRedirect),
new(Project),
new(ProjectBoard),
new(ProjectIssue),
new(Session),
new(RepoTransfer),
new(IssueIndex),
new(PushMirror),
new(RepoArchiver),
new(ProtectedTag),
)
gonicNames := []string{"SSL", "UID"}
for _, name := range gonicNames {
names.LintGonicMapper[name] = true
@@ -235,13 +183,10 @@ func NewEngine(ctx context.Context, migrateFunc func(*xorm.Engine) error) (err e
return fmt.Errorf("sync database struct error: %v", err)
}
if setting.SuccessfulTokensCacheSize > 0 {
successfulAccessTokenCache, err = lru.New(setting.SuccessfulTokensCacheSize)
if err != nil {
return fmt.Errorf("unable to allocate AccessToken cache: %v", err)
for _, initFunc := range initFuncs {
if err := initFunc(); err != nil {
return fmt.Errorf("initFunc failed: %v", err)
}
} else {
successfulAccessTokenCache = nil
}
return nil
@@ -277,62 +222,6 @@ func NamesToBean(names ...string) ([]interface{}, error) {
return beans, nil
}
// Statistic contains the database statistics
type Statistic struct {
Counter struct {
User, Org, PublicKey,
Repo, Watch, Star, Action, Access,
Issue, IssueClosed, IssueOpen,
Comment, Oauth, Follow,
Mirror, Release, LoginSource, Webhook,
Milestone, Label, HookTask,
Team, UpdateTask, Attachment int64
}
}
// GetStatistic returns the database statistics
func GetStatistic() (stats Statistic) {
stats.Counter.User = CountUsers()
stats.Counter.Org = CountOrganizations()
stats.Counter.PublicKey, _ = x.Count(new(PublicKey))
stats.Counter.Repo = CountRepositories(true)
stats.Counter.Watch, _ = x.Count(new(Watch))
stats.Counter.Star, _ = x.Count(new(Star))
stats.Counter.Action, _ = x.Count(new(Action))
stats.Counter.Access, _ = x.Count(new(Access))
type IssueCount struct {
Count int64
IsClosed bool
}
issueCounts := []IssueCount{}
_ = x.Select("COUNT(*) AS count, is_closed").Table("issue").GroupBy("is_closed").Find(&issueCounts)
for _, c := range issueCounts {
if c.IsClosed {
stats.Counter.IssueClosed = c.Count
} else {
stats.Counter.IssueOpen = c.Count
}
}
stats.Counter.Issue = stats.Counter.IssueClosed + stats.Counter.IssueOpen
stats.Counter.Comment, _ = x.Count(new(Comment))
stats.Counter.Oauth = 0
stats.Counter.Follow, _ = x.Count(new(Follow))
stats.Counter.Mirror, _ = x.Count(new(Mirror))
stats.Counter.Release, _ = x.Count(new(Release))
stats.Counter.LoginSource = CountLoginSources()
stats.Counter.Webhook, _ = x.Count(new(Webhook))
stats.Counter.Milestone, _ = x.Count(new(Milestone))
stats.Counter.Label, _ = x.Count(new(Label))
stats.Counter.HookTask, _ = x.Count(new(HookTask))
stats.Counter.Team, _ = x.Count(new(Team))
stats.Counter.Attachment, _ = x.Count(new(Attachment))
return
}
// Ping tests if database is alive
func Ping() error {
if x != nil {
+6 -9
View File
@@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
package db
import (
"errors"
@@ -18,11 +18,8 @@ type ResourceIndex struct {
MaxIndex int64 `xorm:"index"`
}
// IssueIndex represents the issue index table
type IssueIndex ResourceIndex
// upsertResourceIndex the function will not return until it acquires the lock or receives an error.
func upsertResourceIndex(e Engine, tableName string, groupID int64) (err error) {
// UpsertResourceIndex the function will not return until it acquires the lock or receives an error.
func UpsertResourceIndex(e Engine, tableName string, groupID int64) (err error) {
// An atomic UPSERT operation (INSERT/UPDATE) is the only operation
// that ensures that the key is actually locked.
switch {
@@ -75,8 +72,8 @@ func GetNextResourceIndex(tableName string, groupID int64) (int64, error) {
return 0, ErrGetResourceIndexFailed
}
// deleteResouceIndex delete resource index
func deleteResouceIndex(e Engine, tableName string, groupID int64) error {
// DeleteResouceIndex delete resource index
func DeleteResouceIndex(e Engine, tableName string, groupID int64) error {
_, err := e.Exec(fmt.Sprintf("DELETE FROM %s WHERE group_id=?", tableName), groupID)
return err
}
@@ -94,7 +91,7 @@ func getNextResourceIndex(tableName string, groupID int64) (int64, error) {
return 0, err
}
if err := upsertResourceIndex(sess, tableName, groupID); err != nil {
if err := UpsertResourceIndex(sess, tableName, groupID); err != nil {
return 0, err
}
+1 -1
View File
@@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
package db
import (
"fmt"
+14
View File
@@ -0,0 +1,14 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"path/filepath"
"testing"
)
func TestMain(m *testing.M) {
MainTest(m, filepath.Join("..", ".."))
}
+70
View File
@@ -0,0 +1,70 @@
// Copyright 2018 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package db
import (
"fmt"
"regexp"
"code.gitea.io/gitea/modules/setting"
)
// CountBadSequences looks for broken sequences from recreate-table mistakes
func CountBadSequences() (int64, error) {
if !setting.Database.UsePostgreSQL {
return 0, nil
}
sess := x.NewSession()
defer sess.Close()
var sequences []string
schema := x.Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return 0, err
}
sess.Engine().SetSchema(schema)
return int64(len(sequences)), nil
}
// FixBadSequences fixes for broken sequences from recreate-table mistakes
func FixBadSequences() error {
if !setting.Database.UsePostgreSQL {
return nil
}
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
}
var sequences []string
schema := sess.Engine().Dialect().URI().Schema
sess.Engine().SetSchema("")
if err := sess.Table("information_schema.sequences").Cols("sequence_name").Where("sequence_name LIKE 'tmp_recreate__%_id_seq%' AND sequence_catalog = ?", setting.Database.Name).Find(&sequences); err != nil {
return err
}
sess.Engine().SetSchema(schema)
sequenceRegexp := regexp.MustCompile(`tmp_recreate__(\w+)_id_seq.*`)
for _, sequence := range sequences {
tableName := sequenceRegexp.FindStringSubmatch(sequence)[1]
newSequenceName := tableName + "_id_seq"
if _, err := sess.Exec(fmt.Sprintf("ALTER SEQUENCE `%s` RENAME TO `%s`", sequence, newSequenceName)); err != nil {
return err
}
if _, err := sess.Exec(fmt.Sprintf("SELECT setval('%s', COALESCE((SELECT MAX(id)+1 FROM `%s`), 1), false)", newSequenceName, tableName)); err != nil {
return err
}
}
return sess.Commit()
}
@@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
package db
import (
"database/sql"
+4 -2
View File
@@ -2,9 +2,11 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
package db
import "github.com/lafriks/xormstore"
import (
"github.com/lafriks/xormstore"
)
// CreateStore creates a xormstore for the provided table and key
func CreateStore(table, key string) (*xormstore.Store, error) {
@@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
package db
import (
"fmt"
@@ -2,7 +2,7 @@
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
package db
import (
"fmt"
@@ -32,6 +32,11 @@ var (
fixturesDir string
)
// FixturesDir returns the fixture directory
func FixturesDir() string {
return fixturesDir
}
func fatalTestError(fmtStr string, args ...interface{}) {
fmt.Fprintf(os.Stderr, fmtStr, args...)
os.Exit(1)
@@ -152,6 +157,11 @@ func whereConditions(sess *xorm.Session, conditions []interface{}) {
}
}
// LoadBeanIfExists loads beans from fixture database if exist
func LoadBeanIfExists(bean interface{}, conditions ...interface{}) (bool, error) {
return loadBeanIfExists(bean, conditions...)
}
func loadBeanIfExists(bean interface{}, conditions ...interface{}) (bool, error) {
sess := x.NewSession()
defer sess.Close()
+35
View File
@@ -0,0 +1,35 @@
// Copyright 2019 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
import (
"io/ioutil"
"os"
"path/filepath"
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"github.com/stretchr/testify/assert"
)
func TestDumpDatabase(t *testing.T) {
assert.NoError(t, db.PrepareTestDatabase())
dir, err := ioutil.TempDir(os.TempDir(), "dump")
assert.NoError(t, err)
type Version struct {
ID int64 `xorm:"pk autoincr"`
Version int64
}
assert.NoError(t, db.DefaultContext().Engine().Sync2(new(Version)))
for _, dbName := range setting.SupportedDatabases {
dbType := setting.GetDBTypeByName(dbName)
assert.NoError(t, db.DumpDatabase(filepath.Join(dir, dbType+".sql"), dbType))
}
}
+15 -10
View File
@@ -7,6 +7,7 @@ package models
import (
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/structs"
"github.com/markbates/goth"
@@ -34,15 +35,19 @@ type ExternalLoginUser struct {
ExpiresAt time.Time
}
func init() {
db.RegisterModel(new(ExternalLoginUser))
}
// GetExternalLogin checks if a externalID in loginSourceID scope already exists
func GetExternalLogin(externalLoginUser *ExternalLoginUser) (bool, error) {
return x.Get(externalLoginUser)
return db.DefaultContext().Engine().Get(externalLoginUser)
}
// ListAccountLinks returns a map with the ExternalLoginUser and its LoginSource
func ListAccountLinks(user *User) ([]*ExternalLoginUser, error) {
externalAccounts := make([]*ExternalLoginUser, 0, 5)
err := x.Where("user_id=?", user.ID).
err := db.DefaultContext().Engine().Where("user_id=?", user.ID).
Desc("login_source_id").
Find(&externalAccounts)
if err != nil {
@@ -54,7 +59,7 @@ func ListAccountLinks(user *User) ([]*ExternalLoginUser, error) {
// LinkExternalToUser link the external user to the user
func LinkExternalToUser(user *User, externalLoginUser *ExternalLoginUser) error {
has, err := x.Where("external_id=? AND login_source_id=?", externalLoginUser.ExternalID, externalLoginUser.LoginSourceID).
has, err := db.DefaultContext().Engine().Where("external_id=? AND login_source_id=?", externalLoginUser.ExternalID, externalLoginUser.LoginSourceID).
NoAutoCondition().
Exist(externalLoginUser)
if err != nil {
@@ -63,13 +68,13 @@ func LinkExternalToUser(user *User, externalLoginUser *ExternalLoginUser) error
return ErrExternalLoginUserAlreadyExist{externalLoginUser.ExternalID, user.ID, externalLoginUser.LoginSourceID}
}
_, err = x.Insert(externalLoginUser)
_, err = db.DefaultContext().Engine().Insert(externalLoginUser)
return err
}
// RemoveAccountLink will remove all external login sources for the given user
func RemoveAccountLink(user *User, loginSourceID int64) (int64, error) {
deleted, err := x.Delete(&ExternalLoginUser{UserID: user.ID, LoginSourceID: loginSourceID})
deleted, err := db.DefaultContext().Engine().Delete(&ExternalLoginUser{UserID: user.ID, LoginSourceID: loginSourceID})
if err != nil {
return deleted, err
}
@@ -80,7 +85,7 @@ func RemoveAccountLink(user *User, loginSourceID int64) (int64, error) {
}
// removeAllAccountLinks will remove all external login sources for the given user
func removeAllAccountLinks(e Engine, user *User) error {
func removeAllAccountLinks(e db.Engine, user *User) error {
_, err := e.Delete(&ExternalLoginUser{UserID: user.ID})
return err
}
@@ -88,7 +93,7 @@ func removeAllAccountLinks(e Engine, user *User) error {
// GetUserIDByExternalUserID get user id according to provider and userID
func GetUserIDByExternalUserID(provider, userID string) (int64, error) {
var id int64
_, err := x.Table("external_login_user").
_, err := db.DefaultContext().Engine().Table("external_login_user").
Select("user_id").
Where("provider=?", provider).
And("external_id=?", userID).
@@ -125,7 +130,7 @@ func UpdateExternalUser(user *User, gothUser goth.User) error {
ExpiresAt: gothUser.ExpiresAt,
}
has, err := x.Where("external_id=? AND login_source_id=?", gothUser.UserID, loginSource.ID).
has, err := db.DefaultContext().Engine().Where("external_id=? AND login_source_id=?", gothUser.UserID, loginSource.ID).
NoAutoCondition().
Exist(externalLoginUser)
if err != nil {
@@ -134,7 +139,7 @@ func UpdateExternalUser(user *User, gothUser goth.User) error {
return ErrExternalLoginUserNotExist{user.ID, loginSource.ID}
}
_, err = x.Where("external_id=? AND login_source_id=?", gothUser.UserID, loginSource.ID).AllCols().Update(externalLoginUser)
_, err = db.DefaultContext().Engine().Where("external_id=? AND login_source_id=?", gothUser.UserID, loginSource.ID).AllCols().Update(externalLoginUser)
return err
}
@@ -156,7 +161,7 @@ func (opts FindExternalUserOptions) toConds() builder.Cond {
// FindExternalUsersByProvider represents external users via provider
func FindExternalUsersByProvider(opts FindExternalUserOptions) ([]ExternalLoginUser, error) {
var users []ExternalLoginUser
err := x.Where(opts.toConds()).
err := db.DefaultContext().Engine().Where(opts.toConds()).
Limit(opts.Limit, opts.Start).
OrderBy("login_source_id ASC, external_id ASC").
Find(&users)
+4 -2
View File
@@ -7,13 +7,15 @@ package models
import (
"fmt"
"strings"
"code.gitea.io/gitea/models/db"
)
// GetYamlFixturesAccess returns a string containing the contents
// for the access table, as recalculated using repo.RecalculateAccesses()
func GetYamlFixturesAccess() (string, error) {
repos := make([]*Repository, 0, 50)
if err := x.Find(&repos); err != nil {
if err := db.DefaultContext().Engine().Find(&repos); err != nil {
return "", err
}
@@ -27,7 +29,7 @@ func GetYamlFixturesAccess() (string, error) {
var b strings.Builder
accesses := make([]*Access, 0, 200)
if err := x.OrderBy("user_id, repo_id").Find(&accesses); err != nil {
if err := db.DefaultContext().Engine().OrderBy("user_id, repo_id").Find(&accesses); err != nil {
return "", err
}
+3 -2
View File
@@ -9,20 +9,21 @@ import (
"path/filepath"
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/util"
"github.com/stretchr/testify/assert"
)
func TestFixtureGeneration(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(gen func() (string, error), name string) {
expected, err := gen()
if !assert.NoError(t, err) {
return
}
bytes, err := ioutil.ReadFile(filepath.Join(fixturesDir, name+".yml"))
bytes, err := ioutil.ReadFile(filepath.Join(db.FixturesDir(), name+".yml"))
if !assert.NoError(t, err) {
return
}
+19 -14
View File
@@ -9,6 +9,7 @@ import (
"strings"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
@@ -43,6 +44,10 @@ type GPGKey struct {
CanCertify bool
}
func init() {
db.RegisterModel(new(GPGKey))
}
// BeforeInsert will be invoked by XORM before inserting a record
func (key *GPGKey) BeforeInsert() {
key.AddedUnix = timeutil.TimeStampNow()
@@ -58,10 +63,10 @@ func (key *GPGKey) AfterLoad(session *xorm.Session) {
// ListGPGKeys returns a list of public keys belongs to given user.
func ListGPGKeys(uid int64, listOptions ListOptions) ([]*GPGKey, error) {
return listGPGKeys(x, uid, listOptions)
return listGPGKeys(db.DefaultContext().Engine(), uid, listOptions)
}
func listGPGKeys(e Engine, uid int64, listOptions ListOptions) ([]*GPGKey, error) {
func listGPGKeys(e db.Engine, uid int64, listOptions ListOptions) ([]*GPGKey, error) {
sess := e.Table(&GPGKey{}).Where("owner_id=? AND primary_key_id=''", uid)
if listOptions.Page != 0 {
sess = setSessionPagination(sess, &listOptions)
@@ -73,13 +78,13 @@ func listGPGKeys(e Engine, uid int64, listOptions ListOptions) ([]*GPGKey, error
// CountUserGPGKeys return number of gpg keys a user own
func CountUserGPGKeys(userID int64) (int64, error) {
return x.Where("owner_id=? AND primary_key_id=''", userID).Count(&GPGKey{})
return db.DefaultContext().Engine().Where("owner_id=? AND primary_key_id=''", userID).Count(&GPGKey{})
}
// GetGPGKeyByID returns public key by given ID.
func GetGPGKeyByID(keyID int64) (*GPGKey, error) {
key := new(GPGKey)
has, err := x.ID(keyID).Get(key)
has, err := db.DefaultContext().Engine().ID(keyID).Get(key)
if err != nil {
return nil, err
} else if !has {
@@ -91,7 +96,7 @@ func GetGPGKeyByID(keyID int64) (*GPGKey, error) {
// GetGPGKeysByKeyID returns public key by given ID.
func GetGPGKeysByKeyID(keyID string) ([]*GPGKey, error) {
keys := make([]*GPGKey, 0, 1)
return keys, x.Where("key_id=?", keyID).Find(&keys)
return keys, db.DefaultContext().Engine().Where("key_id=?", keyID).Find(&keys)
}
// GPGKeyToEntity retrieve the imported key and the traducted entity
@@ -195,7 +200,7 @@ func parseGPGKey(ownerID int64, e *openpgp.Entity, verified bool) (*GPGKey, erro
}
// deleteGPGKey does the actual key deletion
func deleteGPGKey(e *xorm.Session, keyID string) (int64, error) {
func deleteGPGKey(e db.Engine, keyID string) (int64, error) {
if keyID == "" {
return 0, fmt.Errorf("empty KeyId forbidden") // Should never happen but just to be sure
}
@@ -222,17 +227,17 @@ func DeleteGPGKey(doer *User, id int64) (err error) {
return ErrGPGKeyAccessDenied{doer.ID, key.ID}
}
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if _, err = deleteGPGKey(ctx.Engine(), key.KeyID); err != nil {
return err
}
if _, err = deleteGPGKey(sess, key.KeyID); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}
func checkKeyEmails(email string, keys ...*GPGKey) (bool, string) {
+10 -8
View File
@@ -7,6 +7,7 @@ package models
import (
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"github.com/keybase/go-crypto/openpgp"
@@ -28,7 +29,7 @@ import (
// This file contains functions relating to adding GPG Keys
// addGPGKey add key, import and subkeys to database
func addGPGKey(e Engine, key *GPGKey, content string) (err error) {
func addGPGKey(e db.Engine, key *GPGKey, content string) (err error) {
// Add GPGKeyImport
if _, err = e.Insert(GPGKeyImport{
KeyID: key.KeyID,
@@ -50,7 +51,7 @@ func addGPGKey(e Engine, key *GPGKey, content string) (err error) {
}
// addGPGSubKey add subkeys to database
func addGPGSubKey(e Engine, key *GPGKey) (err error) {
func addGPGSubKey(e db.Engine, key *GPGKey) (err error) {
// Save GPG primary key.
if _, err = e.Insert(key); err != nil {
return err
@@ -71,11 +72,12 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
return nil, err
}
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return nil, err
}
defer committer.Close()
keys := make([]*GPGKey, 0, len(ekeys))
verified := false
@@ -101,7 +103,7 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
for _, ekey := range ekeys {
// Key ID cannot be duplicated.
has, err := sess.Where("key_id=?", ekey.PrimaryKey.KeyIdString()).
has, err := ctx.Engine().Where("key_id=?", ekey.PrimaryKey.KeyIdString()).
Get(new(GPGKey))
if err != nil {
return nil, err
@@ -116,10 +118,10 @@ func AddGPGKey(ownerID int64, content, token, signature string) ([]*GPGKey, erro
return nil, err
}
if err = addGPGKey(sess, key, content); err != nil {
if err = addGPGKey(ctx.Engine(), key, content); err != nil {
return nil, err
}
keys = append(keys, key)
}
return keys, sess.Commit()
return keys, committer.Commit()
}
+7 -1
View File
@@ -4,6 +4,8 @@
package models
import "code.gitea.io/gitea/models/db"
// __________________ ________ ____ __.
// / _____/\______ \/ _____/ | |/ _|____ ___.__.
// / \ ___ | ___/ \ ___ | <_/ __ < | |
@@ -25,10 +27,14 @@ type GPGKeyImport struct {
Content string `xorm:"TEXT NOT NULL"`
}
func init() {
db.RegisterModel(new(GPGKeyImport))
}
// GetGPGImportByKeyID returns the import public armored key by given KeyID.
func GetGPGImportByKeyID(keyID string) (*GPGKeyImport, error) {
key := new(GPGKeyImport)
has, err := x.ID(keyID).Get(key)
has, err := db.DefaultContext().Engine().ID(keyID).Get(key)
if err != nil {
return nil, err
} else if !has {
+3 -2
View File
@@ -8,6 +8,7 @@ import (
"testing"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"github.com/stretchr/testify/assert"
@@ -192,9 +193,9 @@ Unknown GPG key with good email
}
func TestCheckGPGUserEmail(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
_ = AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
_ = db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
testEmailWithUpperCaseLetters := `-----BEGIN PGP PUBLIC KEY BLOCK-----
Version: GnuPG v1
+7 -6
View File
@@ -8,6 +8,7 @@ import (
"strconv"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/log"
)
@@ -29,15 +30,15 @@ import (
// VerifyGPGKey marks a GPG key as verified
func VerifyGPGKey(ownerID int64, keyID, token, signature string) (string, error) {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return "", err
}
defer committer.Close()
key := new(GPGKey)
has, err := sess.Where("owner_id = ? AND key_id = ?", ownerID, keyID).Get(key)
has, err := ctx.Engine().Where("owner_id = ? AND key_id = ?", ownerID, keyID).Get(key)
if err != nil {
return "", err
} else if !has {
@@ -91,11 +92,11 @@ func VerifyGPGKey(ownerID int64, keyID, token, signature string) (string, error)
}
key.Verified = true
if _, err := sess.ID(key.ID).SetExpr("verified", true).Update(new(GPGKey)); err != nil {
if _, err := ctx.Engine().ID(key.ID).SetExpr("verified", true).Update(new(GPGKey)); err != nil {
return "", err
}
if err := sess.Commit(); err != nil {
if err := committer.Commit(); err != nil {
return "", err
}
-27
View File
@@ -1,27 +0,0 @@
// Copyright 2021 The Gitea Authors. All rights reserved.
// Use of this source code is governed by a MIT-style
// license that can be found in the LICENSE file.
package models
import (
"fmt"
"sync"
"testing"
"github.com/stretchr/testify/assert"
)
func TestResourceIndex(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
testInsertIssue(t, fmt.Sprintf("issue %d", i+1), "my issue", 0)
wg.Done()
}(i)
}
wg.Wait()
}
+168 -157
View File
@@ -12,6 +12,7 @@ import (
"strconv"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/references"
@@ -82,12 +83,18 @@ const (
issueTasksDoneRegexpStr = `(^\s*[-*]\s\[[xX]\]\s.)|(\n\s*[-*]\s\[[xX]\]\s.)`
)
// IssueIndex represents the issue index table
type IssueIndex db.ResourceIndex
func init() {
issueTasksPat = regexp.MustCompile(issueTasksRegexpStr)
issueTasksDonePat = regexp.MustCompile(issueTasksDoneRegexpStr)
db.RegisterModel(new(Issue))
db.RegisterModel(new(IssueIndex))
}
func (issue *Issue) loadTotalTimes(e Engine) (err error) {
func (issue *Issue) loadTotalTimes(e db.Engine) (err error) {
opts := FindTrackedTimesOptions{IssueID: issue.ID}
issue.TotalTrackedTime, err = opts.toSession(e).SumInt(&TrackedTime{}, "time")
if err != nil {
@@ -106,10 +113,10 @@ func (issue *Issue) IsOverdue() bool {
// LoadRepo loads issue's repository
func (issue *Issue) LoadRepo() error {
return issue.loadRepo(x)
return issue.loadRepo(db.DefaultContext().Engine())
}
func (issue *Issue) loadRepo(e Engine) (err error) {
func (issue *Issue) loadRepo(e db.Engine) (err error) {
if issue.Repo == nil {
issue.Repo, err = getRepositoryByID(e, issue.RepoID)
if err != nil {
@@ -121,10 +128,10 @@ func (issue *Issue) loadRepo(e Engine) (err error) {
// IsTimetrackerEnabled returns true if the repo enables timetracking
func (issue *Issue) IsTimetrackerEnabled() bool {
return issue.isTimetrackerEnabled(x)
return issue.isTimetrackerEnabled(db.DefaultContext().Engine())
}
func (issue *Issue) isTimetrackerEnabled(e Engine) bool {
func (issue *Issue) isTimetrackerEnabled(e db.Engine) bool {
if err := issue.loadRepo(e); err != nil {
log.Error(fmt.Sprintf("loadRepo: %v", err))
return false
@@ -138,7 +145,7 @@ func (issue *Issue) GetPullRequest() (pr *PullRequest, err error) {
return nil, fmt.Errorf("Issue is not a pull request")
}
pr, err = getPullRequestByIssueID(x, issue.ID)
pr, err = getPullRequestByIssueID(db.DefaultContext().Engine(), issue.ID)
if err != nil {
return nil, err
}
@@ -148,10 +155,10 @@ func (issue *Issue) GetPullRequest() (pr *PullRequest, err error) {
// LoadLabels loads labels
func (issue *Issue) LoadLabels() error {
return issue.loadLabels(x)
return issue.loadLabels(db.DefaultContext().Engine())
}
func (issue *Issue) loadLabels(e Engine) (err error) {
func (issue *Issue) loadLabels(e db.Engine) (err error) {
if issue.Labels == nil {
issue.Labels, err = getLabelsByIssueID(e, issue.ID)
if err != nil {
@@ -163,10 +170,10 @@ func (issue *Issue) loadLabels(e Engine) (err error) {
// LoadPoster loads poster
func (issue *Issue) LoadPoster() error {
return issue.loadPoster(x)
return issue.loadPoster(db.DefaultContext().Engine())
}
func (issue *Issue) loadPoster(e Engine) (err error) {
func (issue *Issue) loadPoster(e db.Engine) (err error) {
if issue.Poster == nil {
issue.Poster, err = getUserByID(e, issue.PosterID)
if err != nil {
@@ -182,7 +189,7 @@ func (issue *Issue) loadPoster(e Engine) (err error) {
return
}
func (issue *Issue) loadPullRequest(e Engine) (err error) {
func (issue *Issue) loadPullRequest(e db.Engine) (err error) {
if issue.IsPull && issue.PullRequest == nil {
issue.PullRequest, err = getPullRequestByIssueID(e, issue.ID)
if err != nil {
@@ -198,19 +205,19 @@ func (issue *Issue) loadPullRequest(e Engine) (err error) {
// LoadPullRequest loads pull request info
func (issue *Issue) LoadPullRequest() error {
return issue.loadPullRequest(x)
return issue.loadPullRequest(db.DefaultContext().Engine())
}
func (issue *Issue) loadComments(e Engine) (err error) {
func (issue *Issue) loadComments(e db.Engine) (err error) {
return issue.loadCommentsByType(e, CommentTypeUnknown)
}
// LoadDiscussComments loads discuss comments
func (issue *Issue) LoadDiscussComments() error {
return issue.loadCommentsByType(x, CommentTypeComment)
return issue.loadCommentsByType(db.DefaultContext().Engine(), CommentTypeComment)
}
func (issue *Issue) loadCommentsByType(e Engine, tp CommentType) (err error) {
func (issue *Issue) loadCommentsByType(e db.Engine, tp CommentType) (err error) {
if issue.Comments != nil {
return nil
}
@@ -221,7 +228,7 @@ func (issue *Issue) loadCommentsByType(e Engine, tp CommentType) (err error) {
return err
}
func (issue *Issue) loadReactions(e Engine) (err error) {
func (issue *Issue) loadReactions(e db.Engine) (err error) {
if issue.Reactions != nil {
return nil
}
@@ -255,7 +262,7 @@ func (issue *Issue) loadReactions(e Engine) (err error) {
return nil
}
func (issue *Issue) loadMilestone(e Engine) (err error) {
func (issue *Issue) loadMilestone(e db.Engine) (err error) {
if (issue.Milestone == nil || issue.Milestone.ID != issue.MilestoneID) && issue.MilestoneID > 0 {
issue.Milestone, err = getMilestoneByRepoID(e, issue.RepoID, issue.MilestoneID)
if err != nil && !IsErrMilestoneNotExist(err) {
@@ -265,7 +272,7 @@ func (issue *Issue) loadMilestone(e Engine) (err error) {
return nil
}
func (issue *Issue) loadAttributes(e Engine) (err error) {
func (issue *Issue) loadAttributes(e db.Engine) (err error) {
if err = issue.loadRepo(e); err != nil {
return
}
@@ -320,18 +327,18 @@ func (issue *Issue) loadAttributes(e Engine) (err error) {
// LoadAttributes loads the attribute of this issue.
func (issue *Issue) LoadAttributes() error {
return issue.loadAttributes(x)
return issue.loadAttributes(db.DefaultContext().Engine())
}
// LoadMilestone load milestone of this issue.
func (issue *Issue) LoadMilestone() error {
return issue.loadMilestone(x)
return issue.loadMilestone(db.DefaultContext().Engine())
}
// GetIsRead load the `IsRead` field of the issue
func (issue *Issue) GetIsRead(userID int64) error {
issueUser := &IssueUser{IssueID: issue.ID, UID: userID}
if has, err := x.Get(issueUser); err != nil {
if has, err := db.DefaultContext().Engine().Get(issueUser); err != nil {
return err
} else if !has {
issue.IsRead = false
@@ -398,13 +405,13 @@ func (issue *Issue) IsPoster(uid int64) bool {
return issue.OriginalAuthorID == 0 && issue.PosterID == uid
}
func (issue *Issue) hasLabel(e Engine, labelID int64) bool {
func (issue *Issue) hasLabel(e db.Engine, labelID int64) bool {
return hasIssueLabel(e, issue.ID, labelID)
}
// HasLabel returns true if issue has been labeled by given ID.
func (issue *Issue) HasLabel(labelID int64) bool {
return issue.hasLabel(x, labelID)
return issue.hasLabel(db.DefaultContext().Engine(), labelID)
}
// ReplyReference returns tokenized address to use for email reply headers
@@ -419,15 +426,15 @@ func (issue *Issue) ReplyReference() string {
return fmt.Sprintf("%s/%s/%d@%s", issue.Repo.FullName(), path, issue.Index, setting.Domain)
}
func (issue *Issue) addLabel(e *xorm.Session, label *Label, doer *User) error {
func (issue *Issue) addLabel(e db.Engine, label *Label, doer *User) error {
return newIssueLabel(e, issue, label, doer)
}
func (issue *Issue) addLabels(e *xorm.Session, labels []*Label, doer *User) error {
func (issue *Issue) addLabels(e db.Engine, labels []*Label, doer *User) error {
return newIssueLabels(e, issue, labels, doer)
}
func (issue *Issue) getLabels(e Engine) (err error) {
func (issue *Issue) getLabels(e db.Engine) (err error) {
if len(issue.Labels) > 0 {
return nil
}
@@ -439,11 +446,11 @@ func (issue *Issue) getLabels(e Engine) (err error) {
return nil
}
func (issue *Issue) removeLabel(e *xorm.Session, doer *User, label *Label) error {
func (issue *Issue) removeLabel(e db.Engine, doer *User, label *Label) error {
return deleteIssueLabel(e, issue, label, doer)
}
func (issue *Issue) clearLabels(e *xorm.Session, doer *User) (err error) {
func (issue *Issue) clearLabels(e db.Engine, doer *User) (err error) {
if err = issue.getLabels(e); err != nil {
return fmt.Errorf("getLabels: %v", err)
}
@@ -460,19 +467,19 @@ func (issue *Issue) clearLabels(e *xorm.Session, doer *User) (err error) {
// ClearLabels removes all issue labels as the given user.
// Triggers appropriate WebHooks, if any.
func (issue *Issue) ClearLabels(doer *User) (err error) {
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if err := issue.loadRepo(ctx.Engine()); err != nil {
return err
} else if err = issue.loadPullRequest(ctx.Engine()); err != nil {
return err
}
if err := issue.loadRepo(sess); err != nil {
return err
} else if err = issue.loadPullRequest(sess); err != nil {
return err
}
perm, err := getUserRepoPermission(sess, issue.Repo, doer)
perm, err := getUserRepoPermission(ctx.Engine(), issue.Repo, doer)
if err != nil {
return err
}
@@ -480,11 +487,11 @@ func (issue *Issue) ClearLabels(doer *User) (err error) {
return ErrRepoLabelNotExist{}
}
if err = issue.clearLabels(sess, doer); err != nil {
if err = issue.clearLabels(ctx.Engine(), doer); err != nil {
return err
}
if err = sess.Commit(); err != nil {
if err = committer.Commit(); err != nil {
return fmt.Errorf("Commit: %v", err)
}
@@ -508,17 +515,17 @@ func (ts labelSorter) Swap(i, j int) {
// ReplaceLabels removes all current labels and add new labels to the issue.
// Triggers appropriate WebHooks, if any.
func (issue *Issue) ReplaceLabels(labels []*Label, doer *User) (err error) {
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if err = issue.loadRepo(ctx.Engine()); err != nil {
return err
}
if err = issue.loadRepo(sess); err != nil {
return err
}
if err = issue.loadLabels(sess); err != nil {
if err = issue.loadLabels(ctx.Engine()); err != nil {
return err
}
@@ -554,23 +561,23 @@ func (issue *Issue) ReplaceLabels(labels []*Label, doer *User) (err error) {
toRemove = append(toRemove, issue.Labels[removeIndex:]...)
if len(toAdd) > 0 {
if err = issue.addLabels(sess, toAdd, doer); err != nil {
if err = issue.addLabels(ctx.Engine(), toAdd, doer); err != nil {
return fmt.Errorf("addLabels: %v", err)
}
}
for _, l := range toRemove {
if err = issue.removeLabel(sess, doer, l); err != nil {
if err = issue.removeLabel(ctx.Engine(), doer, l); err != nil {
return fmt.Errorf("removeLabel: %v", err)
}
}
issue.Labels = nil
if err = issue.loadLabels(sess); err != nil {
if err = issue.loadLabels(ctx.Engine()); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}
// ReadBy sets issue to be read by given user.
@@ -579,17 +586,17 @@ func (issue *Issue) ReadBy(userID int64) error {
return err
}
return setIssueNotificationStatusReadIfUnread(x, userID, issue.ID)
return setIssueNotificationStatusReadIfUnread(db.DefaultContext().Engine(), userID, issue.ID)
}
func updateIssueCols(e Engine, issue *Issue, cols ...string) error {
func updateIssueCols(e db.Engine, issue *Issue, cols ...string) error {
if _, err := e.ID(issue.ID).Cols(cols...).Update(issue); err != nil {
return err
}
return nil
}
func (issue *Issue) changeStatus(e *xorm.Session, doer *User, isClosed, isMergePull bool) (*Comment, error) {
func (issue *Issue) changeStatus(e db.Engine, doer *User, isClosed, isMergePull bool) (*Comment, error) {
// Reload the issue
currentIssue, err := getIssueByID(e, issue.ID)
if err != nil {
@@ -612,7 +619,7 @@ func (issue *Issue) changeStatus(e *xorm.Session, doer *User, isClosed, isMergeP
return issue.doChangeStatus(e, doer, isMergePull)
}
func (issue *Issue) doChangeStatus(e *xorm.Session, doer *User, isMergePull bool) (*Comment, error) {
func (issue *Issue) doChangeStatus(e db.Engine, doer *User, isMergePull bool) (*Comment, error) {
// Check for open dependencies
if issue.IsClosed && issue.Repo.isDependenciesEnabled(e) {
// only check if dependencies are enabled and we're about to close an issue, otherwise reopening an issue would fail when there are unsatisfied dependencies
@@ -675,25 +682,25 @@ func (issue *Issue) doChangeStatus(e *xorm.Session, doer *User, isMergePull bool
// ChangeStatus changes issue status to open or closed.
func (issue *Issue) ChangeStatus(doer *User, isClosed bool) (*Comment, error) {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return nil, err
}
defer committer.Close()
if err := issue.loadRepo(ctx.Engine()); err != nil {
return nil, err
}
if err := issue.loadPoster(ctx.Engine()); err != nil {
return nil, err
}
if err := issue.loadRepo(sess); err != nil {
return nil, err
}
if err := issue.loadPoster(sess); err != nil {
return nil, err
}
comment, err := issue.changeStatus(sess, doer, isClosed, false)
comment, err := issue.changeStatus(ctx.Engine(), doer, isClosed, false)
if err != nil {
return nil, err
}
if err = sess.Commit(); err != nil {
if err = committer.Commit(); err != nil {
return nil, fmt.Errorf("Commit: %v", err)
}
@@ -702,18 +709,17 @@ func (issue *Issue) ChangeStatus(doer *User, isClosed bool) (*Comment, error) {
// ChangeTitle changes the title of this issue, as the given user.
func (issue *Issue) ChangeTitle(doer *User, oldTitle string) (err error) {
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if err = updateIssueCols(sess, issue, "name"); err != nil {
if err = updateIssueCols(ctx.Engine(), issue, "name"); err != nil {
return fmt.Errorf("updateIssueCols: %v", err)
}
if err = issue.loadRepo(sess); err != nil {
if err = issue.loadRepo(ctx.Engine()); err != nil {
return fmt.Errorf("loadRepo: %v", err)
}
@@ -725,43 +731,42 @@ func (issue *Issue) ChangeTitle(doer *User, oldTitle string) (err error) {
OldTitle: oldTitle,
NewTitle: issue.Title,
}
if _, err = createComment(sess, opts); err != nil {
if _, err = createComment(ctx.Engine(), opts); err != nil {
return fmt.Errorf("createComment: %v", err)
}
if err = issue.addCrossReferences(sess, doer, true); err != nil {
if err = issue.addCrossReferences(ctx.Engine(), doer, true); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}
// ChangeRef changes the branch of this issue, as the given user.
func (issue *Issue) ChangeRef(doer *User, oldRef string) (err error) {
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if err = updateIssueCols(sess, issue, "ref"); err != nil {
if err = updateIssueCols(ctx.Engine(), issue, "ref"); err != nil {
return fmt.Errorf("updateIssueCols: %v", err)
}
return sess.Commit()
return committer.Commit()
}
// AddDeletePRBranchComment adds delete branch comment for pull request issue
func AddDeletePRBranchComment(doer *User, repo *Repository, issueID int64, branchName string) error {
issue, err := getIssueByID(x, issueID)
issue, err := getIssueByID(db.DefaultContext().Engine(), issueID)
if err != nil {
return err
}
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
opts := &CreateCommentOptions{
Type: CommentTypeDeleteBranch,
Doer: doer,
@@ -769,52 +774,52 @@ func AddDeletePRBranchComment(doer *User, repo *Repository, issueID int64, branc
Issue: issue,
OldRef: branchName,
}
if _, err = createComment(sess, opts); err != nil {
if _, err = createComment(ctx.Engine(), opts); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}
// UpdateAttachments update attachments by UUIDs for the issue
func (issue *Issue) UpdateAttachments(uuids []string) (err error) {
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
attachments, err := getAttachmentsByUUIDs(sess, uuids)
defer committer.Close()
attachments, err := getAttachmentsByUUIDs(ctx.Engine(), uuids)
if err != nil {
return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %v", uuids, err)
}
for i := 0; i < len(attachments); i++ {
attachments[i].IssueID = issue.ID
if err := updateAttachment(sess, attachments[i]); err != nil {
if err := updateAttachment(ctx.Engine(), attachments[i]); err != nil {
return fmt.Errorf("update attachment [id: %d]: %v", attachments[i].ID, err)
}
}
return sess.Commit()
return committer.Commit()
}
// ChangeContent changes issue content, as the given user.
func (issue *Issue) ChangeContent(doer *User, content string) (err error) {
issue.Content = content
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if err = updateIssueCols(sess, issue, "content"); err != nil {
if err = updateIssueCols(ctx.Engine(), issue, "content"); err != nil {
return fmt.Errorf("UpdateIssueCols: %v", err)
}
if err = issue.addCrossReferences(sess, doer, true); err != nil {
if err = issue.addCrossReferences(ctx.Engine(), doer, true); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}
// GetTasks returns the amount of tasks in the issues content
@@ -849,7 +854,7 @@ func (issue *Issue) GetLastEventLabel() string {
// GetLastComment return last comment for the current issue.
func (issue *Issue) GetLastComment() (*Comment, error) {
var c Comment
exist, err := x.Where("type = ?", CommentTypeComment).
exist, err := db.DefaultContext().Engine().Where("type = ?", CommentTypeComment).
And("issue_id = ?", issue.ID).Desc("id").Get(&c)
if err != nil {
return nil, err
@@ -880,7 +885,7 @@ type NewIssueOptions struct {
IsPull bool
}
func newIssue(e *xorm.Session, doer *User, opts NewIssueOptions) (err error) {
func newIssue(e db.Engine, doer *User, opts NewIssueOptions) (err error) {
opts.Issue.Title = strings.TrimSpace(opts.Issue.Title)
if opts.Issue.MilestoneID > 0 {
@@ -985,44 +990,44 @@ func newIssue(e *xorm.Session, doer *User, opts NewIssueOptions) (err error) {
// RecalculateIssueIndexForRepo create issue_index for repo if not exist and
// update it based on highest index of existing issues assigned to a repo
func RecalculateIssueIndexForRepo(repoID int64) error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if err := upsertResourceIndex(sess, "issue_index", repoID); err != nil {
if err := db.UpsertResourceIndex(ctx.Engine(), "issue_index", repoID); err != nil {
return err
}
var max int64
if _, err := sess.Select(" MAX(`index`)").Table("issue").Where("repo_id=?", repoID).Get(&max); err != nil {
if _, err := ctx.Engine().Select(" MAX(`index`)").Table("issue").Where("repo_id=?", repoID).Get(&max); err != nil {
return err
}
if _, err := sess.Exec("UPDATE `issue_index` SET max_index=? WHERE group_id=?", max, repoID); err != nil {
if _, err := ctx.Engine().Exec("UPDATE `issue_index` SET max_index=? WHERE group_id=?", max, repoID); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}
// NewIssue creates new issue with labels for repository.
func NewIssue(repo *Repository, issue *Issue, labelIDs []int64, uuids []string) (err error) {
idx, err := GetNextResourceIndex("issue_index", repo.ID)
idx, err := db.GetNextResourceIndex("issue_index", repo.ID)
if err != nil {
return fmt.Errorf("generate issue index failed: %v", err)
}
issue.Index = idx
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if err = newIssue(sess, issue.Poster, NewIssueOptions{
if err = newIssue(ctx.Engine(), issue.Poster, NewIssueOptions{
Repo: repo,
Issue: issue,
LabelIDs: labelIDs,
@@ -1034,7 +1039,7 @@ func NewIssue(repo *Repository, issue *Issue, labelIDs []int64, uuids []string)
return fmt.Errorf("newIssue: %v", err)
}
if err = sess.Commit(); err != nil {
if err = committer.Commit(); err != nil {
return fmt.Errorf("Commit: %v", err)
}
@@ -1050,7 +1055,7 @@ func GetIssueByIndex(repoID, index int64) (*Issue, error) {
RepoID: repoID,
Index: index,
}
has, err := x.Get(issue)
has, err := db.DefaultContext().Engine().Get(issue)
if err != nil {
return nil, err
} else if !has {
@@ -1068,7 +1073,7 @@ func GetIssueWithAttrsByIndex(repoID, index int64) (*Issue, error) {
return issue, issue.LoadAttributes()
}
func getIssueByID(e Engine, id int64) (*Issue, error) {
func getIssueByID(e db.Engine, id int64) (*Issue, error) {
issue := new(Issue)
has, err := e.ID(id).Get(issue)
if err != nil {
@@ -1081,24 +1086,24 @@ func getIssueByID(e Engine, id int64) (*Issue, error) {
// GetIssueWithAttrsByID returns an issue with attributes by given ID.
func GetIssueWithAttrsByID(id int64) (*Issue, error) {
issue, err := getIssueByID(x, id)
issue, err := getIssueByID(db.DefaultContext().Engine(), id)
if err != nil {
return nil, err
}
return issue, issue.loadAttributes(x)
return issue, issue.loadAttributes(db.DefaultContext().Engine())
}
// GetIssueByID returns an issue by given ID.
func GetIssueByID(id int64) (*Issue, error) {
return getIssueByID(x, id)
return getIssueByID(db.DefaultContext().Engine(), id)
}
func getIssuesByIDs(e Engine, issueIDs []int64) ([]*Issue, error) {
func getIssuesByIDs(e db.Engine, issueIDs []int64) ([]*Issue, error) {
issues := make([]*Issue, 0, 10)
return issues, e.In("id", issueIDs).Find(&issues)
}
func getIssueIDsByRepoID(e Engine, repoID int64) ([]int64, error) {
func getIssueIDsByRepoID(e db.Engine, repoID int64) ([]int64, error) {
ids := make([]int64, 0, 10)
err := e.Table("issue").Cols("id").Where("repo_id = ?", repoID).Find(&ids)
return ids, err
@@ -1106,12 +1111,12 @@ func getIssueIDsByRepoID(e Engine, repoID int64) ([]int64, error) {
// GetIssueIDsByRepoID returns all issue ids by repo id
func GetIssueIDsByRepoID(repoID int64) ([]int64, error) {
return getIssueIDsByRepoID(x, repoID)
return getIssueIDsByRepoID(db.DefaultContext().Engine(), repoID)
}
// GetIssuesByIDs return issues with the given IDs.
func GetIssuesByIDs(issueIDs []int64) ([]*Issue, error) {
return getIssuesByIDs(x, issueIDs)
return getIssuesByIDs(db.DefaultContext().Engine(), issueIDs)
}
// IssuesOptions represents options of an issue.
@@ -1311,7 +1316,7 @@ func applyReviewRequestedCondition(sess *xorm.Session, reviewRequestedID int64)
// CountIssuesByRepo map from repoID to number of issues matching the options
func CountIssuesByRepo(opts *IssuesOptions) (map[int64]int64, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
@@ -1339,7 +1344,7 @@ func CountIssuesByRepo(opts *IssuesOptions) (map[int64]int64, error) {
// GetRepoIDsForIssuesOptions find all repo ids for the given options
func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *User) ([]int64, error) {
repoIDs := make([]int64, 0, 5)
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
@@ -1359,7 +1364,7 @@ func GetRepoIDsForIssuesOptions(opts *IssuesOptions, user *User) ([]int64, error
// Issues returns a list of issues by given conditions.
func Issues(opts *IssuesOptions) ([]*Issue, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
sess.Join("INNER", "repository", "`issue`.repo_id = `repository`.id")
@@ -1381,7 +1386,7 @@ func Issues(opts *IssuesOptions) ([]*Issue, error) {
// CountIssues number return of issues by given conditions.
func CountIssues(opts *IssuesOptions) (int64, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
countsSlice := make([]*struct {
@@ -1406,7 +1411,7 @@ func CountIssues(opts *IssuesOptions) (int64, error) {
// User permissions must be verified elsewhere if required.
func GetParticipantsIDsByIssueID(issueID int64) ([]int64, error) {
userIDs := make([]int64, 0, 5)
return userIDs, x.Table("comment").
return userIDs, db.DefaultContext().Engine().Table("comment").
Cols("poster_id").
Where("issue_id = ?", issueID).
And("type in (?,?,?)", CommentTypeComment, CommentTypeCode, CommentTypeReview).
@@ -1416,7 +1421,7 @@ func GetParticipantsIDsByIssueID(issueID int64) ([]int64, error) {
// IsUserParticipantsOfIssue return true if user is participants of an issue
func IsUserParticipantsOfIssue(user *User, issue *Issue) bool {
userIDs, err := issue.getParticipantIDsByIssue(x)
userIDs, err := issue.getParticipantIDsByIssue(db.DefaultContext().Engine())
if err != nil {
log.Error(err.Error())
return false
@@ -1425,7 +1430,7 @@ func IsUserParticipantsOfIssue(user *User, issue *Issue) bool {
}
// UpdateIssueMentions updates issue-user relations for mentioned users.
func UpdateIssueMentions(ctx DBContext, issueID int64, mentions []*User) error {
func UpdateIssueMentions(ctx *db.Context, issueID int64, mentions []*User) error {
if len(mentions) == 0 {
return nil
}
@@ -1482,6 +1487,12 @@ type IssueStatsOptions struct {
IssueIDs []int64
}
const (
// When queries are broken down in parts because of the number
// of parameters, attempt to break by this amount
maxQueryParameters = 300
)
// GetIssueStats returns issue statistic information by given conditions.
func GetIssueStats(opts *IssueStatsOptions) (*IssueStats, error) {
if len(opts.IssueIDs) <= maxQueryParameters {
@@ -1518,7 +1529,7 @@ func getIssueStatsChunk(opts *IssueStatsOptions, issueIDs []int64) (*IssueStats,
stats := &IssueStats{}
countSession := func(opts *IssueStatsOptions) *xorm.Session {
sess := x.
sess := db.DefaultContext().Engine().
Where("issue.repo_id = ?", opts.RepoID)
if len(opts.IssueIDs) > 0 {
@@ -1612,7 +1623,7 @@ func GetUserIssueStats(opts UserIssueStatsOptions) (*IssueStats, error) {
}
sess := func(cond builder.Cond) *xorm.Session {
s := x.Where(cond)
s := db.DefaultContext().Engine().Where(cond)
if len(opts.LabelIDs) > 0 {
s.Join("INNER", "issue_label", "issue_label.issue_id = issue.id").
In("issue_label.label_id", opts.LabelIDs)
@@ -1724,7 +1735,7 @@ func GetUserIssueStats(opts UserIssueStatsOptions) (*IssueStats, error) {
// GetRepoIssueStats returns number of open and closed repository issues by given filter mode.
func GetRepoIssueStats(repoID, uid int64, filterMode int, isPull bool) (numOpen, numClosed int64) {
countSession := func(isClosed, isPull bool, repoID int64) *xorm.Session {
sess := x.
sess := db.DefaultContext().Engine().
Where("is_closed = ?", isClosed).
And("is_pull = ?", isPull).
And("repo_id = ?", repoID)
@@ -1776,7 +1787,7 @@ func SearchIssueIDsByKeyword(kw string, repoIDs []int64, limit, start int) (int6
ID int64
UpdatedUnix int64
}, 0, limit)
err := x.Distinct("id", "updated_unix").Table("issue").Where(cond).
err := db.DefaultContext().Engine().Distinct("id", "updated_unix").Table("issue").Where(cond).
OrderBy("`updated_unix` DESC").Limit(limit, start).
Find(&res)
if err != nil {
@@ -1786,7 +1797,7 @@ func SearchIssueIDsByKeyword(kw string, repoIDs []int64, limit, start int) (int6
ids = append(ids, r.ID)
}
total, err := x.Distinct("id").Table("issue").Where(cond).Count()
total, err := db.DefaultContext().Engine().Distinct("id").Table("issue").Where(cond).Count()
if err != nil {
return 0, nil, err
}
@@ -1798,7 +1809,7 @@ func SearchIssueIDsByKeyword(kw string, repoIDs []int64, limit, start int) (int6
// If the issue status is changed a statusChangeComment is returned
// similarly if the title is changed the titleChanged bool is set to true
func UpdateIssueByAPI(issue *Issue, doer *User) (statusChangeComment *Comment, titleChanged bool, err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return nil, false, err
@@ -1857,7 +1868,7 @@ func UpdateIssueDeadline(issue *Issue, deadlineUnix timeutil.TimeStamp, doer *Us
return nil
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -1883,7 +1894,7 @@ type DependencyInfo struct {
}
// getParticipantIDsByIssue returns all userIDs who are participated in comments of an issue and issue author
func (issue *Issue) getParticipantIDsByIssue(e Engine) ([]int64, error) {
func (issue *Issue) getParticipantIDsByIssue(e db.Engine) ([]int64, error) {
if issue == nil {
return nil, nil
}
@@ -1905,7 +1916,7 @@ func (issue *Issue) getParticipantIDsByIssue(e Engine) ([]int64, error) {
}
// Get Blocked By Dependencies, aka all issues this issue is blocked by.
func (issue *Issue) getBlockedByDependencies(e Engine) (issueDeps []*DependencyInfo, err error) {
func (issue *Issue) getBlockedByDependencies(e db.Engine) (issueDeps []*DependencyInfo, err error) {
return issueDeps, e.
Table("issue").
Join("INNER", "repository", "repository.id = issue.repo_id").
@@ -1917,7 +1928,7 @@ func (issue *Issue) getBlockedByDependencies(e Engine) (issueDeps []*DependencyI
}
// Get Blocking Dependencies, aka all issues this issue blocks.
func (issue *Issue) getBlockingDependencies(e Engine) (issueDeps []*DependencyInfo, err error) {
func (issue *Issue) getBlockingDependencies(e db.Engine) (issueDeps []*DependencyInfo, err error) {
return issueDeps, e.
Table("issue").
Join("INNER", "repository", "repository.id = issue.repo_id").
@@ -1930,15 +1941,15 @@ func (issue *Issue) getBlockingDependencies(e Engine) (issueDeps []*DependencyIn
// BlockedByDependencies finds all Dependencies an issue is blocked by
func (issue *Issue) BlockedByDependencies() ([]*DependencyInfo, error) {
return issue.getBlockedByDependencies(x)
return issue.getBlockedByDependencies(db.DefaultContext().Engine())
}
// BlockingDependencies returns all blocking dependencies, aka all other issues a given issue blocks
func (issue *Issue) BlockingDependencies() ([]*DependencyInfo, error) {
return issue.getBlockingDependencies(x)
return issue.getBlockingDependencies(db.DefaultContext().Engine())
}
func (issue *Issue) updateClosedNum(e Engine) (err error) {
func (issue *Issue) updateClosedNum(e db.Engine) (err error) {
if issue.IsPull {
_, err = e.Exec("UPDATE `repository` SET num_closed_pulls=(SELECT count(*) FROM issue WHERE repo_id=? AND is_pull=? AND is_closed=?) WHERE id=?",
issue.RepoID,
@@ -1958,7 +1969,7 @@ func (issue *Issue) updateClosedNum(e Engine) (err error) {
}
// FindAndUpdateIssueMentions finds users mentioned in the given content string, and saves them in the database.
func (issue *Issue) FindAndUpdateIssueMentions(ctx DBContext, doer *User, content string) (mentions []*User, err error) {
func (issue *Issue) FindAndUpdateIssueMentions(ctx *db.Context, doer *User, content string) (mentions []*User, err error) {
rawMentions := references.FindAllMentionsMarkdown(content)
mentions, err = issue.ResolveMentionsByVisibility(ctx, doer, rawMentions)
if err != nil {
@@ -1972,18 +1983,18 @@ func (issue *Issue) FindAndUpdateIssueMentions(ctx DBContext, doer *User, conten
// ResolveMentionsByVisibility returns the users mentioned in an issue, removing those that
// don't have access to reading it. Teams are expanded into their users, but organizations are ignored.
func (issue *Issue) ResolveMentionsByVisibility(ctx DBContext, doer *User, mentions []string) (users []*User, err error) {
func (issue *Issue) ResolveMentionsByVisibility(ctx *db.Context, doer *User, mentions []string) (users []*User, err error) {
if len(mentions) == 0 {
return
}
if err = issue.loadRepo(ctx.e); err != nil {
if err = issue.loadRepo(ctx.Engine()); err != nil {
return
}
resolved := make(map[string]bool, 10)
var mentionTeams []string
if err := issue.Repo.getOwner(ctx.e); err != nil {
if err := issue.Repo.getOwner(ctx.Engine()); err != nil {
return nil, err
}
@@ -2012,7 +2023,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx DBContext, doer *User, menti
if issue.Repo.Owner.IsOrganization() && len(mentionTeams) > 0 {
teams := make([]*Team, 0, len(mentionTeams))
if err := ctx.e.
if err := ctx.Engine().
Join("INNER", "team_repo", "team_repo.team_id = team.id").
Where("team_repo.repo_id=?", issue.Repo.ID).
In("team.lower_name", mentionTeams).
@@ -2031,7 +2042,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx DBContext, doer *User, menti
resolved[issue.Repo.Owner.LowerName+"/"+team.LowerName] = true
continue
}
has, err := ctx.e.Get(&TeamUnit{OrgID: issue.Repo.Owner.ID, TeamID: team.ID, Type: unittype})
has, err := ctx.Engine().Get(&TeamUnit{OrgID: issue.Repo.Owner.ID, TeamID: team.ID, Type: unittype})
if err != nil {
return nil, fmt.Errorf("get team units (%d): %v", team.ID, err)
}
@@ -2042,7 +2053,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx DBContext, doer *User, menti
}
if len(checked) != 0 {
teamusers := make([]*User, 0, 20)
if err := ctx.e.
if err := ctx.Engine().
Join("INNER", "team_user", "team_user.uid = `user`.id").
In("`team_user`.team_id", checked).
And("`user`.is_active = ?", true).
@@ -2079,7 +2090,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx DBContext, doer *User, menti
}
unchecked := make([]*User, 0, len(mentionUsers))
if err := ctx.e.
if err := ctx.Engine().
Where("`user`.is_active = ?", true).
And("`user`.prohibit_login = ?", false).
In("`user`.lower_name", mentionUsers).
@@ -2091,7 +2102,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx DBContext, doer *User, menti
continue
}
// Normal users must have read access to the referencing issue
perm, err := getUserRepoPermission(ctx.e, issue.Repo, user)
perm, err := getUserRepoPermission(ctx.Engine(), issue.Repo, user)
if err != nil {
return nil, fmt.Errorf("getUserRepoPermission [%d]: %v", user.ID, err)
}
@@ -2106,7 +2117,7 @@ func (issue *Issue) ResolveMentionsByVisibility(ctx DBContext, doer *User, menti
// UpdateIssuesMigrationsByType updates all migrated repositories' issues from gitServiceType to replace originalAuthorID to posterID
func UpdateIssuesMigrationsByType(gitServiceType structs.GitServiceType, originalAuthorID string, posterID int64) error {
_, err := x.Table("issue").
_, err := db.DefaultContext().Engine().Table("issue").
Where("repo_id IN (SELECT id FROM repository WHERE original_service_type = ?)", gitServiceType).
And("original_author_id = ?", originalAuthorID).
Update(map[string]interface{}{
@@ -2119,7 +2130,7 @@ func UpdateIssuesMigrationsByType(gitServiceType structs.GitServiceType, origina
// UpdateReactionsMigrationsByType updates all migrated repositories' reactions from gitServiceType to replace originalAuthorID to posterID
func UpdateReactionsMigrationsByType(gitServiceType structs.GitServiceType, originalAuthorID string, userID int64) error {
_, err := x.Table("reaction").
_, err := db.DefaultContext().Engine().Table("reaction").
Where("original_author_id = ?", originalAuthorID).
And(migratedIssueCond(gitServiceType)).
Update(map[string]interface{}{
@@ -2130,7 +2141,7 @@ func UpdateReactionsMigrationsByType(gitServiceType structs.GitServiceType, orig
return err
}
func deleteIssuesByRepoID(sess Engine, repoID int64) (attachmentPaths []string, err error) {
func deleteIssuesByRepoID(sess db.Engine, repoID int64) (attachmentPaths []string, err error) {
deleteCond := builder.Select("id").From("issue").Where(builder.Eq{"issue.repo_id": repoID})
// Delete comments and attachments
+14 -9
View File
@@ -7,6 +7,7 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/util"
"xorm.io/xorm"
@@ -19,13 +20,17 @@ type IssueAssignees struct {
IssueID int64 `xorm:"INDEX"`
}
func init() {
db.RegisterModel(new(IssueAssignees))
}
// LoadAssignees load assignees of this issue.
func (issue *Issue) LoadAssignees() error {
return issue.loadAssignees(x)
return issue.loadAssignees(db.DefaultContext().Engine())
}
// This loads all assignees of an issue
func (issue *Issue) loadAssignees(e Engine) (err error) {
func (issue *Issue) loadAssignees(e db.Engine) (err error) {
// Reset maybe preexisting assignees
issue.Assignees = []*User{}
@@ -51,7 +56,7 @@ func (issue *Issue) loadAssignees(e Engine) (err error) {
// User permissions must be verified elsewhere if required.
func GetAssigneeIDsByIssue(issueID int64) ([]int64, error) {
userIDs := make([]int64, 0, 5)
return userIDs, x.Table("issue_assignees").
return userIDs, db.DefaultContext().Engine().Table("issue_assignees").
Cols("assignee_id").
Where("issue_id = ?", issueID).
Distinct("assignee_id").
@@ -60,10 +65,10 @@ func GetAssigneeIDsByIssue(issueID int64) ([]int64, error) {
// GetAssigneesByIssue returns everyone assigned to that issue
func GetAssigneesByIssue(issue *Issue) (assignees []*User, err error) {
return getAssigneesByIssue(x, issue)
return getAssigneesByIssue(db.DefaultContext().Engine(), issue)
}
func getAssigneesByIssue(e Engine, issue *Issue) (assignees []*User, err error) {
func getAssigneesByIssue(e db.Engine, issue *Issue) (assignees []*User, err error) {
err = issue.loadAssignees(e)
if err != nil {
return assignees, err
@@ -74,22 +79,22 @@ func getAssigneesByIssue(e Engine, issue *Issue) (assignees []*User, err error)
// IsUserAssignedToIssue returns true when the user is assigned to the issue
func IsUserAssignedToIssue(issue *Issue, user *User) (isAssigned bool, err error) {
return isUserAssignedToIssue(x, issue, user)
return isUserAssignedToIssue(db.DefaultContext().Engine(), issue, user)
}
func isUserAssignedToIssue(e Engine, issue *Issue, user *User) (isAssigned bool, err error) {
func isUserAssignedToIssue(e db.Engine, issue *Issue, user *User) (isAssigned bool, err error) {
return e.Get(&IssueAssignees{IssueID: issue.ID, AssigneeID: user.ID})
}
// ClearAssigneeByUserID deletes all assignments of an user
func clearAssigneeByUserID(sess Engine, userID int64) (err error) {
func clearAssigneeByUserID(sess db.Engine, userID int64) (err error) {
_, err = sess.Delete(&IssueAssignees{AssigneeID: userID})
return
}
// ToggleAssignee changes a user between assigned and not assigned for this issue, and make issue comment for it.
func (issue *Issue) ToggleAssignee(doer *User, assigneeID int64) (removed bool, comment *Comment, err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
+5 -4
View File
@@ -7,11 +7,12 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestUpdateAssignee(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
// Fake issue with assignees
issue, err := GetIssueWithAttrsByID(1)
@@ -61,10 +62,10 @@ func TestUpdateAssignee(t *testing.T) {
}
func TestMakeIDsFromAPIAssigneesToAdd(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
_ = AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
_ = AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
_ = db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
_ = db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
IDs, err := MakeIDsFromAPIAssigneesToAdd("", []string{""})
assert.NoError(t, err)
+46 -41
View File
@@ -13,6 +13,7 @@ import (
"strings"
"unicode/utf8"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/json"
"code.gitea.io/gitea/modules/log"
@@ -197,6 +198,10 @@ type Comment struct {
IsForcePush bool `xorm:"-"`
}
func init() {
db.RegisterModel(new(Comment))
}
// PushActionContent is content of push pull comment
type PushActionContent struct {
IsForcePush bool `json:"is_force_push"`
@@ -205,10 +210,10 @@ type PushActionContent struct {
// LoadIssue loads issue from database
func (c *Comment) LoadIssue() (err error) {
return c.loadIssue(x)
return c.loadIssue(db.DefaultContext().Engine())
}
func (c *Comment) loadIssue(e Engine) (err error) {
func (c *Comment) loadIssue(e db.Engine) (err error) {
if c.Issue != nil {
return nil
}
@@ -243,7 +248,7 @@ func (c *Comment) AfterLoad(session *xorm.Session) {
}
}
func (c *Comment) loadPoster(e Engine) (err error) {
func (c *Comment) loadPoster(e db.Engine) (err error) {
if c.PosterID <= 0 || c.Poster != nil {
return nil
}
@@ -279,7 +284,7 @@ func (c *Comment) HTMLURL() string {
log.Error("LoadIssue(%d): %v", c.IssueID, err)
return ""
}
err = c.Issue.loadRepo(x)
err = c.Issue.loadRepo(db.DefaultContext().Engine())
if err != nil { // Silently dropping errors :unamused:
log.Error("loadRepo(%d): %v", c.Issue.RepoID, err)
return ""
@@ -308,7 +313,7 @@ func (c *Comment) APIURL() string {
log.Error("LoadIssue(%d): %v", c.IssueID, err)
return ""
}
err = c.Issue.loadRepo(x)
err = c.Issue.loadRepo(db.DefaultContext().Engine())
if err != nil { // Silently dropping errors :unamused:
log.Error("loadRepo(%d): %v", c.Issue.RepoID, err)
return ""
@@ -329,7 +334,7 @@ func (c *Comment) IssueURL() string {
return ""
}
err = c.Issue.loadRepo(x)
err = c.Issue.loadRepo(db.DefaultContext().Engine())
if err != nil { // Silently dropping errors :unamused:
log.Error("loadRepo(%d): %v", c.Issue.RepoID, err)
return ""
@@ -345,7 +350,7 @@ func (c *Comment) PRURL() string {
return ""
}
err = c.Issue.loadRepo(x)
err = c.Issue.loadRepo(db.DefaultContext().Engine())
if err != nil { // Silently dropping errors :unamused:
log.Error("loadRepo(%d): %v", c.Issue.RepoID, err)
return ""
@@ -375,7 +380,7 @@ func (c *Comment) EventTag() string {
// LoadLabel if comment.Type is CommentTypeLabel, then load Label
func (c *Comment) LoadLabel() error {
var label Label
has, err := x.ID(c.LabelID).Get(&label)
has, err := db.DefaultContext().Engine().ID(c.LabelID).Get(&label)
if err != nil {
return err
} else if has {
@@ -392,7 +397,7 @@ func (c *Comment) LoadLabel() error {
func (c *Comment) LoadProject() error {
if c.OldProjectID > 0 {
var oldProject Project
has, err := x.ID(c.OldProjectID).Get(&oldProject)
has, err := db.DefaultContext().Engine().ID(c.OldProjectID).Get(&oldProject)
if err != nil {
return err
} else if has {
@@ -402,7 +407,7 @@ func (c *Comment) LoadProject() error {
if c.ProjectID > 0 {
var project Project
has, err := x.ID(c.ProjectID).Get(&project)
has, err := db.DefaultContext().Engine().ID(c.ProjectID).Get(&project)
if err != nil {
return err
} else if has {
@@ -417,7 +422,7 @@ func (c *Comment) LoadProject() error {
func (c *Comment) LoadMilestone() error {
if c.OldMilestoneID > 0 {
var oldMilestone Milestone
has, err := x.ID(c.OldMilestoneID).Get(&oldMilestone)
has, err := db.DefaultContext().Engine().ID(c.OldMilestoneID).Get(&oldMilestone)
if err != nil {
return err
} else if has {
@@ -427,7 +432,7 @@ func (c *Comment) LoadMilestone() error {
if c.MilestoneID > 0 {
var milestone Milestone
has, err := x.ID(c.MilestoneID).Get(&milestone)
has, err := db.DefaultContext().Engine().ID(c.MilestoneID).Get(&milestone)
if err != nil {
return err
} else if has {
@@ -439,7 +444,7 @@ func (c *Comment) LoadMilestone() error {
// LoadPoster loads comment poster
func (c *Comment) LoadPoster() error {
return c.loadPoster(x)
return c.loadPoster(db.DefaultContext().Engine())
}
// LoadAttachments loads attachments
@@ -449,7 +454,7 @@ func (c *Comment) LoadAttachments() error {
}
var err error
c.Attachments, err = getAttachmentsByCommentID(x, c.ID)
c.Attachments, err = getAttachmentsByCommentID(db.DefaultContext().Engine(), c.ID)
if err != nil {
log.Error("getAttachmentsByCommentID[%d]: %v", c.ID, err)
}
@@ -458,7 +463,7 @@ func (c *Comment) LoadAttachments() error {
// UpdateAttachments update attachments by UUIDs for the comment
func (c *Comment) UpdateAttachments(uuids []string) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -482,7 +487,7 @@ func (c *Comment) LoadAssigneeUserAndTeam() error {
var err error
if c.AssigneeID > 0 && c.Assignee == nil {
c.Assignee, err = getUserByID(x, c.AssigneeID)
c.Assignee, err = getUserByID(db.DefaultContext().Engine(), c.AssigneeID)
if err != nil {
if !IsErrUserNotExist(err) {
return err
@@ -517,7 +522,7 @@ func (c *Comment) LoadResolveDoer() (err error) {
if c.ResolveDoerID == 0 || c.Type != CommentTypeCode {
return nil
}
c.ResolveDoer, err = getUserByID(x, c.ResolveDoerID)
c.ResolveDoer, err = getUserByID(db.DefaultContext().Engine(), c.ResolveDoerID)
if err != nil {
if IsErrUserNotExist(err) {
c.ResolveDoer = NewGhostUser()
@@ -537,7 +542,7 @@ func (c *Comment) LoadDepIssueDetails() (err error) {
if c.DependentIssueID <= 0 || c.DependentIssue != nil {
return nil
}
c.DependentIssue, err = getIssueByID(x, c.DependentIssueID)
c.DependentIssue, err = getIssueByID(db.DefaultContext().Engine(), c.DependentIssueID)
return err
}
@@ -551,7 +556,7 @@ func (c *Comment) LoadTime() error {
return err
}
func (c *Comment) loadReactions(e Engine, repo *Repository) (err error) {
func (c *Comment) loadReactions(e db.Engine, repo *Repository) (err error) {
if c.Reactions != nil {
return nil
}
@@ -571,10 +576,10 @@ func (c *Comment) loadReactions(e Engine, repo *Repository) (err error) {
// LoadReactions loads comment reactions
func (c *Comment) LoadReactions(repo *Repository) error {
return c.loadReactions(x, repo)
return c.loadReactions(db.DefaultContext().Engine(), repo)
}
func (c *Comment) loadReview(e Engine) (err error) {
func (c *Comment) loadReview(e db.Engine) (err error) {
if c.Review == nil {
if c.Review, err = getReviewByID(e, c.ReviewID); err != nil {
return err
@@ -586,7 +591,7 @@ func (c *Comment) loadReview(e Engine) (err error) {
// LoadReview loads the associated review
func (c *Comment) LoadReview() error {
return c.loadReview(x)
return c.loadReview(db.DefaultContext().Engine())
}
var notEnoughLines = regexp.MustCompile(`fatal: file .* has only \d+ lines?`)
@@ -637,7 +642,7 @@ func (c *Comment) CodeCommentURL() string {
log.Error("LoadIssue(%d): %v", c.IssueID, err)
return ""
}
err = c.Issue.loadRepo(x)
err = c.Issue.loadRepo(db.DefaultContext().Engine())
if err != nil { // Silently dropping errors :unamused:
log.Error("loadRepo(%d): %v", c.Issue.RepoID, err)
return ""
@@ -681,7 +686,7 @@ func (c *Comment) LoadPushCommits() (err error) {
return err
}
func createComment(e *xorm.Session, opts *CreateCommentOptions) (_ *Comment, err error) {
func createComment(e db.Engine, opts *CreateCommentOptions) (_ *Comment, err error) {
var LabelID int64
if opts.Label != nil {
LabelID = opts.Label.ID
@@ -740,7 +745,7 @@ func createComment(e *xorm.Session, opts *CreateCommentOptions) (_ *Comment, err
return comment, nil
}
func updateCommentInfos(e *xorm.Session, opts *CreateCommentOptions, comment *Comment) (err error) {
func updateCommentInfos(e db.Engine, opts *CreateCommentOptions, comment *Comment) (err error) {
// Check comment type.
switch opts.Type {
case CommentTypeCode:
@@ -894,7 +899,7 @@ type CreateCommentOptions struct {
// CreateComment creates comment of issue or commit.
func CreateComment(opts *CreateCommentOptions) (comment *Comment, err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return nil, err
@@ -919,7 +924,7 @@ func CreateRefComment(doer *User, repo *Repository, issue *Issue, content, commi
}
// Check if same reference from same commit has already existed.
has, err := x.Get(&Comment{
has, err := db.DefaultContext().Engine().Get(&Comment{
Type: CommentTypeCommitRef,
IssueID: issue.ID,
CommitSHA: commitSHA,
@@ -943,10 +948,10 @@ func CreateRefComment(doer *User, repo *Repository, issue *Issue, content, commi
// GetCommentByID returns the comment by given ID.
func GetCommentByID(id int64) (*Comment, error) {
return getCommentByID(x, id)
return getCommentByID(db.DefaultContext().Engine(), id)
}
func getCommentByID(e Engine, id int64) (*Comment, error) {
func getCommentByID(e db.Engine, id int64) (*Comment, error) {
c := new(Comment)
has, err := e.ID(id).Get(c)
if err != nil {
@@ -999,7 +1004,7 @@ func (opts *FindCommentsOptions) toConds() builder.Cond {
return cond
}
func findComments(e Engine, opts *FindCommentsOptions) ([]*Comment, error) {
func findComments(e db.Engine, opts *FindCommentsOptions) ([]*Comment, error) {
comments := make([]*Comment, 0, 10)
sess := e.Where(opts.toConds())
if opts.RepoID > 0 {
@@ -1020,12 +1025,12 @@ func findComments(e Engine, opts *FindCommentsOptions) ([]*Comment, error) {
// FindComments returns all comments according options
func FindComments(opts *FindCommentsOptions) ([]*Comment, error) {
return findComments(x, opts)
return findComments(db.DefaultContext().Engine(), opts)
}
// CountComments count all comments according options by ignoring pagination
func CountComments(opts *FindCommentsOptions) (int64, error) {
sess := x.Where(opts.toConds())
sess := db.DefaultContext().Engine().Where(opts.toConds())
if opts.RepoID > 0 {
sess.Join("INNER", "issue", "issue.id = comment.issue_id")
}
@@ -1034,7 +1039,7 @@ func CountComments(opts *FindCommentsOptions) (int64, error) {
// UpdateComment updates information of comment.
func UpdateComment(c *Comment, doer *User) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -1058,7 +1063,7 @@ func UpdateComment(c *Comment, doer *User) error {
// DeleteComment deletes the comment
func DeleteComment(comment *Comment) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -1071,7 +1076,7 @@ func DeleteComment(comment *Comment) error {
return sess.Commit()
}
func deleteComment(e Engine, comment *Comment) error {
func deleteComment(e db.Engine, comment *Comment) error {
if _, err := e.Delete(&Comment{
ID: comment.ID,
}); err != nil {
@@ -1097,11 +1102,11 @@ func deleteComment(e Engine, comment *Comment) error {
// CodeComments represents comments on code by using this structure: FILENAME -> LINE (+ == proposed; - == previous) -> COMMENTS
type CodeComments map[string]map[int64][]*Comment
func fetchCodeComments(e Engine, issue *Issue, currentUser *User) (CodeComments, error) {
func fetchCodeComments(e db.Engine, issue *Issue, currentUser *User) (CodeComments, error) {
return fetchCodeCommentsByReview(e, issue, currentUser, nil)
}
func fetchCodeCommentsByReview(e Engine, issue *Issue, currentUser *User, review *Review) (CodeComments, error) {
func fetchCodeCommentsByReview(e db.Engine, issue *Issue, currentUser *User, review *Review) (CodeComments, error) {
pathToLineToComment := make(CodeComments)
if review == nil {
review = &Review{ID: 0}
@@ -1126,7 +1131,7 @@ func fetchCodeCommentsByReview(e Engine, issue *Issue, currentUser *User, review
return pathToLineToComment, nil
}
func findCodeComments(e Engine, opts FindCommentsOptions, issue *Issue, currentUser *User, review *Review) ([]*Comment, error) {
func findCodeComments(e db.Engine, opts FindCommentsOptions, issue *Issue, currentUser *User, review *Review) ([]*Comment, error) {
var comments []*Comment
if review == nil {
review = &Review{ID: 0}
@@ -1202,17 +1207,17 @@ func FetchCodeCommentsByLine(issue *Issue, currentUser *User, treePath string, l
TreePath: treePath,
Line: line,
}
return findCodeComments(x, opts, issue, currentUser, nil)
return findCodeComments(db.DefaultContext().Engine(), opts, issue, currentUser, nil)
}
// FetchCodeComments will return a 2d-map: ["Path"]["Line"] = Comments at line
func FetchCodeComments(issue *Issue, currentUser *User) (CodeComments, error) {
return fetchCodeComments(x, issue, currentUser)
return fetchCodeComments(db.DefaultContext().Engine(), issue, currentUser)
}
// UpdateCommentsMigrationsByType updates comments' migrations information via given git service type and original id and poster id
func UpdateCommentsMigrationsByType(tp structs.GitServiceType, originalAuthorID string, posterID int64) error {
_, err := x.Table("comment").
_, err := db.DefaultContext().Engine().Table("comment").
Where(builder.In("issue_id",
builder.Select("issue.id").
From("issue").
+16 -14
View File
@@ -4,6 +4,8 @@
package models
import "code.gitea.io/gitea/models/db"
// CommentList defines a list of comments
type CommentList []*Comment
@@ -17,7 +19,7 @@ func (comments CommentList) getPosterIDs() []int64 {
return keysInt64(posterIDs)
}
func (comments CommentList) loadPosters(e Engine) error {
func (comments CommentList) loadPosters(e db.Engine) error {
if len(comments) == 0 {
return nil
}
@@ -70,7 +72,7 @@ func (comments CommentList) getLabelIDs() []int64 {
return keysInt64(ids)
}
func (comments CommentList) loadLabels(e Engine) error {
func (comments CommentList) loadLabels(e db.Engine) error {
if len(comments) == 0 {
return nil
}
@@ -120,7 +122,7 @@ func (comments CommentList) getMilestoneIDs() []int64 {
return keysInt64(ids)
}
func (comments CommentList) loadMilestones(e Engine) error {
func (comments CommentList) loadMilestones(e db.Engine) error {
if len(comments) == 0 {
return nil
}
@@ -163,7 +165,7 @@ func (comments CommentList) getOldMilestoneIDs() []int64 {
return keysInt64(ids)
}
func (comments CommentList) loadOldMilestones(e Engine) error {
func (comments CommentList) loadOldMilestones(e db.Engine) error {
if len(comments) == 0 {
return nil
}
@@ -206,7 +208,7 @@ func (comments CommentList) getAssigneeIDs() []int64 {
return keysInt64(ids)
}
func (comments CommentList) loadAssignees(e Engine) error {
func (comments CommentList) loadAssignees(e db.Engine) error {
if len(comments) == 0 {
return nil
}
@@ -280,7 +282,7 @@ func (comments CommentList) Issues() IssueList {
return issueList
}
func (comments CommentList) loadIssues(e Engine) error {
func (comments CommentList) loadIssues(e db.Engine) error {
if len(comments) == 0 {
return nil
}
@@ -337,7 +339,7 @@ func (comments CommentList) getDependentIssueIDs() []int64 {
return keysInt64(ids)
}
func (comments CommentList) loadDependentIssues(e Engine) error {
func (comments CommentList) loadDependentIssues(e db.Engine) error {
if len(comments) == 0 {
return nil
}
@@ -386,7 +388,7 @@ func (comments CommentList) loadDependentIssues(e Engine) error {
return nil
}
func (comments CommentList) loadAttachments(e Engine) (err error) {
func (comments CommentList) loadAttachments(e db.Engine) (err error) {
if len(comments) == 0 {
return nil
}
@@ -438,7 +440,7 @@ func (comments CommentList) getReviewIDs() []int64 {
return keysInt64(ids)
}
func (comments CommentList) loadReviews(e Engine) error {
func (comments CommentList) loadReviews(e db.Engine) error {
if len(comments) == 0 {
return nil
}
@@ -481,7 +483,7 @@ func (comments CommentList) loadReviews(e Engine) error {
}
// loadAttributes loads all attributes
func (comments CommentList) loadAttributes(e Engine) (err error) {
func (comments CommentList) loadAttributes(e db.Engine) (err error) {
if err = comments.loadPosters(e); err != nil {
return
}
@@ -524,20 +526,20 @@ func (comments CommentList) loadAttributes(e Engine) (err error) {
// LoadAttributes loads attributes of the comments, except for attachments and
// comments
func (comments CommentList) LoadAttributes() error {
return comments.loadAttributes(x)
return comments.loadAttributes(db.DefaultContext().Engine())
}
// LoadAttachments loads attachments
func (comments CommentList) LoadAttachments() error {
return comments.loadAttachments(x)
return comments.loadAttachments(db.DefaultContext().Engine())
}
// LoadPosters loads posters
func (comments CommentList) LoadPosters() error {
return comments.loadPosters(x)
return comments.loadPosters(db.DefaultContext().Engine())
}
// LoadIssues loads issues of comments
func (comments CommentList) LoadIssues() error {
return comments.loadIssues(x)
return comments.loadIssues(db.DefaultContext().Engine())
}
+13 -12
View File
@@ -8,15 +8,16 @@ import (
"testing"
"time"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestCreateComment(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{}).(*Issue)
repo := AssertExistsAndLoadBean(t, &Repository{ID: issue.RepoID}).(*Repository)
doer := AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User)
issue := db.AssertExistsAndLoadBean(t, &Issue{}).(*Issue)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: issue.RepoID}).(*Repository)
doer := db.AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User)
now := time.Now().Unix()
comment, err := CreateComment(&CreateCommentOptions{
@@ -33,18 +34,18 @@ func TestCreateComment(t *testing.T) {
assert.EqualValues(t, "Hello", comment.Content)
assert.EqualValues(t, issue.ID, comment.IssueID)
assert.EqualValues(t, doer.ID, comment.PosterID)
AssertInt64InRange(t, now, then, int64(comment.CreatedUnix))
AssertExistsAndLoadBean(t, comment) // assert actually added to DB
db.AssertInt64InRange(t, now, then, int64(comment.CreatedUnix))
db.AssertExistsAndLoadBean(t, comment) // assert actually added to DB
updatedIssue := AssertExistsAndLoadBean(t, &Issue{ID: issue.ID}).(*Issue)
AssertInt64InRange(t, now, then, int64(updatedIssue.UpdatedUnix))
updatedIssue := db.AssertExistsAndLoadBean(t, &Issue{ID: issue.ID}).(*Issue)
db.AssertInt64InRange(t, now, then, int64(updatedIssue.UpdatedUnix))
}
func TestFetchCodeComments(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{ID: 2}).(*Issue)
user := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 2}).(*Issue)
user := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
res, err := FetchCodeComments(issue, user)
assert.NoError(t, err)
assert.Contains(t, res, "README.md")
@@ -52,7 +53,7 @@ func TestFetchCodeComments(t *testing.T) {
assert.Len(t, res["README.md"][4], 1)
assert.Equal(t, int64(4), res["README.md"][4][0].ID)
user2 := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user2 := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
res, err = FetchCodeComments(issue, user2)
assert.NoError(t, err)
assert.Len(t, res, 1)
+12 -7
View File
@@ -5,6 +5,7 @@
package models
import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
@@ -20,6 +21,10 @@ type IssueDependency struct {
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
}
func init() {
db.RegisterModel(new(IssueDependency))
}
// DependencyType Defines Dependency Type Constants
type DependencyType int
@@ -31,7 +36,7 @@ const (
// CreateIssueDependency creates a new dependency for an issue
func CreateIssueDependency(user *User, issue, dep *Issue) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -72,7 +77,7 @@ func CreateIssueDependency(user *User, issue, dep *Issue) error {
// RemoveIssueDependency removes a dependency from an issue
func RemoveIssueDependency(user *User, issue, dep *Issue, depType DependencyType) (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -107,16 +112,16 @@ func RemoveIssueDependency(user *User, issue, dep *Issue, depType DependencyType
}
// Check if the dependency already exists
func issueDepExists(e Engine, issueID, depID int64) (bool, error) {
func issueDepExists(e db.Engine, issueID, depID int64) (bool, error) {
return e.Where("(issue_id = ? AND dependency_id = ?)", issueID, depID).Exist(&IssueDependency{})
}
// IssueNoDependenciesLeft checks if issue can be closed
func IssueNoDependenciesLeft(issue *Issue) (bool, error) {
return issueNoDependenciesLeft(x, issue)
return issueNoDependenciesLeft(db.DefaultContext().Engine(), issue)
}
func issueNoDependenciesLeft(e Engine, issue *Issue) (bool, error) {
func issueNoDependenciesLeft(e db.Engine, issue *Issue) (bool, error) {
exists, err := e.
Table("issue_dependency").
Select("issue.*").
@@ -130,10 +135,10 @@ func issueNoDependenciesLeft(e Engine, issue *Issue) (bool, error) {
// IsDependenciesEnabled returns if dependencies are enabled and returns the default setting if not set.
func (repo *Repository) IsDependenciesEnabled() bool {
return repo.isDependenciesEnabled(x)
return repo.isDependenciesEnabled(db.DefaultContext().Engine())
}
func (repo *Repository) isDependenciesEnabled(e Engine) bool {
func (repo *Repository) isDependenciesEnabled(e db.Engine) bool {
var u *RepoUnit
var err error
if u, err = repo.getUnit(e, UnitTypeIssues); err != nil {
+3 -2
View File
@@ -7,12 +7,13 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestCreateIssueDependency(t *testing.T) {
// Prepare
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1, err := GetUserByID(1)
assert.NoError(t, err)
@@ -37,7 +38,7 @@ func TestCreateIssueDependency(t *testing.T) {
assert.Error(t, err)
assert.True(t, IsErrCircularDependency(err))
_ = AssertExistsAndLoadBean(t, &Comment{Type: CommentTypeAddDependency, PosterID: user1.ID, IssueID: issue1.ID})
_ = db.AssertExistsAndLoadBean(t, &Comment{Type: CommentTypeAddDependency, PosterID: user1.ID, IssueID: issue1.ID})
// Check if dependencies left is correct
left, err := IssueNoDependenciesLeft(issue1)
+57 -51
View File
@@ -13,10 +13,10 @@ import (
"strconv"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
"xorm.io/xorm"
)
// LabelColorPattern is a regexp witch can validate LabelColor
@@ -43,6 +43,11 @@ type Label struct {
IsExcluded bool `xorm:"-"`
}
func init() {
db.RegisterModel(new(Label))
db.RegisterModel(new(IssueLabel))
}
// GetLabelTemplateFile loads the label template file by given name,
// then parses and returns a list of name-color pairs and optionally description.
func GetLabelTemplateFile(name string) ([][3]string, error) {
@@ -209,7 +214,7 @@ func LoadLabelsFormatted(labelTemplate string) (string, error) {
return strings.Join(labels, ", "), err
}
func initializeLabels(e Engine, id int64, labelTemplate string, isOrg bool) error {
func initializeLabels(e db.Engine, id int64, labelTemplate string, isOrg bool) error {
list, err := GetLabelTemplateFile(labelTemplate)
if err != nil {
return err
@@ -237,11 +242,11 @@ func initializeLabels(e Engine, id int64, labelTemplate string, isOrg bool) erro
}
// InitializeLabels adds a label set to a repository using a template
func InitializeLabels(ctx DBContext, repoID int64, labelTemplate string, isOrg bool) error {
return initializeLabels(ctx.e, repoID, labelTemplate, isOrg)
func InitializeLabels(ctx *db.Context, repoID int64, labelTemplate string, isOrg bool) error {
return initializeLabels(ctx.Engine(), repoID, labelTemplate, isOrg)
}
func newLabel(e Engine, label *Label) error {
func newLabel(e db.Engine, label *Label) error {
_, err := e.Insert(label)
return err
}
@@ -251,25 +256,26 @@ func NewLabel(label *Label) error {
if !LabelColorPattern.MatchString(label.Color) {
return fmt.Errorf("bad color code: %s", label.Color)
}
return newLabel(x, label)
return newLabel(db.DefaultContext().Engine(), label)
}
// NewLabels creates new labels
func NewLabels(labels ...*Label) error {
sess := x.NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
for _, label := range labels {
if !LabelColorPattern.MatchString(label.Color) {
return fmt.Errorf("bad color code: %s", label.Color)
}
if err := newLabel(sess, label); err != nil {
if err := newLabel(ctx.Engine(), label); err != nil {
return err
}
}
return sess.Commit()
return committer.Commit()
}
// UpdateLabel updates label information.
@@ -277,7 +283,7 @@ func UpdateLabel(l *Label) error {
if !LabelColorPattern.MatchString(l.Color) {
return fmt.Errorf("bad color code: %s", l.Color)
}
return updateLabelCols(x, l, "name", "description", "color")
return updateLabelCols(db.DefaultContext().Engine(), l, "name", "description", "color")
}
// DeleteLabel delete a label
@@ -290,7 +296,7 @@ func DeleteLabel(id, labelID int64) error {
return err
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -320,7 +326,7 @@ func DeleteLabel(id, labelID int64) error {
}
// getLabelByID returns a label by label id
func getLabelByID(e Engine, labelID int64) (*Label, error) {
func getLabelByID(e db.Engine, labelID int64) (*Label, error) {
if labelID <= 0 {
return nil, ErrLabelNotExist{labelID}
}
@@ -337,13 +343,13 @@ func getLabelByID(e Engine, labelID int64) (*Label, error) {
// GetLabelByID returns a label by given ID.
func GetLabelByID(id int64) (*Label, error) {
return getLabelByID(x, id)
return getLabelByID(db.DefaultContext().Engine(), id)
}
// GetLabelsByIDs returns a list of labels by IDs
func GetLabelsByIDs(labelIDs []int64) ([]*Label, error) {
labels := make([]*Label, 0, len(labelIDs))
return labels, x.Table("label").
return labels, db.DefaultContext().Engine().Table("label").
In("id", labelIDs).
Asc("name").
Cols("id", "repo_id", "org_id").
@@ -358,7 +364,7 @@ func GetLabelsByIDs(labelIDs []int64) ([]*Label, error) {
// \/ \/|__| \/ \/
// getLabelInRepoByName returns a label by Name in given repository.
func getLabelInRepoByName(e Engine, repoID int64, labelName string) (*Label, error) {
func getLabelInRepoByName(e db.Engine, repoID int64, labelName string) (*Label, error) {
if len(labelName) == 0 || repoID <= 0 {
return nil, ErrRepoLabelNotExist{0, repoID}
}
@@ -377,7 +383,7 @@ func getLabelInRepoByName(e Engine, repoID int64, labelName string) (*Label, err
}
// getLabelInRepoByID returns a label by ID in given repository.
func getLabelInRepoByID(e Engine, repoID, labelID int64) (*Label, error) {
func getLabelInRepoByID(e db.Engine, repoID, labelID int64) (*Label, error) {
if labelID <= 0 || repoID <= 0 {
return nil, ErrRepoLabelNotExist{labelID, repoID}
}
@@ -397,7 +403,7 @@ func getLabelInRepoByID(e Engine, repoID, labelID int64) (*Label, error) {
// GetLabelInRepoByName returns a label by name in given repository.
func GetLabelInRepoByName(repoID int64, labelName string) (*Label, error) {
return getLabelInRepoByName(x, repoID, labelName)
return getLabelInRepoByName(db.DefaultContext().Engine(), repoID, labelName)
}
// GetLabelIDsInRepoByNames returns a list of labelIDs by names in a given
@@ -405,7 +411,7 @@ func GetLabelInRepoByName(repoID int64, labelName string) (*Label, error) {
// it silently ignores label names that do not belong to the repository.
func GetLabelIDsInRepoByNames(repoID int64, labelNames []string) ([]int64, error) {
labelIDs := make([]int64, 0, len(labelNames))
return labelIDs, x.Table("label").
return labelIDs, db.DefaultContext().Engine().Table("label").
Where("repo_id = ?", repoID).
In("name", labelNames).
Asc("name").
@@ -426,21 +432,21 @@ func BuildLabelNamesIssueIDsCondition(labelNames []string) *builder.Builder {
// GetLabelInRepoByID returns a label by ID in given repository.
func GetLabelInRepoByID(repoID, labelID int64) (*Label, error) {
return getLabelInRepoByID(x, repoID, labelID)
return getLabelInRepoByID(db.DefaultContext().Engine(), repoID, labelID)
}
// GetLabelsInRepoByIDs returns a list of labels by IDs in given repository,
// it silently ignores label IDs that do not belong to the repository.
func GetLabelsInRepoByIDs(repoID int64, labelIDs []int64) ([]*Label, error) {
labels := make([]*Label, 0, len(labelIDs))
return labels, x.
return labels, db.DefaultContext().Engine().
Where("repo_id = ?", repoID).
In("id", labelIDs).
Asc("name").
Find(&labels)
}
func getLabelsByRepoID(e Engine, repoID int64, sortType string, listOptions ListOptions) ([]*Label, error) {
func getLabelsByRepoID(e db.Engine, repoID int64, sortType string, listOptions ListOptions) ([]*Label, error) {
if repoID <= 0 {
return nil, ErrRepoLabelNotExist{0, repoID}
}
@@ -467,12 +473,12 @@ func getLabelsByRepoID(e Engine, repoID int64, sortType string, listOptions List
// GetLabelsByRepoID returns all labels that belong to given repository by ID.
func GetLabelsByRepoID(repoID int64, sortType string, listOptions ListOptions) ([]*Label, error) {
return getLabelsByRepoID(x, repoID, sortType, listOptions)
return getLabelsByRepoID(db.DefaultContext().Engine(), repoID, sortType, listOptions)
}
// CountLabelsByRepoID count number of all labels that belong to given repository by ID.
func CountLabelsByRepoID(repoID int64) (int64, error) {
return x.Where("repo_id = ?", repoID).Count(&Label{})
return db.DefaultContext().Engine().Where("repo_id = ?", repoID).Count(&Label{})
}
// ________
@@ -483,7 +489,7 @@ func CountLabelsByRepoID(repoID int64) (int64, error) {
// \/ /_____/
// getLabelInOrgByName returns a label by Name in given organization
func getLabelInOrgByName(e Engine, orgID int64, labelName string) (*Label, error) {
func getLabelInOrgByName(e db.Engine, orgID int64, labelName string) (*Label, error) {
if len(labelName) == 0 || orgID <= 0 {
return nil, ErrOrgLabelNotExist{0, orgID}
}
@@ -502,7 +508,7 @@ func getLabelInOrgByName(e Engine, orgID int64, labelName string) (*Label, error
}
// getLabelInOrgByID returns a label by ID in given organization.
func getLabelInOrgByID(e Engine, orgID, labelID int64) (*Label, error) {
func getLabelInOrgByID(e db.Engine, orgID, labelID int64) (*Label, error) {
if labelID <= 0 || orgID <= 0 {
return nil, ErrOrgLabelNotExist{labelID, orgID}
}
@@ -522,7 +528,7 @@ func getLabelInOrgByID(e Engine, orgID, labelID int64) (*Label, error) {
// GetLabelInOrgByName returns a label by name in given organization.
func GetLabelInOrgByName(orgID int64, labelName string) (*Label, error) {
return getLabelInOrgByName(x, orgID, labelName)
return getLabelInOrgByName(db.DefaultContext().Engine(), orgID, labelName)
}
// GetLabelIDsInOrgByNames returns a list of labelIDs by names in a given
@@ -533,7 +539,7 @@ func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error)
}
labelIDs := make([]int64, 0, len(labelNames))
return labelIDs, x.Table("label").
return labelIDs, db.DefaultContext().Engine().Table("label").
Where("org_id = ?", orgID).
In("name", labelNames).
Asc("name").
@@ -543,21 +549,21 @@ func GetLabelIDsInOrgByNames(orgID int64, labelNames []string) ([]int64, error)
// GetLabelInOrgByID returns a label by ID in given organization.
func GetLabelInOrgByID(orgID, labelID int64) (*Label, error) {
return getLabelInOrgByID(x, orgID, labelID)
return getLabelInOrgByID(db.DefaultContext().Engine(), orgID, labelID)
}
// GetLabelsInOrgByIDs returns a list of labels by IDs in given organization,
// it silently ignores label IDs that do not belong to the organization.
func GetLabelsInOrgByIDs(orgID int64, labelIDs []int64) ([]*Label, error) {
labels := make([]*Label, 0, len(labelIDs))
return labels, x.
return labels, db.DefaultContext().Engine().
Where("org_id = ?", orgID).
In("id", labelIDs).
Asc("name").
Find(&labels)
}
func getLabelsByOrgID(e Engine, orgID int64, sortType string, listOptions ListOptions) ([]*Label, error) {
func getLabelsByOrgID(e db.Engine, orgID int64, sortType string, listOptions ListOptions) ([]*Label, error) {
if orgID <= 0 {
return nil, ErrOrgLabelNotExist{0, orgID}
}
@@ -584,12 +590,12 @@ func getLabelsByOrgID(e Engine, orgID int64, sortType string, listOptions ListOp
// GetLabelsByOrgID returns all labels that belong to given organization by ID.
func GetLabelsByOrgID(orgID int64, sortType string, listOptions ListOptions) ([]*Label, error) {
return getLabelsByOrgID(x, orgID, sortType, listOptions)
return getLabelsByOrgID(db.DefaultContext().Engine(), orgID, sortType, listOptions)
}
// CountLabelsByOrgID count all labels that belong to given organization by ID.
func CountLabelsByOrgID(orgID int64) (int64, error) {
return x.Where("org_id = ?", orgID).Count(&Label{})
return db.DefaultContext().Engine().Where("org_id = ?", orgID).Count(&Label{})
}
// .___
@@ -599,7 +605,7 @@ func CountLabelsByOrgID(orgID int64) (int64, error) {
// |___/____ >____ >____/ \___ |
// \/ \/ \/
func getLabelsByIssueID(e Engine, issueID int64) ([]*Label, error) {
func getLabelsByIssueID(e db.Engine, issueID int64) ([]*Label, error) {
var labels []*Label
return labels, e.Where("issue_label.issue_id = ?", issueID).
Join("LEFT", "issue_label", "issue_label.label_id = label.id").
@@ -609,10 +615,10 @@ func getLabelsByIssueID(e Engine, issueID int64) ([]*Label, error) {
// GetLabelsByIssueID returns all labels that belong to given issue by ID.
func GetLabelsByIssueID(issueID int64) ([]*Label, error) {
return getLabelsByIssueID(x, issueID)
return getLabelsByIssueID(db.DefaultContext().Engine(), issueID)
}
func updateLabelCols(e Engine, l *Label, cols ...string) error {
func updateLabelCols(e db.Engine, l *Label, cols ...string) error {
_, err := e.ID(l.ID).
SetExpr("num_issues",
builder.Select("count(*)").From("issue_label").
@@ -644,19 +650,19 @@ type IssueLabel struct {
LabelID int64 `xorm:"UNIQUE(s)"`
}
func hasIssueLabel(e Engine, issueID, labelID int64) bool {
func hasIssueLabel(e db.Engine, issueID, labelID int64) bool {
has, _ := e.Where("issue_id = ? AND label_id = ?", issueID, labelID).Get(new(IssueLabel))
return has
}
// HasIssueLabel returns true if issue has been labeled.
func HasIssueLabel(issueID, labelID int64) bool {
return hasIssueLabel(x, issueID, labelID)
return hasIssueLabel(db.DefaultContext().Engine(), issueID, labelID)
}
// newIssueLabel this function creates a new label it does not check if the label is valid for the issue
// YOU MUST CHECK THIS BEFORE THIS FUNCTION
func newIssueLabel(e *xorm.Session, issue *Issue, label *Label, doer *User) (err error) {
func newIssueLabel(e db.Engine, issue *Issue, label *Label, doer *User) (err error) {
if _, err = e.Insert(&IssueLabel{
IssueID: issue.ID,
LabelID: label.ID,
@@ -689,7 +695,7 @@ func NewIssueLabel(issue *Issue, label *Label, doer *User) (err error) {
return nil
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -717,7 +723,7 @@ func NewIssueLabel(issue *Issue, label *Label, doer *User) (err error) {
}
// newIssueLabels add labels to an issue. It will check if the labels are valid for the issue
func newIssueLabels(e *xorm.Session, issue *Issue, labels []*Label, doer *User) (err error) {
func newIssueLabels(e db.Engine, issue *Issue, labels []*Label, doer *User) (err error) {
if err = issue.loadRepo(e); err != nil {
return err
}
@@ -738,25 +744,25 @@ func newIssueLabels(e *xorm.Session, issue *Issue, labels []*Label, doer *User)
// NewIssueLabels creates a list of issue-label relations.
func NewIssueLabels(issue *Issue, labels []*Label, doer *User) (err error) {
sess := x.NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
ctx, committer, err := db.TxContext()
if err != nil {
return err
}
defer committer.Close()
if err = newIssueLabels(sess, issue, labels, doer); err != nil {
if err = newIssueLabels(ctx.Engine(), issue, labels, doer); err != nil {
return err
}
issue.Labels = nil
if err = issue.loadLabels(sess); err != nil {
if err = issue.loadLabels(ctx.Engine()); err != nil {
return err
}
return sess.Commit()
return committer.Commit()
}
func deleteIssueLabel(e *xorm.Session, issue *Issue, label *Label, doer *User) (err error) {
func deleteIssueLabel(e db.Engine, issue *Issue, label *Label, doer *User) (err error) {
if count, err := e.Delete(&IssueLabel{
IssueID: issue.ID,
LabelID: label.ID,
@@ -786,7 +792,7 @@ func deleteIssueLabel(e *xorm.Session, issue *Issue, label *Label, doer *User) (
// DeleteIssueLabel deletes issue-label relation.
func DeleteIssueLabel(issue *Issue, label *Label, doer *User) (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -804,7 +810,7 @@ func DeleteIssueLabel(issue *Issue, label *Label, doer *User) (err error) {
return sess.Commit()
}
func deleteLabelsByRepoID(sess Engine, repoID int64) error {
func deleteLabelsByRepoID(sess db.Engine, repoID int64) error {
deleteCond := builder.Select("id").From("label").Where(builder.Eq{"label.repo_id": repoID})
if _, err := sess.In("label_id", deleteCond).
+66 -65
View File
@@ -8,29 +8,30 @@ import (
"html/template"
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
// TODO TestGetLabelTemplateFile
func TestLabel_CalOpenIssues(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
label := AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
assert.NoError(t, db.PrepareTestDatabase())
label := db.AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
label.CalOpenIssues()
assert.EqualValues(t, 2, label.NumOpenIssues)
}
func TestLabel_ForegroundColor(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
label := AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
assert.NoError(t, db.PrepareTestDatabase())
label := db.AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
assert.Equal(t, template.CSS("#000"), label.ForegroundColor())
label = AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
label = db.AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
assert.Equal(t, template.CSS("#fff"), label.ForegroundColor())
}
func TestNewLabels(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
labels := []*Label{
{RepoID: 2, Name: "labelName2", Color: "#123456"},
{RepoID: 3, Name: "labelName3", Color: "#23456F"},
@@ -39,27 +40,27 @@ func TestNewLabels(t *testing.T) {
assert.Error(t, NewLabel(&Label{RepoID: 3, Name: "invalid Color", Color: "123456"}))
assert.Error(t, NewLabel(&Label{RepoID: 3, Name: "invalid Color", Color: "#12345G"}))
for _, label := range labels {
AssertNotExistsBean(t, label)
db.AssertNotExistsBean(t, label)
}
assert.NoError(t, NewLabels(labels...))
for _, label := range labels {
AssertExistsAndLoadBean(t, label, Cond("id = ?", label.ID))
db.AssertExistsAndLoadBean(t, label, db.Cond("id = ?", label.ID))
}
CheckConsistencyFor(t, &Label{}, &Repository{})
}
func TestGetLabelByID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
label, err := GetLabelByID(1)
assert.NoError(t, err)
assert.EqualValues(t, 1, label.ID)
_, err = GetLabelByID(NonexistentID)
_, err = GetLabelByID(db.NonexistentID)
assert.True(t, IsErrLabelNotExist(err))
}
func TestGetLabelInRepoByName(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
label, err := GetLabelInRepoByName(1, "label1")
assert.NoError(t, err)
assert.EqualValues(t, 1, label.ID)
@@ -68,12 +69,12 @@ func TestGetLabelInRepoByName(t *testing.T) {
_, err = GetLabelInRepoByName(1, "")
assert.True(t, IsErrRepoLabelNotExist(err))
_, err = GetLabelInRepoByName(NonexistentID, "nonexistent")
_, err = GetLabelInRepoByName(db.NonexistentID, "nonexistent")
assert.True(t, IsErrRepoLabelNotExist(err))
}
func TestGetLabelInRepoByNames(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
labelIDs, err := GetLabelIDsInRepoByNames(1, []string{"label1", "label2"})
assert.NoError(t, err)
@@ -84,7 +85,7 @@ func TestGetLabelInRepoByNames(t *testing.T) {
}
func TestGetLabelInRepoByNamesDiscardsNonExistentLabels(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
// label3 doesn't exists.. See labels.yml
labelIDs, err := GetLabelIDsInRepoByNames(1, []string{"label1", "label2", "label3"})
assert.NoError(t, err)
@@ -97,7 +98,7 @@ func TestGetLabelInRepoByNamesDiscardsNonExistentLabels(t *testing.T) {
}
func TestGetLabelInRepoByID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
label, err := GetLabelInRepoByID(1, 1)
assert.NoError(t, err)
assert.EqualValues(t, 1, label.ID)
@@ -105,13 +106,13 @@ func TestGetLabelInRepoByID(t *testing.T) {
_, err = GetLabelInRepoByID(1, -1)
assert.True(t, IsErrRepoLabelNotExist(err))
_, err = GetLabelInRepoByID(NonexistentID, NonexistentID)
_, err = GetLabelInRepoByID(db.NonexistentID, db.NonexistentID)
assert.True(t, IsErrRepoLabelNotExist(err))
}
func TestGetLabelsInRepoByIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
labels, err := GetLabelsInRepoByIDs(1, []int64{1, 2, NonexistentID})
assert.NoError(t, db.PrepareTestDatabase())
labels, err := GetLabelsInRepoByIDs(1, []int64{1, 2, db.NonexistentID})
assert.NoError(t, err)
if assert.Len(t, labels, 2) {
assert.EqualValues(t, 1, labels[0].ID)
@@ -120,7 +121,7 @@ func TestGetLabelsInRepoByIDs(t *testing.T) {
}
func TestGetLabelsByRepoID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(repoID int64, sortType string, expectedIssueIDs []int64) {
labels, err := GetLabelsByRepoID(repoID, sortType, ListOptions{})
assert.NoError(t, err)
@@ -138,7 +139,7 @@ func TestGetLabelsByRepoID(t *testing.T) {
// Org versions
func TestGetLabelInOrgByName(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
label, err := GetLabelInOrgByName(3, "orglabel3")
assert.NoError(t, err)
assert.EqualValues(t, 3, label.ID)
@@ -153,12 +154,12 @@ func TestGetLabelInOrgByName(t *testing.T) {
_, err = GetLabelInOrgByName(-1, "orglabel3")
assert.True(t, IsErrOrgLabelNotExist(err))
_, err = GetLabelInOrgByName(NonexistentID, "nonexistent")
_, err = GetLabelInOrgByName(db.NonexistentID, "nonexistent")
assert.True(t, IsErrOrgLabelNotExist(err))
}
func TestGetLabelInOrgByNames(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
labelIDs, err := GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4"})
assert.NoError(t, err)
@@ -169,7 +170,7 @@ func TestGetLabelInOrgByNames(t *testing.T) {
}
func TestGetLabelInOrgByNamesDiscardsNonExistentLabels(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
// orglabel99 doesn't exists.. See labels.yml
labelIDs, err := GetLabelIDsInOrgByNames(3, []string{"orglabel3", "orglabel4", "orglabel99"})
assert.NoError(t, err)
@@ -182,7 +183,7 @@ func TestGetLabelInOrgByNamesDiscardsNonExistentLabels(t *testing.T) {
}
func TestGetLabelInOrgByID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
label, err := GetLabelInOrgByID(3, 3)
assert.NoError(t, err)
assert.EqualValues(t, 3, label.ID)
@@ -196,13 +197,13 @@ func TestGetLabelInOrgByID(t *testing.T) {
_, err = GetLabelInOrgByID(-1, 3)
assert.True(t, IsErrOrgLabelNotExist(err))
_, err = GetLabelInOrgByID(NonexistentID, NonexistentID)
_, err = GetLabelInOrgByID(db.NonexistentID, db.NonexistentID)
assert.True(t, IsErrOrgLabelNotExist(err))
}
func TestGetLabelsInOrgByIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
labels, err := GetLabelsInOrgByIDs(3, []int64{3, 4, NonexistentID})
assert.NoError(t, db.PrepareTestDatabase())
labels, err := GetLabelsInOrgByIDs(3, []int64{3, 4, db.NonexistentID})
assert.NoError(t, err)
if assert.Len(t, labels, 2) {
assert.EqualValues(t, 3, labels[0].ID)
@@ -211,7 +212,7 @@ func TestGetLabelsInOrgByIDs(t *testing.T) {
}
func TestGetLabelsByOrgID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(orgID int64, sortType string, expectedIssueIDs []int64) {
labels, err := GetLabelsByOrgID(orgID, sortType, ListOptions{})
assert.NoError(t, err)
@@ -236,21 +237,21 @@ func TestGetLabelsByOrgID(t *testing.T) {
//
func TestGetLabelsByIssueID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
labels, err := GetLabelsByIssueID(1)
assert.NoError(t, err)
if assert.Len(t, labels, 1) {
assert.EqualValues(t, 1, labels[0].ID)
}
labels, err = GetLabelsByIssueID(NonexistentID)
labels, err = GetLabelsByIssueID(db.NonexistentID)
assert.NoError(t, err)
assert.Len(t, labels, 0)
}
func TestUpdateLabel(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
label := AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
assert.NoError(t, db.PrepareTestDatabase())
label := db.AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
// make sure update wont overwrite it
update := &Label{
ID: label.ID,
@@ -261,7 +262,7 @@ func TestUpdateLabel(t *testing.T) {
label.Color = update.Color
label.Name = update.Name
assert.NoError(t, UpdateLabel(update))
newLabel := AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
newLabel := db.AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
assert.EqualValues(t, label.ID, newLabel.ID)
assert.EqualValues(t, label.Color, newLabel.Color)
assert.EqualValues(t, label.Name, newLabel.Name)
@@ -270,43 +271,43 @@ func TestUpdateLabel(t *testing.T) {
}
func TestDeleteLabel(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
label := AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
assert.NoError(t, db.PrepareTestDatabase())
label := db.AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
assert.NoError(t, DeleteLabel(label.RepoID, label.ID))
AssertNotExistsBean(t, &Label{ID: label.ID, RepoID: label.RepoID})
db.AssertNotExistsBean(t, &Label{ID: label.ID, RepoID: label.RepoID})
assert.NoError(t, DeleteLabel(label.RepoID, label.ID))
AssertNotExistsBean(t, &Label{ID: label.ID})
db.AssertNotExistsBean(t, &Label{ID: label.ID})
assert.NoError(t, DeleteLabel(NonexistentID, NonexistentID))
assert.NoError(t, DeleteLabel(db.NonexistentID, db.NonexistentID))
CheckConsistencyFor(t, &Label{}, &Repository{})
}
func TestHasIssueLabel(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
assert.True(t, HasIssueLabel(1, 1))
assert.False(t, HasIssueLabel(1, 2))
assert.False(t, HasIssueLabel(NonexistentID, NonexistentID))
assert.False(t, HasIssueLabel(db.NonexistentID, db.NonexistentID))
}
func TestNewIssueLabel(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
label := AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
issue := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
doer := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
label := db.AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
doer := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
// add new IssueLabel
prevNumIssues := label.NumIssues
assert.NoError(t, NewIssueLabel(issue, label, doer))
AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issue.ID, LabelID: label.ID})
AssertExistsAndLoadBean(t, &Comment{
db.AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issue.ID, LabelID: label.ID})
db.AssertExistsAndLoadBean(t, &Comment{
Type: CommentTypeLabel,
PosterID: doer.ID,
IssueID: issue.ID,
LabelID: label.ID,
Content: "1",
})
label = AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
label = db.AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
assert.EqualValues(t, prevNumIssues+1, label.NumIssues)
// re-add existing IssueLabel
@@ -315,26 +316,26 @@ func TestNewIssueLabel(t *testing.T) {
}
func TestNewIssueLabels(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
label1 := AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
label2 := AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
issue := AssertExistsAndLoadBean(t, &Issue{ID: 5}).(*Issue)
doer := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
label1 := db.AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
label2 := db.AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 5}).(*Issue)
doer := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NoError(t, NewIssueLabels(issue, []*Label{label1, label2}, doer))
AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issue.ID, LabelID: label1.ID})
AssertExistsAndLoadBean(t, &Comment{
db.AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issue.ID, LabelID: label1.ID})
db.AssertExistsAndLoadBean(t, &Comment{
Type: CommentTypeLabel,
PosterID: doer.ID,
IssueID: issue.ID,
LabelID: label1.ID,
Content: "1",
})
AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issue.ID, LabelID: label1.ID})
label1 = AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
db.AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issue.ID, LabelID: label1.ID})
label1 = db.AssertExistsAndLoadBean(t, &Label{ID: 1}).(*Label)
assert.EqualValues(t, 3, label1.NumIssues)
assert.EqualValues(t, 1, label1.NumClosedIssues)
label2 = AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
label2 = db.AssertExistsAndLoadBean(t, &Label{ID: 2}).(*Label)
assert.EqualValues(t, 1, label2.NumIssues)
assert.EqualValues(t, 1, label2.NumClosedIssues)
@@ -345,15 +346,15 @@ func TestNewIssueLabels(t *testing.T) {
}
func TestDeleteIssueLabel(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(labelID, issueID, doerID int64) {
label := AssertExistsAndLoadBean(t, &Label{ID: labelID}).(*Label)
issue := AssertExistsAndLoadBean(t, &Issue{ID: issueID}).(*Issue)
doer := AssertExistsAndLoadBean(t, &User{ID: doerID}).(*User)
label := db.AssertExistsAndLoadBean(t, &Label{ID: labelID}).(*Label)
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: issueID}).(*Issue)
doer := db.AssertExistsAndLoadBean(t, &User{ID: doerID}).(*User)
expectedNumIssues := label.NumIssues
expectedNumClosedIssues := label.NumClosedIssues
if BeanExists(t, &IssueLabel{IssueID: issueID, LabelID: labelID}) {
if db.BeanExists(t, &IssueLabel{IssueID: issueID, LabelID: labelID}) {
expectedNumIssues--
if issue.IsClosed {
expectedNumClosedIssues--
@@ -361,14 +362,14 @@ func TestDeleteIssueLabel(t *testing.T) {
}
assert.NoError(t, DeleteIssueLabel(issue, label, doer))
AssertNotExistsBean(t, &IssueLabel{IssueID: issueID, LabelID: labelID})
AssertExistsAndLoadBean(t, &Comment{
db.AssertNotExistsBean(t, &IssueLabel{IssueID: issueID, LabelID: labelID})
db.AssertExistsAndLoadBean(t, &Comment{
Type: CommentTypeLabel,
PosterID: doerID,
IssueID: issueID,
LabelID: labelID,
}, `content=""`)
label = AssertExistsAndLoadBean(t, &Label{ID: labelID}).(*Label)
label = db.AssertExistsAndLoadBean(t, &Label{ID: labelID}).(*Label)
assert.EqualValues(t, expectedNumIssues, label.NumIssues)
assert.EqualValues(t, expectedNumClosedIssues, label.NumClosedIssues)
}
+19 -18
View File
@@ -7,6 +7,7 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"xorm.io/builder"
)
@@ -28,7 +29,7 @@ func (issues IssueList) getRepoIDs() []int64 {
return keysInt64(repoIDs)
}
func (issues IssueList) loadRepositories(e Engine) ([]*Repository, error) {
func (issues IssueList) loadRepositories(e db.Engine) ([]*Repository, error) {
if len(issues) == 0 {
return nil, nil
}
@@ -62,7 +63,7 @@ func (issues IssueList) loadRepositories(e Engine) ([]*Repository, error) {
// LoadRepositories loads issues' all repositories
func (issues IssueList) LoadRepositories() ([]*Repository, error) {
return issues.loadRepositories(x)
return issues.loadRepositories(db.DefaultContext().Engine())
}
func (issues IssueList) getPosterIDs() []int64 {
@@ -75,7 +76,7 @@ func (issues IssueList) getPosterIDs() []int64 {
return keysInt64(posterIDs)
}
func (issues IssueList) loadPosters(e Engine) error {
func (issues IssueList) loadPosters(e db.Engine) error {
if len(issues) == 0 {
return nil
}
@@ -118,7 +119,7 @@ func (issues IssueList) getIssueIDs() []int64 {
return ids
}
func (issues IssueList) loadLabels(e Engine) error {
func (issues IssueList) loadLabels(e db.Engine) error {
if len(issues) == 0 {
return nil
}
@@ -181,7 +182,7 @@ func (issues IssueList) getMilestoneIDs() []int64 {
return keysInt64(ids)
}
func (issues IssueList) loadMilestones(e Engine) error {
func (issues IssueList) loadMilestones(e db.Engine) error {
milestoneIDs := issues.getMilestoneIDs()
if len(milestoneIDs) == 0 {
return nil
@@ -210,7 +211,7 @@ func (issues IssueList) loadMilestones(e Engine) error {
return nil
}
func (issues IssueList) loadAssignees(e Engine) error {
func (issues IssueList) loadAssignees(e db.Engine) error {
if len(issues) == 0 {
return nil
}
@@ -271,7 +272,7 @@ func (issues IssueList) getPullIssueIDs() []int64 {
return ids
}
func (issues IssueList) loadPullRequests(e Engine) error {
func (issues IssueList) loadPullRequests(e db.Engine) error {
issuesIDs := issues.getPullIssueIDs()
if len(issuesIDs) == 0 {
return nil
@@ -315,7 +316,7 @@ func (issues IssueList) loadPullRequests(e Engine) error {
return nil
}
func (issues IssueList) loadAttachments(e Engine) (err error) {
func (issues IssueList) loadAttachments(e db.Engine) (err error) {
if len(issues) == 0 {
return nil
}
@@ -360,7 +361,7 @@ func (issues IssueList) loadAttachments(e Engine) (err error) {
return nil
}
func (issues IssueList) loadComments(e Engine, cond builder.Cond) (err error) {
func (issues IssueList) loadComments(e db.Engine, cond builder.Cond) (err error) {
if len(issues) == 0 {
return nil
}
@@ -406,7 +407,7 @@ func (issues IssueList) loadComments(e Engine, cond builder.Cond) (err error) {
return nil
}
func (issues IssueList) loadTotalTrackedTimes(e Engine) (err error) {
func (issues IssueList) loadTotalTrackedTimes(e db.Engine) (err error) {
type totalTimesByIssue struct {
IssueID int64
Time int64
@@ -466,7 +467,7 @@ func (issues IssueList) loadTotalTrackedTimes(e Engine) (err error) {
}
// loadAttributes loads all attributes, expect for attachments and comments
func (issues IssueList) loadAttributes(e Engine) error {
func (issues IssueList) loadAttributes(e db.Engine) error {
if _, err := issues.loadRepositories(e); err != nil {
return fmt.Errorf("issue.loadAttributes: loadRepositories: %v", err)
}
@@ -501,36 +502,36 @@ func (issues IssueList) loadAttributes(e Engine) error {
// LoadAttributes loads attributes of the issues, except for attachments and
// comments
func (issues IssueList) LoadAttributes() error {
return issues.loadAttributes(x)
return issues.loadAttributes(db.DefaultContext().Engine())
}
// LoadAttachments loads attachments
func (issues IssueList) LoadAttachments() error {
return issues.loadAttachments(x)
return issues.loadAttachments(db.DefaultContext().Engine())
}
// LoadComments loads comments
func (issues IssueList) LoadComments() error {
return issues.loadComments(x, builder.NewCond())
return issues.loadComments(db.DefaultContext().Engine(), builder.NewCond())
}
// LoadDiscussComments loads discuss comments
func (issues IssueList) LoadDiscussComments() error {
return issues.loadComments(x, builder.Eq{"comment.type": CommentTypeComment})
return issues.loadComments(db.DefaultContext().Engine(), builder.Eq{"comment.type": CommentTypeComment})
}
// LoadPullRequests loads pull requests
func (issues IssueList) LoadPullRequests() error {
return issues.loadPullRequests(x)
return issues.loadPullRequests(db.DefaultContext().Engine())
}
// GetApprovalCounts returns a map of issue ID to slice of approval counts
// FIXME: only returns official counts due to double counting of non-official approvals
func (issues IssueList) GetApprovalCounts() (map[int64][]*ReviewCount, error) {
return issues.getApprovalCounts(x)
return issues.getApprovalCounts(db.DefaultContext().Engine())
}
func (issues IssueList) getApprovalCounts(e Engine) (map[int64][]*ReviewCount, error) {
func (issues IssueList) getApprovalCounts(e db.Engine) (map[int64][]*ReviewCount, error) {
rCounts := make([]*ReviewCount, 0, 2*len(issues))
ids := make([]int64, len(issues))
for i, issue := range issues {
+9 -8
View File
@@ -7,18 +7,19 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"github.com/stretchr/testify/assert"
)
func TestIssueList_LoadRepositories(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
issueList := IssueList{
AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue),
AssertExistsAndLoadBean(t, &Issue{ID: 2}).(*Issue),
AssertExistsAndLoadBean(t, &Issue{ID: 4}).(*Issue),
db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue),
db.AssertExistsAndLoadBean(t, &Issue{ID: 2}).(*Issue),
db.AssertExistsAndLoadBean(t, &Issue{ID: 4}).(*Issue),
}
repos, err := issueList.LoadRepositories()
@@ -30,11 +31,11 @@ func TestIssueList_LoadRepositories(t *testing.T) {
}
func TestIssueList_LoadAttributes(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
setting.Service.EnableTimetracking = true
issueList := IssueList{
AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue),
AssertExistsAndLoadBean(t, &Issue{ID: 4}).(*Issue),
db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue),
db.AssertExistsAndLoadBean(t, &Issue{ID: 4}).(*Issue),
}
assert.NoError(t, issueList.LoadAttributes())
@@ -42,7 +43,7 @@ func TestIssueList_LoadAttributes(t *testing.T) {
assert.EqualValues(t, issue.RepoID, issue.Repo.ID)
for _, label := range issue.Labels {
assert.EqualValues(t, issue.RepoID, label.RepoID)
AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issue.ID, LabelID: label.ID})
db.AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issue.ID, LabelID: label.ID})
}
if issue.PosterID > 0 {
assert.EqualValues(t, issue.PosterID, issue.Poster.ID)
+3 -1
View File
@@ -4,6 +4,8 @@
package models
import "code.gitea.io/gitea/models/db"
// IssueLockOptions defines options for locking and/or unlocking an issue/PR
type IssueLockOptions struct {
Doer *User
@@ -35,7 +37,7 @@ func updateIssueLock(opts *IssueLockOptions, lock bool) error {
commentType = CommentTypeUnlock
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
+35 -30
View File
@@ -9,6 +9,7 @@ import (
"strings"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
api "code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/timeutil"
@@ -42,6 +43,10 @@ type Milestone struct {
TimeSinceUpdate int64 `xorm:"-"`
}
func init() {
db.RegisterModel(new(Milestone))
}
// BeforeUpdate is invoked from XORM before updating this object.
func (m *Milestone) BeforeUpdate() {
if m.NumIssues > 0 {
@@ -80,7 +85,7 @@ func (m *Milestone) State() api.StateType {
// NewMilestone creates new milestone of repository.
func NewMilestone(m *Milestone) (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -98,7 +103,7 @@ func NewMilestone(m *Milestone) (err error) {
return sess.Commit()
}
func getMilestoneByRepoID(e Engine, repoID, id int64) (*Milestone, error) {
func getMilestoneByRepoID(e db.Engine, repoID, id int64) (*Milestone, error) {
m := new(Milestone)
has, err := e.ID(id).Where("repo_id=?", repoID).Get(m)
if err != nil {
@@ -111,13 +116,13 @@ func getMilestoneByRepoID(e Engine, repoID, id int64) (*Milestone, error) {
// GetMilestoneByRepoID returns the milestone in a repository.
func GetMilestoneByRepoID(repoID, id int64) (*Milestone, error) {
return getMilestoneByRepoID(x, repoID, id)
return getMilestoneByRepoID(db.DefaultContext().Engine(), repoID, id)
}
// GetMilestoneByRepoIDANDName return a milestone if one exist by name and repo
func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error) {
var mile Milestone
has, err := x.Where("repo_id=? AND name=?", repoID, name).Get(&mile)
has, err := db.DefaultContext().Engine().Where("repo_id=? AND name=?", repoID, name).Get(&mile)
if err != nil {
return nil, err
}
@@ -129,10 +134,10 @@ func GetMilestoneByRepoIDANDName(repoID int64, name string) (*Milestone, error)
// GetMilestoneByID returns the milestone via id .
func GetMilestoneByID(id int64) (*Milestone, error) {
return getMilestoneByID(x, id)
return getMilestoneByID(db.DefaultContext().Engine(), id)
}
func getMilestoneByID(e Engine, id int64) (*Milestone, error) {
func getMilestoneByID(e db.Engine, id int64) (*Milestone, error) {
var m Milestone
has, err := e.ID(id).Get(&m)
if err != nil {
@@ -145,7 +150,7 @@ func getMilestoneByID(e Engine, id int64) (*Milestone, error) {
// UpdateMilestone updates information of given milestone.
func UpdateMilestone(m *Milestone, oldIsClosed bool) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -169,7 +174,7 @@ func UpdateMilestone(m *Milestone, oldIsClosed bool) error {
return sess.Commit()
}
func updateMilestone(e Engine, m *Milestone) error {
func updateMilestone(e db.Engine, m *Milestone) error {
m.Name = strings.TrimSpace(m.Name)
_, err := e.ID(m.ID).AllCols().Update(m)
if err != nil {
@@ -179,7 +184,7 @@ func updateMilestone(e Engine, m *Milestone) error {
}
// updateMilestoneCounters calculates NumIssues, NumClosesIssues and Completeness
func updateMilestoneCounters(e Engine, id int64) error {
func updateMilestoneCounters(e db.Engine, id int64) error {
_, err := e.ID(id).
SetExpr("num_issues", builder.Select("count(*)").From("issue").Where(
builder.Eq{"milestone_id": id},
@@ -202,7 +207,7 @@ func updateMilestoneCounters(e Engine, id int64) error {
// ChangeMilestoneStatusByRepoIDAndID changes a milestone open/closed status if the milestone ID is in the repo.
func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -229,7 +234,7 @@ func ChangeMilestoneStatusByRepoIDAndID(repoID, milestoneID int64, isClosed bool
// ChangeMilestoneStatus changes the milestone open/closed status.
func ChangeMilestoneStatus(m *Milestone, isClosed bool) (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -242,7 +247,7 @@ func ChangeMilestoneStatus(m *Milestone, isClosed bool) (err error) {
return sess.Commit()
}
func changeMilestoneStatus(e Engine, m *Milestone, isClosed bool) error {
func changeMilestoneStatus(e db.Engine, m *Milestone, isClosed bool) error {
m.IsClosed = isClosed
if isClosed {
m.ClosedDateUnix = timeutil.TimeStampNow()
@@ -298,7 +303,7 @@ func changeMilestoneAssign(e *xorm.Session, doer *User, issue *Issue, oldMilesto
// ChangeMilestoneAssign changes assignment of milestone for issue.
func ChangeMilestoneAssign(issue *Issue, doer *User, oldMilestoneID int64) (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -329,7 +334,7 @@ func DeleteMilestoneByRepoID(repoID, id int64) error {
return err
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -405,7 +410,7 @@ func (opts GetMilestonesOption) toCond() builder.Cond {
// GetMilestones returns milestones filtered by GetMilestonesOption's
func GetMilestones(opts GetMilestonesOption) (MilestoneList, int64, error) {
sess := x.Where(opts.toCond())
sess := db.DefaultContext().Engine().Where(opts.toCond())
if opts.Page != 0 {
sess = setSessionPagination(sess, &opts)
@@ -436,7 +441,7 @@ func GetMilestones(opts GetMilestonesOption) (MilestoneList, int64, error) {
// SearchMilestones search milestones
func SearchMilestones(repoCond builder.Cond, page int, isClosed bool, sortType string, keyword string) (MilestoneList, error) {
miles := make([]*Milestone, 0, setting.UI.IssuePagingNum)
sess := x.Where("is_closed = ?", isClosed)
sess := db.DefaultContext().Engine().Where("is_closed = ?", isClosed)
if len(keyword) > 0 {
sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)})
}
@@ -497,7 +502,7 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro
var err error
stats := &MilestonesStats{}
sess := x.Where("is_closed = ?", false)
sess := db.DefaultContext().Engine().Where("is_closed = ?", false)
if repoCond.IsValid() {
sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond)))
}
@@ -506,7 +511,7 @@ func GetMilestonesStatsByRepoCond(repoCond builder.Cond) (*MilestonesStats, erro
return nil, err
}
sess = x.Where("is_closed = ?", true)
sess = db.DefaultContext().Engine().Where("is_closed = ?", true)
if repoCond.IsValid() {
sess.And(builder.In("repo_id", builder.Select("id").From("repository").Where(repoCond)))
}
@@ -523,7 +528,7 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (*
var err error
stats := &MilestonesStats{}
sess := x.Where("is_closed = ?", false)
sess := db.DefaultContext().Engine().Where("is_closed = ?", false)
if len(keyword) > 0 {
sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)})
}
@@ -535,7 +540,7 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (*
return nil, err
}
sess = x.Where("is_closed = ?", true)
sess = db.DefaultContext().Engine().Where("is_closed = ?", true)
if len(keyword) > 0 {
sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)})
}
@@ -550,13 +555,13 @@ func GetMilestonesStatsByRepoCondAndKw(repoCond builder.Cond, keyword string) (*
return stats, nil
}
func countRepoMilestones(e Engine, repoID int64) (int64, error) {
func countRepoMilestones(e db.Engine, repoID int64) (int64, error) {
return e.
Where("repo_id=?", repoID).
Count(new(Milestone))
}
func countRepoClosedMilestones(e Engine, repoID int64) (int64, error) {
func countRepoClosedMilestones(e db.Engine, repoID int64) (int64, error) {
return e.
Where("repo_id=? AND is_closed=?", repoID, true).
Count(new(Milestone))
@@ -564,12 +569,12 @@ func countRepoClosedMilestones(e Engine, repoID int64) (int64, error) {
// CountRepoClosedMilestones returns number of closed milestones in given repository.
func CountRepoClosedMilestones(repoID int64) (int64, error) {
return countRepoClosedMilestones(x, repoID)
return countRepoClosedMilestones(db.DefaultContext().Engine(), repoID)
}
// CountMilestonesByRepoCond map from repo conditions to number of milestones matching the options`
func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64]int64, error) {
sess := x.Where("is_closed = ?", isClosed)
sess := db.DefaultContext().Engine().Where("is_closed = ?", isClosed)
if repoCond.IsValid() {
sess.In("repo_id", builder.Select("id").From("repository").Where(repoCond))
}
@@ -594,7 +599,7 @@ func CountMilestonesByRepoCond(repoCond builder.Cond, isClosed bool) (map[int64]
// CountMilestonesByRepoCondAndKw map from repo conditions and the keyword of milestones' name to number of milestones matching the options`
func CountMilestonesByRepoCondAndKw(repoCond builder.Cond, keyword string, isClosed bool) (map[int64]int64, error) {
sess := x.Where("is_closed = ?", isClosed)
sess := db.DefaultContext().Engine().Where("is_closed = ?", isClosed)
if len(keyword) > 0 {
sess = sess.And(builder.Like{"UPPER(name)", strings.ToUpper(keyword)})
}
@@ -620,7 +625,7 @@ func CountMilestonesByRepoCondAndKw(repoCond builder.Cond, keyword string, isClo
return countMap, nil
}
func updateRepoMilestoneNum(e Engine, repoID int64) error {
func updateRepoMilestoneNum(e db.Engine, repoID int64) error {
_, err := e.Exec("UPDATE `repository` SET num_milestones=(SELECT count(*) FROM milestone WHERE repo_id=?),num_closed_milestones=(SELECT count(*) FROM milestone WHERE repo_id=? AND is_closed=?) WHERE id=?",
repoID,
repoID,
@@ -637,7 +642,7 @@ func updateRepoMilestoneNum(e Engine, repoID int64) error {
// |_||_| \__,_|\___|_|\_\___|\__,_| |_| |_|_| |_| |_|\___||___/
//
func (milestones MilestoneList) loadTotalTrackedTimes(e Engine) error {
func (milestones MilestoneList) loadTotalTrackedTimes(e db.Engine) error {
type totalTimesByMilestone struct {
MilestoneID int64
Time int64
@@ -677,7 +682,7 @@ func (milestones MilestoneList) loadTotalTrackedTimes(e Engine) error {
return nil
}
func (m *Milestone) loadTotalTrackedTime(e Engine) error {
func (m *Milestone) loadTotalTrackedTime(e db.Engine) error {
type totalTimesByMilestone struct {
MilestoneID int64
Time int64
@@ -702,10 +707,10 @@ func (m *Milestone) loadTotalTrackedTime(e Engine) error {
// LoadTotalTrackedTimes loads for every milestone in the list the TotalTrackedTime by a batch request
func (milestones MilestoneList) LoadTotalTrackedTimes() error {
return milestones.loadTotalTrackedTimes(x)
return milestones.loadTotalTrackedTimes(db.DefaultContext().Engine())
}
// LoadTotalTrackedTime loads the tracked time for the milestone
func (m *Milestone) LoadTotalTrackedTime() error {
return m.loadTotalTrackedTime(x)
return m.loadTotalTrackedTime(db.DefaultContext().Engine())
}
+51 -50
View File
@@ -8,6 +8,7 @@ import (
"sort"
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
api "code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/timeutil"
@@ -22,7 +23,7 @@ func TestMilestone_State(t *testing.T) {
}
func TestNewMilestone(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
milestone := &Milestone{
RepoID: 1,
Name: "milestoneName",
@@ -30,26 +31,26 @@ func TestNewMilestone(t *testing.T) {
}
assert.NoError(t, NewMilestone(milestone))
AssertExistsAndLoadBean(t, milestone)
db.AssertExistsAndLoadBean(t, milestone)
CheckConsistencyFor(t, &Repository{ID: milestone.RepoID}, &Milestone{})
}
func TestGetMilestoneByRepoID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
milestone, err := GetMilestoneByRepoID(1, 1)
assert.NoError(t, err)
assert.EqualValues(t, 1, milestone.ID)
assert.EqualValues(t, 1, milestone.RepoID)
_, err = GetMilestoneByRepoID(NonexistentID, NonexistentID)
_, err = GetMilestoneByRepoID(db.NonexistentID, db.NonexistentID)
assert.True(t, IsErrMilestoneNotExist(err))
}
func TestGetMilestonesByRepoID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(repoID int64, state api.StateType) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
milestones, _, err := GetMilestones(GetMilestonesOption{
RepoID: repo.ID,
State: state,
@@ -88,7 +89,7 @@ func TestGetMilestonesByRepoID(t *testing.T) {
test(3, api.StateAll)
milestones, _, err := GetMilestones(GetMilestonesOption{
RepoID: NonexistentID,
RepoID: db.NonexistentID,
State: api.StateOpen,
})
assert.NoError(t, err)
@@ -96,8 +97,8 @@ func TestGetMilestonesByRepoID(t *testing.T) {
}
func TestGetMilestones(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
assert.NoError(t, db.PrepareTestDatabase())
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
test := func(sortType string, sortCond func(*Milestone) int) {
for _, page := range []int{0, 1} {
milestones, _, err := GetMilestones(GetMilestonesOption{
@@ -157,22 +158,22 @@ func TestGetMilestones(t *testing.T) {
}
func TestUpdateMilestone(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
milestone := AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone)
milestone := db.AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone)
milestone.Name = " newMilestoneName "
milestone.Content = "newMilestoneContent"
assert.NoError(t, UpdateMilestone(milestone, milestone.IsClosed))
milestone = AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone)
milestone = db.AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone)
assert.EqualValues(t, "newMilestoneName", milestone.Name)
CheckConsistencyFor(t, &Milestone{})
}
func TestCountRepoMilestones(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(repoID int64) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
count, err := countRepoMilestones(x, repoID)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
count, err := countRepoMilestones(db.DefaultContext().Engine(), repoID)
assert.NoError(t, err)
assert.EqualValues(t, repo.NumMilestones, count)
}
@@ -180,15 +181,15 @@ func TestCountRepoMilestones(t *testing.T) {
test(2)
test(3)
count, err := countRepoMilestones(x, NonexistentID)
count, err := countRepoMilestones(db.DefaultContext().Engine(), db.NonexistentID)
assert.NoError(t, err)
assert.EqualValues(t, 0, count)
}
func TestCountRepoClosedMilestones(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(repoID int64) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
count, err := CountRepoClosedMilestones(repoID)
assert.NoError(t, err)
assert.EqualValues(t, repo.NumClosedMilestones, count)
@@ -197,55 +198,55 @@ func TestCountRepoClosedMilestones(t *testing.T) {
test(2)
test(3)
count, err := CountRepoClosedMilestones(NonexistentID)
count, err := CountRepoClosedMilestones(db.NonexistentID)
assert.NoError(t, err)
assert.EqualValues(t, 0, count)
}
func TestChangeMilestoneStatus(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
milestone := AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone)
assert.NoError(t, db.PrepareTestDatabase())
milestone := db.AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone)
assert.NoError(t, ChangeMilestoneStatus(milestone, true))
AssertExistsAndLoadBean(t, &Milestone{ID: 1}, "is_closed=1")
db.AssertExistsAndLoadBean(t, &Milestone{ID: 1}, "is_closed=1")
CheckConsistencyFor(t, &Repository{ID: milestone.RepoID}, &Milestone{})
assert.NoError(t, ChangeMilestoneStatus(milestone, false))
AssertExistsAndLoadBean(t, &Milestone{ID: 1}, "is_closed=0")
db.AssertExistsAndLoadBean(t, &Milestone{ID: 1}, "is_closed=0")
CheckConsistencyFor(t, &Repository{ID: milestone.RepoID}, &Milestone{})
}
func TestUpdateMilestoneCounters(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{MilestoneID: 1},
assert.NoError(t, db.PrepareTestDatabase())
issue := db.AssertExistsAndLoadBean(t, &Issue{MilestoneID: 1},
"is_closed=0").(*Issue)
issue.IsClosed = true
issue.ClosedUnix = timeutil.TimeStampNow()
_, err := x.ID(issue.ID).Cols("is_closed", "closed_unix").Update(issue)
_, err := db.DefaultContext().Engine().ID(issue.ID).Cols("is_closed", "closed_unix").Update(issue)
assert.NoError(t, err)
assert.NoError(t, updateMilestoneCounters(x, issue.MilestoneID))
assert.NoError(t, updateMilestoneCounters(db.DefaultContext().Engine(), issue.MilestoneID))
CheckConsistencyFor(t, &Milestone{})
issue.IsClosed = false
issue.ClosedUnix = 0
_, err = x.ID(issue.ID).Cols("is_closed", "closed_unix").Update(issue)
_, err = db.DefaultContext().Engine().ID(issue.ID).Cols("is_closed", "closed_unix").Update(issue)
assert.NoError(t, err)
assert.NoError(t, updateMilestoneCounters(x, issue.MilestoneID))
assert.NoError(t, updateMilestoneCounters(db.DefaultContext().Engine(), issue.MilestoneID))
CheckConsistencyFor(t, &Milestone{})
}
func TestChangeMilestoneAssign(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{RepoID: 1}).(*Issue)
doer := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
issue := db.AssertExistsAndLoadBean(t, &Issue{RepoID: 1}).(*Issue)
doer := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NotNil(t, issue)
assert.NotNil(t, doer)
oldMilestoneID := issue.MilestoneID
issue.MilestoneID = 2
assert.NoError(t, ChangeMilestoneAssign(issue, doer, oldMilestoneID))
AssertExistsAndLoadBean(t, &Comment{
db.AssertExistsAndLoadBean(t, &Comment{
IssueID: issue.ID,
Type: CommentTypeMilestone,
MilestoneID: issue.MilestoneID,
@@ -255,18 +256,18 @@ func TestChangeMilestoneAssign(t *testing.T) {
}
func TestDeleteMilestoneByRepoID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
assert.NoError(t, DeleteMilestoneByRepoID(1, 1))
AssertNotExistsBean(t, &Milestone{ID: 1})
db.AssertNotExistsBean(t, &Milestone{ID: 1})
CheckConsistencyFor(t, &Repository{ID: 1})
assert.NoError(t, DeleteMilestoneByRepoID(NonexistentID, NonexistentID))
assert.NoError(t, DeleteMilestoneByRepoID(db.NonexistentID, db.NonexistentID))
}
func TestMilestoneList_LoadTotalTrackedTimes(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
miles := MilestoneList{
AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone),
db.AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone),
}
assert.NoError(t, miles.LoadTotalTrackedTimes())
@@ -275,9 +276,9 @@ func TestMilestoneList_LoadTotalTrackedTimes(t *testing.T) {
}
func TestCountMilestonesByRepoIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
milestonesCount := func(repoID int64) (int, int) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
return repo.NumOpenMilestones, repo.NumClosedMilestones
}
repo1OpenCount, repo1ClosedCount := milestonesCount(1)
@@ -295,9 +296,9 @@ func TestCountMilestonesByRepoIDs(t *testing.T) {
}
func TestGetMilestonesByRepoIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo2 := AssertExistsAndLoadBean(t, &Repository{ID: 2}).(*Repository)
assert.NoError(t, db.PrepareTestDatabase())
repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo2 := db.AssertExistsAndLoadBean(t, &Repository{ID: 2}).(*Repository)
test := func(sortType string, sortCond func(*Milestone) int) {
for _, page := range []int{0, 1} {
openMilestones, err := GetMilestonesByRepoIDs([]int64{repo1.ID, repo2.ID}, page, false, sortType)
@@ -340,8 +341,8 @@ func TestGetMilestonesByRepoIDs(t *testing.T) {
}
func TestLoadTotalTrackedTime(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
milestone := AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone)
assert.NoError(t, db.PrepareTestDatabase())
milestone := db.AssertExistsAndLoadBean(t, &Milestone{ID: 1}).(*Milestone)
assert.NoError(t, milestone.LoadTotalTrackedTime())
@@ -349,10 +350,10 @@ func TestLoadTotalTrackedTime(t *testing.T) {
}
func TestGetMilestonesStats(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(repoID int64) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
stats, err := GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": repoID}))
assert.NoError(t, err)
assert.EqualValues(t, repo.NumMilestones-repo.NumClosedMilestones, stats.OpenCount)
@@ -362,13 +363,13 @@ func TestGetMilestonesStats(t *testing.T) {
test(2)
test(3)
stats, err := GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": NonexistentID}))
stats, err := GetMilestonesStatsByRepoCond(builder.And(builder.Eq{"repo_id": db.NonexistentID}))
assert.NoError(t, err)
assert.EqualValues(t, 0, stats.OpenCount)
assert.EqualValues(t, 0, stats.ClosedCount)
repo1 := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo2 := AssertExistsAndLoadBean(t, &Repository{ID: 2}).(*Repository)
repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo2 := db.AssertExistsAndLoadBean(t, &Repository{ID: 2}).(*Repository)
milestoneStats, err := GetMilestonesStatsByRepoCond(builder.In("repo_id", []int64{repo1.ID, repo2.ID}))
assert.NoError(t, err)
+14 -9
View File
@@ -8,6 +8,7 @@ import (
"bytes"
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
@@ -28,6 +29,10 @@ type Reaction struct {
CreatedUnix timeutil.TimeStamp `xorm:"INDEX created"`
}
func init() {
db.RegisterModel(new(Reaction))
}
// FindReactionsOptions describes the conditions to Find reactions
type FindReactionsOptions struct {
ListOptions
@@ -66,7 +71,7 @@ func (opts *FindReactionsOptions) toConds() builder.Cond {
// FindCommentReactions returns a ReactionList of all reactions from an comment
func FindCommentReactions(comment *Comment) (ReactionList, error) {
return findReactions(x, FindReactionsOptions{
return findReactions(db.DefaultContext().Engine(), FindReactionsOptions{
IssueID: comment.IssueID,
CommentID: comment.ID,
})
@@ -74,14 +79,14 @@ func FindCommentReactions(comment *Comment) (ReactionList, error) {
// FindIssueReactions returns a ReactionList of all reactions from an issue
func FindIssueReactions(issue *Issue, listOptions ListOptions) (ReactionList, error) {
return findReactions(x, FindReactionsOptions{
return findReactions(db.DefaultContext().Engine(), FindReactionsOptions{
ListOptions: listOptions,
IssueID: issue.ID,
CommentID: -1,
})
}
func findReactions(e Engine, opts FindReactionsOptions) ([]*Reaction, error) {
func findReactions(e db.Engine, opts FindReactionsOptions) ([]*Reaction, error) {
e = e.
Where(opts.toConds()).
In("reaction.`type`", setting.UI.Reactions).
@@ -143,7 +148,7 @@ func CreateReaction(opts *ReactionOptions) (*Reaction, error) {
return nil, ErrForbiddenIssueReaction{opts.Type}
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return nil, err
@@ -179,7 +184,7 @@ func CreateCommentReaction(doer *User, issue *Issue, comment *Comment, content s
})
}
func deleteReaction(e Engine, opts *ReactionOptions) error {
func deleteReaction(e db.Engine, opts *ReactionOptions) error {
reaction := &Reaction{
Type: opts.Type,
}
@@ -198,7 +203,7 @@ func deleteReaction(e Engine, opts *ReactionOptions) error {
// DeleteReaction deletes reaction for issue or comment.
func DeleteReaction(opts *ReactionOptions) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -235,7 +240,7 @@ func (r *Reaction) LoadUser() (*User, error) {
if r.User != nil {
return r.User, nil
}
user, err := getUserByID(x, r.UserID)
user, err := getUserByID(db.DefaultContext().Engine(), r.UserID)
if err != nil {
return nil, err
}
@@ -281,7 +286,7 @@ func (list ReactionList) getUserIDs() []int64 {
return keysInt64(userIDs)
}
func (list ReactionList) loadUsers(e Engine, repo *Repository) ([]*User, error) {
func (list ReactionList) loadUsers(e db.Engine, repo *Repository) ([]*User, error) {
if len(list) == 0 {
return nil, nil
}
@@ -309,7 +314,7 @@ func (list ReactionList) loadUsers(e Engine, repo *Repository) ([]*User, error)
// LoadUsers loads reactions' all users
func (list ReactionList) LoadUsers(repo *Repository) ([]*User, error) {
return list.loadUsers(x, repo)
return list.loadUsers(db.DefaultContext().Engine(), repo)
}
// GetFirstUsers returns first reacted user display names separated by comma
+38 -37
View File
@@ -6,6 +6,7 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"github.com/stretchr/testify/assert"
@@ -24,23 +25,23 @@ func addReaction(t *testing.T, doer *User, issue *Issue, comment *Comment, conte
}
func TestIssueAddReaction(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
issue1 := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
issue1 := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
addReaction(t, user1, issue1, nil, "heart")
AssertExistsAndLoadBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID})
db.AssertExistsAndLoadBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID})
}
func TestIssueAddDuplicateReaction(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
issue1 := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
issue1 := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
addReaction(t, user1, issue1, nil, "heart")
@@ -52,37 +53,37 @@ func TestIssueAddDuplicateReaction(t *testing.T) {
assert.Error(t, err)
assert.Equal(t, ErrReactionAlreadyExist{Reaction: "heart"}, err)
existingR := AssertExistsAndLoadBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID}).(*Reaction)
existingR := db.AssertExistsAndLoadBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID}).(*Reaction)
assert.Equal(t, existingR.ID, reaction.ID)
}
func TestIssueDeleteReaction(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
issue1 := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
issue1 := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
addReaction(t, user1, issue1, nil, "heart")
err := DeleteIssueReaction(user1, issue1, "heart")
assert.NoError(t, err)
AssertNotExistsBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID})
db.AssertNotExistsBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID})
}
func TestIssueReactionCount(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
setting.UI.ReactionMaxUserNum = 2
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user2 := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
user4 := AssertExistsAndLoadBean(t, &User{ID: 4}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user2 := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
user4 := db.AssertExistsAndLoadBean(t, &User{ID: 4}).(*User)
ghost := NewGhostUser()
issue := AssertExistsAndLoadBean(t, &Issue{ID: 2}).(*Issue)
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 2}).(*Issue)
addReaction(t, user1, issue, nil, "heart")
addReaction(t, user2, issue, nil, "heart")
@@ -92,7 +93,7 @@ func TestIssueReactionCount(t *testing.T) {
addReaction(t, user4, issue, nil, "heart")
addReaction(t, ghost, issue, nil, "-1")
err := issue.loadReactions(x)
err := issue.loadReactions(db.DefaultContext().Engine())
assert.NoError(t, err)
assert.Len(t, issue.Reactions, 7)
@@ -110,31 +111,31 @@ func TestIssueReactionCount(t *testing.T) {
}
func TestIssueCommentAddReaction(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
issue1 := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
issue1 := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
comment1 := AssertExistsAndLoadBean(t, &Comment{ID: 1}).(*Comment)
comment1 := db.AssertExistsAndLoadBean(t, &Comment{ID: 1}).(*Comment)
addReaction(t, user1, issue1, comment1, "heart")
AssertExistsAndLoadBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID, CommentID: comment1.ID})
db.AssertExistsAndLoadBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID, CommentID: comment1.ID})
}
func TestIssueCommentDeleteReaction(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user2 := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
user4 := AssertExistsAndLoadBean(t, &User{ID: 4}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user2 := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
user4 := db.AssertExistsAndLoadBean(t, &User{ID: 4}).(*User)
issue1 := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
repo1 := AssertExistsAndLoadBean(t, &Repository{ID: issue1.RepoID}).(*Repository)
issue1 := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
repo1 := db.AssertExistsAndLoadBean(t, &Repository{ID: issue1.RepoID}).(*Repository)
comment1 := AssertExistsAndLoadBean(t, &Comment{ID: 1}).(*Comment)
comment1 := db.AssertExistsAndLoadBean(t, &Comment{ID: 1}).(*Comment)
addReaction(t, user1, issue1, comment1, "heart")
addReaction(t, user2, issue1, comment1, "heart")
@@ -151,16 +152,16 @@ func TestIssueCommentDeleteReaction(t *testing.T) {
}
func TestIssueCommentReactionCount(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1 := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
user1 := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
issue1 := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
issue1 := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
comment1 := AssertExistsAndLoadBean(t, &Comment{ID: 1}).(*Comment)
comment1 := db.AssertExistsAndLoadBean(t, &Comment{ID: 1}).(*Comment)
addReaction(t, user1, issue1, comment1, "heart")
assert.NoError(t, DeleteCommentReaction(user1, issue1, comment1, "heart"))
AssertNotExistsBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID, CommentID: comment1.ID})
db.AssertNotExistsBean(t, &Reaction{Type: "heart", UserID: user1.ID, IssueID: issue1.ID, CommentID: comment1.ID})
}
+13 -8
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/xorm"
@@ -21,6 +22,10 @@ type Stopwatch struct {
CreatedUnix timeutil.TimeStamp `xorm:"created"`
}
func init() {
db.RegisterModel(new(Stopwatch))
}
// Seconds returns the amount of time passed since creation, based on local server time
func (s Stopwatch) Seconds() int64 {
return int64(timeutil.TimeStampNow() - s.CreatedUnix)
@@ -31,7 +36,7 @@ func (s Stopwatch) Duration() string {
return SecToTime(s.Seconds())
}
func getStopwatch(e Engine, userID, issueID int64) (sw *Stopwatch, exists bool, err error) {
func getStopwatch(e db.Engine, userID, issueID int64) (sw *Stopwatch, exists bool, err error) {
sw = new(Stopwatch)
exists, err = e.
Where("user_id = ?", userID).
@@ -43,7 +48,7 @@ func getStopwatch(e Engine, userID, issueID int64) (sw *Stopwatch, exists bool,
// GetUserStopwatches return list of all stopwatches of a user
func GetUserStopwatches(userID int64, listOptions ListOptions) ([]*Stopwatch, error) {
sws := make([]*Stopwatch, 0, 8)
sess := x.Where("stopwatch.user_id = ?", userID)
sess := db.DefaultContext().Engine().Where("stopwatch.user_id = ?", userID)
if listOptions.Page != 0 {
sess = setSessionPagination(sess, &listOptions)
}
@@ -57,21 +62,21 @@ func GetUserStopwatches(userID int64, listOptions ListOptions) ([]*Stopwatch, er
// CountUserStopwatches return count of all stopwatches of a user
func CountUserStopwatches(userID int64) (int64, error) {
return x.Where("user_id = ?", userID).Count(&Stopwatch{})
return db.DefaultContext().Engine().Where("user_id = ?", userID).Count(&Stopwatch{})
}
// StopwatchExists returns true if the stopwatch exists
func StopwatchExists(userID, issueID int64) bool {
_, exists, _ := getStopwatch(x, userID, issueID)
_, exists, _ := getStopwatch(db.DefaultContext().Engine(), userID, issueID)
return exists
}
// HasUserStopwatch returns true if the user has a stopwatch
func HasUserStopwatch(userID int64) (exists bool, sw *Stopwatch, err error) {
return hasUserStopwatch(x, userID)
return hasUserStopwatch(db.DefaultContext().Engine(), userID)
}
func hasUserStopwatch(e Engine, userID int64) (exists bool, sw *Stopwatch, err error) {
func hasUserStopwatch(e db.Engine, userID int64) (exists bool, sw *Stopwatch, err error) {
sw = new(Stopwatch)
exists, err = e.
Where("user_id = ?", userID).
@@ -81,7 +86,7 @@ func hasUserStopwatch(e Engine, userID int64) (exists bool, sw *Stopwatch, err e
// CreateOrStopIssueStopwatch will create or remove a stopwatch and will log it into issue's timeline.
func CreateOrStopIssueStopwatch(user *User, issue *Issue) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -170,7 +175,7 @@ func createOrStopIssueStopwatch(e *xorm.Session, user *User, issue *Issue) error
// CancelStopwatch removes the given stopwatch and logs it into issue's timeline.
func CancelStopwatch(user *User, issue *Issue) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
+10 -9
View File
@@ -7,13 +7,14 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"github.com/stretchr/testify/assert"
)
func TestCancelStopwatch(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user1, err := GetUserByID(1)
assert.NoError(t, err)
@@ -25,22 +26,22 @@ func TestCancelStopwatch(t *testing.T) {
err = CancelStopwatch(user1, issue1)
assert.NoError(t, err)
AssertNotExistsBean(t, &Stopwatch{UserID: user1.ID, IssueID: issue1.ID})
db.AssertNotExistsBean(t, &Stopwatch{UserID: user1.ID, IssueID: issue1.ID})
_ = AssertExistsAndLoadBean(t, &Comment{Type: CommentTypeCancelTracking, PosterID: user1.ID, IssueID: issue1.ID})
_ = db.AssertExistsAndLoadBean(t, &Comment{Type: CommentTypeCancelTracking, PosterID: user1.ID, IssueID: issue1.ID})
assert.Nil(t, CancelStopwatch(user1, issue2))
}
func TestStopwatchExists(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
assert.True(t, StopwatchExists(1, 1))
assert.False(t, StopwatchExists(1, 2))
}
func TestHasUserStopwatch(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
exists, sw, err := HasUserStopwatch(1)
assert.NoError(t, err)
@@ -53,7 +54,7 @@ func TestHasUserStopwatch(t *testing.T) {
}
func TestCreateOrStopIssueStopwatch(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user2, err := GetUserByID(2)
assert.NoError(t, err)
@@ -66,10 +67,10 @@ func TestCreateOrStopIssueStopwatch(t *testing.T) {
assert.NoError(t, err)
assert.NoError(t, CreateOrStopIssueStopwatch(user3, issue1))
sw := AssertExistsAndLoadBean(t, &Stopwatch{UserID: 3, IssueID: 1}).(*Stopwatch)
sw := db.AssertExistsAndLoadBean(t, &Stopwatch{UserID: 3, IssueID: 1}).(*Stopwatch)
assert.LessOrEqual(t, sw.CreatedUnix, timeutil.TimeStampNow())
assert.NoError(t, CreateOrStopIssueStopwatch(user2, issue2))
AssertNotExistsBean(t, &Stopwatch{UserID: 2, IssueID: 2})
AssertExistsAndLoadBean(t, &TrackedTime{UserID: 2, IssueID: 2})
db.AssertNotExistsBean(t, &Stopwatch{UserID: 2, IssueID: 2})
db.AssertExistsAndLoadBean(t, &TrackedTime{UserID: 2, IssueID: 2})
}
+58 -41
View File
@@ -5,29 +5,32 @@
package models
import (
"fmt"
"sort"
"sync"
"testing"
"time"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestIssue_ReplaceLabels(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(issueID int64, labelIDs []int64) {
issue := AssertExistsAndLoadBean(t, &Issue{ID: issueID}).(*Issue)
repo := AssertExistsAndLoadBean(t, &Repository{ID: issue.RepoID}).(*Repository)
doer := AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User)
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: issueID}).(*Issue)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: issue.RepoID}).(*Repository)
doer := db.AssertExistsAndLoadBean(t, &User{ID: repo.OwnerID}).(*User)
labels := make([]*Label, len(labelIDs))
for i, labelID := range labelIDs {
labels[i] = AssertExistsAndLoadBean(t, &Label{ID: labelID, RepoID: repo.ID}).(*Label)
labels[i] = db.AssertExistsAndLoadBean(t, &Label{ID: labelID, RepoID: repo.ID}).(*Label)
}
assert.NoError(t, issue.ReplaceLabels(labels, doer))
AssertCount(t, &IssueLabel{IssueID: issueID}, len(labelIDs))
db.AssertCount(t, &IssueLabel{IssueID: issueID}, len(labelIDs))
for _, labelID := range labelIDs {
AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issueID, LabelID: labelID})
db.AssertExistsAndLoadBean(t, &IssueLabel{IssueID: issueID, LabelID: labelID})
}
}
@@ -37,7 +40,7 @@ func TestIssue_ReplaceLabels(t *testing.T) {
}
func Test_GetIssueIDsByRepoID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
ids, err := GetIssueIDsByRepoID(1)
assert.NoError(t, err)
@@ -45,8 +48,8 @@ func Test_GetIssueIDsByRepoID(t *testing.T) {
}
func TestIssueAPIURL(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
assert.NoError(t, db.PrepareTestDatabase())
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
err := issue.LoadAttributes()
assert.NoError(t, err)
@@ -54,7 +57,7 @@ func TestIssueAPIURL(t *testing.T) {
}
func TestGetIssuesByIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(expectedIssueIDs, nonExistentIssueIDs []int64) {
issues, err := GetIssuesByIDs(append(expectedIssueIDs, nonExistentIssueIDs...))
assert.NoError(t, err)
@@ -65,16 +68,16 @@ func TestGetIssuesByIDs(t *testing.T) {
assert.Equal(t, expectedIssueIDs, actualIssueIDs)
}
testSuccess([]int64{1, 2, 3}, []int64{})
testSuccess([]int64{1, 2, 3}, []int64{NonexistentID})
testSuccess([]int64{1, 2, 3}, []int64{db.NonexistentID})
}
func TestGetParticipantIDsByIssue(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
checkParticipants := func(issueID int64, userIDs []int) {
issue, err := GetIssueByID(issueID)
assert.NoError(t, err)
participants, err := issue.getParticipantIDsByIssue(x)
participants, err := issue.getParticipantIDsByIssue(db.DefaultContext().Engine())
if assert.NoError(t, err) {
participantsIDs := make([]int, len(participants))
for i, uid := range participants {
@@ -103,17 +106,17 @@ func TestIssue_ClearLabels(t *testing.T) {
{3, 2}, // pull-request, has no labels
}
for _, test := range tests {
assert.NoError(t, PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{ID: test.issueID}).(*Issue)
doer := AssertExistsAndLoadBean(t, &User{ID: test.doerID}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: test.issueID}).(*Issue)
doer := db.AssertExistsAndLoadBean(t, &User{ID: test.doerID}).(*User)
assert.NoError(t, issue.ClearLabels(doer))
AssertNotExistsBean(t, &IssueLabel{IssueID: test.issueID})
db.AssertNotExistsBean(t, &IssueLabel{IssueID: test.issueID})
}
}
func TestUpdateIssueCols(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{}).(*Issue)
assert.NoError(t, db.PrepareTestDatabase())
issue := db.AssertExistsAndLoadBean(t, &Issue{}).(*Issue)
const newTitle = "New Title for unit test"
issue.Title = newTitle
@@ -122,17 +125,17 @@ func TestUpdateIssueCols(t *testing.T) {
issue.Content = "This should have no effect"
now := time.Now().Unix()
assert.NoError(t, updateIssueCols(x, issue, "name"))
assert.NoError(t, updateIssueCols(db.DefaultContext().Engine(), issue, "name"))
then := time.Now().Unix()
updatedIssue := AssertExistsAndLoadBean(t, &Issue{ID: issue.ID}).(*Issue)
updatedIssue := db.AssertExistsAndLoadBean(t, &Issue{ID: issue.ID}).(*Issue)
assert.EqualValues(t, newTitle, updatedIssue.Title)
assert.EqualValues(t, prevContent, updatedIssue.Content)
AssertInt64InRange(t, now, then, int64(updatedIssue.UpdatedUnix))
db.AssertInt64InRange(t, now, then, int64(updatedIssue.UpdatedUnix))
}
func TestIssues(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
for _, test := range []struct {
Opts IssuesOptions
ExpectedIssueIDs []int64
@@ -187,7 +190,7 @@ func TestIssues(t *testing.T) {
}
func TestGetUserIssueStats(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
for _, test := range []struct {
Opts UserIssueStatsOptions
ExpectedIssueStats IssueStats
@@ -284,15 +287,15 @@ func TestGetUserIssueStats(t *testing.T) {
}
func TestIssue_loadTotalTimes(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
ms, err := GetIssueByID(2)
assert.NoError(t, err)
assert.NoError(t, ms.loadTotalTimes(x))
assert.NoError(t, ms.loadTotalTimes(db.DefaultContext().Engine()))
assert.Equal(t, int64(3682), ms.TotalTrackedTime)
}
func TestIssue_SearchIssueIDsByKeyword(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
total, ids, err := SearchIssueIDsByKeyword("issue2", []int64{1}, 10, 0)
assert.NoError(t, err)
assert.EqualValues(t, 1, total)
@@ -316,8 +319,8 @@ func TestIssue_SearchIssueIDsByKeyword(t *testing.T) {
}
func TestGetRepoIDsForIssuesOptions(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
user := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
for _, test := range []struct {
Opts IssuesOptions
ExpectedRepoIDs []int64
@@ -348,8 +351,8 @@ func TestGetRepoIDsForIssuesOptions(t *testing.T) {
func testInsertIssue(t *testing.T, title, content string, expectIndex int64) *Issue {
var newIssue Issue
t.Run(title, func(t *testing.T) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
user := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
issue := Issue{
RepoID: repo.ID,
@@ -360,7 +363,7 @@ func testInsertIssue(t *testing.T, title, content string, expectIndex int64) *Is
err := NewIssue(repo, &issue, nil, nil)
assert.NoError(t, err)
has, err := x.ID(issue.ID).Get(&newIssue)
has, err := db.DefaultContext().Engine().ID(issue.ID).Get(&newIssue)
assert.NoError(t, err)
assert.True(t, has)
assert.EqualValues(t, issue.Title, newIssue.Title)
@@ -373,28 +376,28 @@ func testInsertIssue(t *testing.T, title, content string, expectIndex int64) *Is
}
func TestIssue_InsertIssue(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
// there are 5 issues and max index is 5 on repository 1, so this one should 6
issue := testInsertIssue(t, "my issue1", "special issue's comments?", 6)
_, err := x.ID(issue.ID).Delete(new(Issue))
_, err := db.DefaultContext().Engine().ID(issue.ID).Delete(new(Issue))
assert.NoError(t, err)
issue = testInsertIssue(t, `my issue2, this is my son's love \n \r \ `, "special issue's '' comments?", 7)
_, err = x.ID(issue.ID).Delete(new(Issue))
_, err = db.DefaultContext().Engine().ID(issue.ID).Delete(new(Issue))
assert.NoError(t, err)
}
func TestIssue_ResolveMentions(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(owner, repo, doer string, mentions []string, expected []int64) {
o := AssertExistsAndLoadBean(t, &User{LowerName: owner}).(*User)
r := AssertExistsAndLoadBean(t, &Repository{OwnerID: o.ID, LowerName: repo}).(*Repository)
o := db.AssertExistsAndLoadBean(t, &User{LowerName: owner}).(*User)
r := db.AssertExistsAndLoadBean(t, &Repository{OwnerID: o.ID, LowerName: repo}).(*Repository)
issue := &Issue{RepoID: r.ID}
d := AssertExistsAndLoadBean(t, &User{LowerName: doer}).(*User)
resolved, err := issue.ResolveMentionsByVisibility(DefaultDBContext(), d, mentions)
d := db.AssertExistsAndLoadBean(t, &User{LowerName: doer}).(*User)
resolved, err := issue.ResolveMentionsByVisibility(db.DefaultContext(), d, mentions)
assert.NoError(t, err)
ids := make([]int64, len(resolved))
for i, user := range resolved {
@@ -417,3 +420,17 @@ func TestIssue_ResolveMentions(t *testing.T) {
// Private repo, whole team
testSuccess("user17", "big_test_private_4", "user15", []string{"user17/owners"}, []int64{18})
}
func TestResourceIndex(t *testing.T) {
assert.NoError(t, db.PrepareTestDatabase())
var wg sync.WaitGroup
for i := 0; i < 100; i++ {
wg.Add(1)
go func(i int) {
testInsertIssue(t, fmt.Sprintf("issue %d", i+1), "my issue", 0)
wg.Done()
}(i)
}
wg.Wait()
}
+20 -15
View File
@@ -7,6 +7,7 @@ package models
import (
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"xorm.io/builder"
@@ -25,6 +26,10 @@ type TrackedTime struct {
Deleted bool `xorm:"NOT NULL DEFAULT false"`
}
func init() {
db.RegisterModel(new(TrackedTime))
}
// TrackedTimeList is a List of TrackedTime's
type TrackedTimeList []*TrackedTime
@@ -35,10 +40,10 @@ func (t *TrackedTime) AfterLoad() {
// LoadAttributes load Issue, User
func (t *TrackedTime) LoadAttributes() (err error) {
return t.loadAttributes(x)
return t.loadAttributes(db.DefaultContext().Engine())
}
func (t *TrackedTime) loadAttributes(e Engine) (err error) {
func (t *TrackedTime) loadAttributes(e db.Engine) (err error) {
if t.Issue == nil {
t.Issue, err = getIssueByID(e, t.IssueID)
if err != nil {
@@ -104,7 +109,7 @@ func (opts *FindTrackedTimesOptions) toCond() builder.Cond {
}
// toSession will convert the given options to a xorm Session by using the conditions from toCond and joining with issue table if required
func (opts *FindTrackedTimesOptions) toSession(e Engine) Engine {
func (opts *FindTrackedTimesOptions) toSession(e db.Engine) db.Engine {
sess := e
if opts.RepositoryID > 0 || opts.MilestoneID > 0 {
sess = e.Join("INNER", "issue", "issue.id = tracked_time.issue_id")
@@ -119,37 +124,37 @@ func (opts *FindTrackedTimesOptions) toSession(e Engine) Engine {
return sess
}
func getTrackedTimes(e Engine, options *FindTrackedTimesOptions) (trackedTimes TrackedTimeList, err error) {
func getTrackedTimes(e db.Engine, options *FindTrackedTimesOptions) (trackedTimes TrackedTimeList, err error) {
err = options.toSession(e).Find(&trackedTimes)
return
}
// GetTrackedTimes returns all tracked times that fit to the given options.
func GetTrackedTimes(opts *FindTrackedTimesOptions) (TrackedTimeList, error) {
return getTrackedTimes(x, opts)
return getTrackedTimes(db.DefaultContext().Engine(), opts)
}
// CountTrackedTimes returns count of tracked times that fit to the given options.
func CountTrackedTimes(opts *FindTrackedTimesOptions) (int64, error) {
sess := x.Where(opts.toCond())
sess := db.DefaultContext().Engine().Where(opts.toCond())
if opts.RepositoryID > 0 || opts.MilestoneID > 0 {
sess = sess.Join("INNER", "issue", "issue.id = tracked_time.issue_id")
}
return sess.Count(&TrackedTime{})
}
func getTrackedSeconds(e Engine, opts FindTrackedTimesOptions) (trackedSeconds int64, err error) {
func getTrackedSeconds(e db.Engine, opts FindTrackedTimesOptions) (trackedSeconds int64, err error) {
return opts.toSession(e).SumInt(&TrackedTime{}, "time")
}
// GetTrackedSeconds return sum of seconds
func GetTrackedSeconds(opts FindTrackedTimesOptions) (int64, error) {
return getTrackedSeconds(x, opts)
return getTrackedSeconds(db.DefaultContext().Engine(), opts)
}
// AddTime will add the given time (in seconds) to the issue
func AddTime(user *User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
@@ -179,7 +184,7 @@ func AddTime(user *User, issue *Issue, amount int64, created time.Time) (*Tracke
return t, sess.Commit()
}
func addTime(e Engine, user *User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) {
func addTime(e db.Engine, user *User, issue *Issue, amount int64, created time.Time) (*TrackedTime, error) {
if created.IsZero() {
created = time.Now()
}
@@ -225,7 +230,7 @@ func TotalTimes(options *FindTrackedTimesOptions) (map[*User]string, error) {
// DeleteIssueUserTimes deletes times for issue
func DeleteIssueUserTimes(issue *Issue, user *User) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
@@ -263,7 +268,7 @@ func DeleteIssueUserTimes(issue *Issue, user *User) error {
// DeleteTime delete a specific Time
func DeleteTime(t *TrackedTime) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
@@ -291,7 +296,7 @@ func DeleteTime(t *TrackedTime) error {
return sess.Commit()
}
func deleteTimes(e Engine, opts FindTrackedTimesOptions) (removedTime int64, err error) {
func deleteTimes(e db.Engine, opts FindTrackedTimesOptions) (removedTime int64, err error) {
removedTime, err = getTrackedSeconds(e, opts)
if err != nil || removedTime == 0 {
return
@@ -301,7 +306,7 @@ func deleteTimes(e Engine, opts FindTrackedTimesOptions) (removedTime int64, err
return
}
func deleteTime(e Engine, t *TrackedTime) error {
func deleteTime(e db.Engine, t *TrackedTime) error {
if t.Deleted {
return ErrNotExist{ID: t.ID}
}
@@ -313,7 +318,7 @@ func deleteTime(e Engine, t *TrackedTime) error {
// GetTrackedTimeByID returns raw TrackedTime without loading attributes by id
func GetTrackedTimeByID(id int64) (*TrackedTime, error) {
time := new(TrackedTime)
has, err := x.ID(id).Get(time)
has, err := db.DefaultContext().Engine().ID(id).Get(time)
if err != nil {
return nil, err
} else if !has {
+6 -5
View File
@@ -8,11 +8,12 @@ import (
"testing"
"time"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestAddTime(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
user3, err := GetUserByID(3)
assert.NoError(t, err)
@@ -27,15 +28,15 @@ func TestAddTime(t *testing.T) {
assert.Equal(t, int64(1), trackedTime.IssueID)
assert.Equal(t, int64(3661), trackedTime.Time)
tt := AssertExistsAndLoadBean(t, &TrackedTime{UserID: 3, IssueID: 1}).(*TrackedTime)
tt := db.AssertExistsAndLoadBean(t, &TrackedTime{UserID: 3, IssueID: 1}).(*TrackedTime)
assert.Equal(t, int64(3661), tt.Time)
comment := AssertExistsAndLoadBean(t, &Comment{Type: CommentTypeAddTimeManual, PosterID: 3, IssueID: 1}).(*Comment)
comment := db.AssertExistsAndLoadBean(t, &Comment{Type: CommentTypeAddTimeManual, PosterID: 3, IssueID: 1}).(*Comment)
assert.Equal(t, comment.Content, "1h 1min 1s")
}
func TestGetTrackedTimes(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
// by Issue
times, err := GetTrackedTimes(&FindTrackedTimesOptions{IssueID: 1})
@@ -76,7 +77,7 @@ func TestGetTrackedTimes(t *testing.T) {
}
func TestTotalTimes(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
total, err := TotalTimes(&FindTrackedTimesOptions{IssueID: 1})
assert.NoError(t, err)
+12 -6
View File
@@ -6,6 +6,8 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
)
// IssueUser represents an issue-user relation.
@@ -17,7 +19,11 @@ type IssueUser struct {
IsMentioned bool
}
func newIssueUsers(e Engine, repo *Repository, issue *Issue) error {
func init() {
db.RegisterModel(new(IssueUser))
}
func newIssueUsers(e db.Engine, repo *Repository, issue *Issue) error {
assignees, err := repo.getAssignees(e)
if err != nil {
return fmt.Errorf("getAssignees: %v", err)
@@ -51,27 +57,27 @@ func newIssueUsers(e Engine, repo *Repository, issue *Issue) error {
// UpdateIssueUserByRead updates issue-user relation for reading.
func UpdateIssueUserByRead(uid, issueID int64) error {
_, err := x.Exec("UPDATE `issue_user` SET is_read=? WHERE uid=? AND issue_id=?", true, uid, issueID)
_, err := db.DefaultContext().Engine().Exec("UPDATE `issue_user` SET is_read=? WHERE uid=? AND issue_id=?", true, uid, issueID)
return err
}
// UpdateIssueUsersByMentions updates issue-user pairs by mentioning.
func UpdateIssueUsersByMentions(ctx DBContext, issueID int64, uids []int64) error {
func UpdateIssueUsersByMentions(ctx *db.Context, issueID int64, uids []int64) error {
for _, uid := range uids {
iu := &IssueUser{
UID: uid,
IssueID: issueID,
}
has, err := ctx.e.Get(iu)
has, err := ctx.Engine().Get(iu)
if err != nil {
return err
}
iu.IsMentioned = true
if has {
_, err = ctx.e.ID(iu.ID).Cols("is_mentioned").Update(iu)
_, err = ctx.Engine().ID(iu.ID).Cols("is_mentioned").Update(iu)
} else {
_, err = ctx.e.Insert(iu)
_, err = ctx.Engine().Insert(iu)
}
if err != nil {
return err
+16 -15
View File
@@ -7,13 +7,14 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func Test_newIssueUsers(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
newIssue := &Issue{
RepoID: repo.ID,
PosterID: 4,
@@ -23,35 +24,35 @@ func Test_newIssueUsers(t *testing.T) {
}
// artificially insert new issue
AssertSuccessfulInsert(t, newIssue)
db.AssertSuccessfulInsert(t, newIssue)
assert.NoError(t, newIssueUsers(x, repo, newIssue))
assert.NoError(t, newIssueUsers(db.DefaultContext().Engine(), repo, newIssue))
// issue_user table should now have entries for new issue
AssertExistsAndLoadBean(t, &IssueUser{IssueID: newIssue.ID, UID: newIssue.PosterID})
AssertExistsAndLoadBean(t, &IssueUser{IssueID: newIssue.ID, UID: repo.OwnerID})
db.AssertExistsAndLoadBean(t, &IssueUser{IssueID: newIssue.ID, UID: newIssue.PosterID})
db.AssertExistsAndLoadBean(t, &IssueUser{IssueID: newIssue.ID, UID: repo.OwnerID})
}
func TestUpdateIssueUserByRead(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
assert.NoError(t, db.PrepareTestDatabase())
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
assert.NoError(t, UpdateIssueUserByRead(4, issue.ID))
AssertExistsAndLoadBean(t, &IssueUser{IssueID: issue.ID, UID: 4}, "is_read=1")
db.AssertExistsAndLoadBean(t, &IssueUser{IssueID: issue.ID, UID: 4}, "is_read=1")
assert.NoError(t, UpdateIssueUserByRead(4, issue.ID))
AssertExistsAndLoadBean(t, &IssueUser{IssueID: issue.ID, UID: 4}, "is_read=1")
db.AssertExistsAndLoadBean(t, &IssueUser{IssueID: issue.ID, UID: 4}, "is_read=1")
assert.NoError(t, UpdateIssueUserByRead(NonexistentID, NonexistentID))
assert.NoError(t, UpdateIssueUserByRead(db.NonexistentID, db.NonexistentID))
}
func TestUpdateIssueUsersByMentions(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
assert.NoError(t, db.PrepareTestDatabase())
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
uids := []int64{2, 5}
assert.NoError(t, UpdateIssueUsersByMentions(DefaultDBContext(), issue.ID, uids))
assert.NoError(t, UpdateIssueUsersByMentions(db.DefaultContext(), issue.ID, uids))
for _, uid := range uids {
AssertExistsAndLoadBean(t, &IssueUser{IssueID: issue.ID, UID: uid}, "is_mentioned=1")
db.AssertExistsAndLoadBean(t, &IssueUser{IssueID: issue.ID, UID: uid}, "is_mentioned=1")
}
}
+17 -12
View File
@@ -5,6 +5,7 @@
package models
import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
)
@@ -18,12 +19,16 @@ type IssueWatch struct {
UpdatedUnix timeutil.TimeStamp `xorm:"updated NOT NULL"`
}
func init() {
db.RegisterModel(new(IssueWatch))
}
// IssueWatchList contains IssueWatch
type IssueWatchList []*IssueWatch
// CreateOrUpdateIssueWatch set watching for a user and issue
func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error {
iw, exists, err := getIssueWatch(x, userID, issueID)
iw, exists, err := getIssueWatch(db.DefaultContext().Engine(), userID, issueID)
if err != nil {
return err
}
@@ -35,13 +40,13 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error {
IsWatching: isWatching,
}
if _, err := x.Insert(iw); err != nil {
if _, err := db.DefaultContext().Engine().Insert(iw); err != nil {
return err
}
} else {
iw.IsWatching = isWatching
if _, err := x.ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil {
if _, err := db.DefaultContext().Engine().ID(iw.ID).Cols("is_watching", "updated_unix").Update(iw); err != nil {
return err
}
}
@@ -51,11 +56,11 @@ func CreateOrUpdateIssueWatch(userID, issueID int64, isWatching bool) error {
// GetIssueWatch returns all IssueWatch objects from db by user and issue
// the current Web-UI need iw object for watchers AND explicit non-watchers
func GetIssueWatch(userID, issueID int64) (iw *IssueWatch, exists bool, err error) {
return getIssueWatch(x, userID, issueID)
return getIssueWatch(db.DefaultContext().Engine(), userID, issueID)
}
// Return watcher AND explicit non-watcher if entry in db exist
func getIssueWatch(e Engine, userID, issueID int64) (iw *IssueWatch, exists bool, err error) {
func getIssueWatch(e db.Engine, userID, issueID int64) (iw *IssueWatch, exists bool, err error) {
iw = new(IssueWatch)
exists, err = e.
Where("user_id = ?", userID).
@@ -67,14 +72,14 @@ func getIssueWatch(e Engine, userID, issueID int64) (iw *IssueWatch, exists bool
// CheckIssueWatch check if an user is watching an issue
// it takes participants and repo watch into account
func CheckIssueWatch(user *User, issue *Issue) (bool, error) {
iw, exist, err := getIssueWatch(x, user.ID, issue.ID)
iw, exist, err := getIssueWatch(db.DefaultContext().Engine(), user.ID, issue.ID)
if err != nil {
return false, err
}
if exist {
return iw.IsWatching, nil
}
w, err := getWatch(x, user.ID, issue.RepoID)
w, err := getWatch(db.DefaultContext().Engine(), user.ID, issue.RepoID)
if err != nil {
return false, err
}
@@ -85,10 +90,10 @@ func CheckIssueWatch(user *User, issue *Issue) (bool, error) {
// but avoids joining with `user` for performance reasons
// User permissions must be verified elsewhere if required
func GetIssueWatchersIDs(issueID int64, watching bool) ([]int64, error) {
return getIssueWatchersIDs(x, issueID, watching)
return getIssueWatchersIDs(db.DefaultContext().Engine(), issueID, watching)
}
func getIssueWatchersIDs(e Engine, issueID int64, watching bool) ([]int64, error) {
func getIssueWatchersIDs(e db.Engine, issueID int64, watching bool) ([]int64, error) {
ids := make([]int64, 0, 64)
return ids, e.Table("issue_watch").
Where("issue_id=?", issueID).
@@ -99,10 +104,10 @@ func getIssueWatchersIDs(e Engine, issueID int64, watching bool) ([]int64, error
// GetIssueWatchers returns watchers/unwatchers of a given issue
func GetIssueWatchers(issueID int64, listOptions ListOptions) (IssueWatchList, error) {
return getIssueWatchers(x, issueID, listOptions)
return getIssueWatchers(db.DefaultContext().Engine(), issueID, listOptions)
}
func getIssueWatchers(e Engine, issueID int64, listOptions ListOptions) (IssueWatchList, error) {
func getIssueWatchers(e db.Engine, issueID int64, listOptions ListOptions) (IssueWatchList, error) {
sess := e.
Where("`issue_watch`.issue_id = ?", issueID).
And("`issue_watch`.is_watching = ?", true).
@@ -119,7 +124,7 @@ func getIssueWatchers(e Engine, issueID int64, listOptions ListOptions) (IssueWa
return watches, sess.Find(&watches)
}
func removeIssueWatchersByRepoID(e Engine, userID, repoID int64) error {
func removeIssueWatchersByRepoID(e db.Engine, userID, repoID int64) error {
_, err := e.
Join("INNER", "issue", "`issue`.id = `issue_watch`.issue_id AND `issue`.repo_id = ?", repoID).
Where("`issue_watch`.user_id = ?", userID).
+6 -5
View File
@@ -7,23 +7,24 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestCreateOrUpdateIssueWatch(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
assert.NoError(t, CreateOrUpdateIssueWatch(3, 1, true))
iw := AssertExistsAndLoadBean(t, &IssueWatch{UserID: 3, IssueID: 1}).(*IssueWatch)
iw := db.AssertExistsAndLoadBean(t, &IssueWatch{UserID: 3, IssueID: 1}).(*IssueWatch)
assert.True(t, iw.IsWatching)
assert.NoError(t, CreateOrUpdateIssueWatch(1, 1, false))
iw = AssertExistsAndLoadBean(t, &IssueWatch{UserID: 1, IssueID: 1}).(*IssueWatch)
iw = db.AssertExistsAndLoadBean(t, &IssueWatch{UserID: 1, IssueID: 1}).(*IssueWatch)
assert.False(t, iw.IsWatching)
}
func TestGetIssueWatch(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
_, exists, err := GetIssueWatch(9, 1)
assert.True(t, exists)
@@ -40,7 +41,7 @@ func TestGetIssueWatch(t *testing.T) {
}
func TestGetIssueWatchers(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
iws, err := GetIssueWatchers(1, ListOptions{})
assert.NoError(t, err)
+12 -13
View File
@@ -7,10 +7,9 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/references"
"xorm.io/xorm"
)
type crossReference struct {
@@ -27,7 +26,7 @@ type crossReferencesContext struct {
RemoveOld bool
}
func findOldCrossReferences(e Engine, issueID, commentID int64) ([]*Comment, error) {
func findOldCrossReferences(e db.Engine, issueID, commentID int64) ([]*Comment, error) {
active := make([]*Comment, 0, 10)
return active, e.Where("`ref_action` IN (?, ?, ?)", references.XRefActionNone, references.XRefActionCloses, references.XRefActionReopens).
And("`ref_issue_id` = ?", issueID).
@@ -35,7 +34,7 @@ func findOldCrossReferences(e Engine, issueID, commentID int64) ([]*Comment, err
Find(&active)
}
func neuterCrossReferences(e Engine, issueID, commentID int64) error {
func neuterCrossReferences(e db.Engine, issueID, commentID int64) error {
active, err := findOldCrossReferences(e, issueID, commentID)
if err != nil {
return err
@@ -47,7 +46,7 @@ func neuterCrossReferences(e Engine, issueID, commentID int64) error {
return neuterCrossReferencesIds(e, ids)
}
func neuterCrossReferencesIds(e Engine, ids []int64) error {
func neuterCrossReferencesIds(e db.Engine, ids []int64) error {
_, err := e.In("id", ids).Cols("`ref_action`").Update(&Comment{RefAction: references.XRefActionNeutered})
return err
}
@@ -60,7 +59,7 @@ func neuterCrossReferencesIds(e Engine, ids []int64) error {
// \/ \/ \/
//
func (issue *Issue) addCrossReferences(e *xorm.Session, doer *User, removeOld bool) error {
func (issue *Issue) addCrossReferences(e db.Engine, doer *User, removeOld bool) error {
var commentType CommentType
if issue.IsPull {
commentType = CommentTypePullRef
@@ -76,7 +75,7 @@ func (issue *Issue) addCrossReferences(e *xorm.Session, doer *User, removeOld bo
return issue.createCrossReferences(e, ctx, issue.Title, issue.Content)
}
func (issue *Issue) createCrossReferences(e *xorm.Session, ctx *crossReferencesContext, plaincontent, mdcontent string) error {
func (issue *Issue) createCrossReferences(e db.Engine, ctx *crossReferencesContext, plaincontent, mdcontent string) error {
xreflist, err := ctx.OrigIssue.getCrossReferences(e, ctx, plaincontent, mdcontent)
if err != nil {
return err
@@ -134,7 +133,7 @@ func (issue *Issue) createCrossReferences(e *xorm.Session, ctx *crossReferencesC
return nil
}
func (issue *Issue) getCrossReferences(e *xorm.Session, ctx *crossReferencesContext, plaincontent, mdcontent string) ([]*crossReference, error) {
func (issue *Issue) getCrossReferences(e db.Engine, ctx *crossReferencesContext, plaincontent, mdcontent string) ([]*crossReference, error) {
xreflist := make([]*crossReference, 0, 5)
var (
refRepo *Repository
@@ -192,7 +191,7 @@ func (issue *Issue) updateCrossReferenceList(list []*crossReference, xref *cross
}
// verifyReferencedIssue will check if the referenced issue exists, and whether the doer has permission to do what
func (issue *Issue) verifyReferencedIssue(e Engine, ctx *crossReferencesContext, repo *Repository,
func (issue *Issue) verifyReferencedIssue(e db.Engine, ctx *crossReferencesContext, repo *Repository,
ref references.IssueReference) (*Issue, references.XRefAction, error) {
refIssue := &Issue{RepoID: repo.ID, Index: ref.Index}
refAction := ref.Action
@@ -241,7 +240,7 @@ func (issue *Issue) verifyReferencedIssue(e Engine, ctx *crossReferencesContext,
// \/ \/ \/ \/ \/
//
func (comment *Comment) addCrossReferences(e *xorm.Session, doer *User, removeOld bool) error {
func (comment *Comment) addCrossReferences(e db.Engine, doer *User, removeOld bool) error {
if comment.Type != CommentTypeCode && comment.Type != CommentTypeComment {
return nil
}
@@ -258,7 +257,7 @@ func (comment *Comment) addCrossReferences(e *xorm.Session, doer *User, removeOl
return comment.Issue.createCrossReferences(e, ctx, "", comment.Content)
}
func (comment *Comment) neuterCrossReferences(e Engine) error {
func (comment *Comment) neuterCrossReferences(e db.Engine) error {
return neuterCrossReferences(e, comment.IssueID, comment.ID)
}
@@ -278,7 +277,7 @@ func (comment *Comment) LoadRefIssue() (err error) {
}
comment.RefIssue, err = GetIssueByID(comment.RefIssueID)
if err == nil {
err = comment.RefIssue.loadRepo(x)
err = comment.RefIssue.loadRepo(db.DefaultContext().Engine())
}
return
}
@@ -338,7 +337,7 @@ func (comment *Comment) RefIssueIdent() string {
// ResolveCrossReferences will return the list of references to close/reopen by this PR
func (pr *PullRequest) ResolveCrossReferences() ([]*Comment, error) {
unfiltered := make([]*Comment, 0, 5)
if err := x.
if err := db.DefaultContext().Engine().
Where("ref_repo_id = ? AND ref_issue_id = ?", pr.Issue.RepoID, pr.Issue.ID).
In("ref_action", []references.XRefAction{references.XRefActionCloses, references.XRefActionReopens}).
OrderBy("id").
+27 -26
View File
@@ -8,13 +8,14 @@ import (
"fmt"
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/references"
"github.com/stretchr/testify/assert"
)
func TestXRef_AddCrossReferences(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
// Issue #1 to test against
itarget := testCreateIssue(t, 1, 2, "title1", "content1", false)
@@ -22,7 +23,7 @@ func TestXRef_AddCrossReferences(t *testing.T) {
// PR to close issue #1
content := fmt.Sprintf("content2, closes #%d", itarget.Index)
pr := testCreateIssue(t, 1, 2, "title2", content, true)
ref := AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: pr.ID, RefCommentID: 0}).(*Comment)
ref := db.AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: pr.ID, RefCommentID: 0}).(*Comment)
assert.Equal(t, CommentTypePullRef, ref.Type)
assert.Equal(t, pr.RepoID, ref.RefRepoID)
assert.True(t, ref.RefIsPull)
@@ -31,7 +32,7 @@ func TestXRef_AddCrossReferences(t *testing.T) {
// Comment on PR to reopen issue #1
content = fmt.Sprintf("content2, reopens #%d", itarget.Index)
c := testCreateComment(t, 1, 2, pr.ID, content)
ref = AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: pr.ID, RefCommentID: c.ID}).(*Comment)
ref = db.AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: pr.ID, RefCommentID: c.ID}).(*Comment)
assert.Equal(t, CommentTypeCommentRef, ref.Type)
assert.Equal(t, pr.RepoID, ref.RefRepoID)
assert.True(t, ref.RefIsPull)
@@ -40,7 +41,7 @@ func TestXRef_AddCrossReferences(t *testing.T) {
// Issue mentioning issue #1
content = fmt.Sprintf("content3, mentions #%d", itarget.Index)
i := testCreateIssue(t, 1, 2, "title3", content, false)
ref = AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0}).(*Comment)
ref = db.AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0}).(*Comment)
assert.Equal(t, CommentTypeIssueRef, ref.Type)
assert.Equal(t, pr.RepoID, ref.RefRepoID)
assert.False(t, ref.RefIsPull)
@@ -52,7 +53,7 @@ func TestXRef_AddCrossReferences(t *testing.T) {
// Cross-reference to issue #4 by admin
content = fmt.Sprintf("content5, mentions user3/repo3#%d", itarget.Index)
i = testCreateIssue(t, 2, 1, "title5", content, false)
ref = AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0}).(*Comment)
ref = db.AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0}).(*Comment)
assert.Equal(t, CommentTypeIssueRef, ref.Type)
assert.Equal(t, i.RepoID, ref.RefRepoID)
assert.False(t, ref.RefIsPull)
@@ -61,11 +62,11 @@ func TestXRef_AddCrossReferences(t *testing.T) {
// Cross-reference to issue #4 with no permission
content = fmt.Sprintf("content6, mentions user3/repo3#%d", itarget.Index)
i = testCreateIssue(t, 4, 5, "title6", content, false)
AssertNotExistsBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0})
db.AssertNotExistsBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0})
}
func TestXRef_NeuterCrossReferences(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
// Issue #1 to test against
itarget := testCreateIssue(t, 1, 2, "title1", "content1", false)
@@ -73,23 +74,23 @@ func TestXRef_NeuterCrossReferences(t *testing.T) {
// Issue mentioning issue #1
title := fmt.Sprintf("title2, mentions #%d", itarget.Index)
i := testCreateIssue(t, 1, 2, title, "content2", false)
ref := AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0}).(*Comment)
ref := db.AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0}).(*Comment)
assert.Equal(t, CommentTypeIssueRef, ref.Type)
assert.Equal(t, references.XRefActionNone, ref.RefAction)
d := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
d := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
i.Title = "title2, no mentions"
assert.NoError(t, i.ChangeTitle(d, title))
ref = AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0}).(*Comment)
ref = db.AssertExistsAndLoadBean(t, &Comment{IssueID: itarget.ID, RefIssueID: i.ID, RefCommentID: 0}).(*Comment)
assert.Equal(t, CommentTypeIssueRef, ref.Type)
assert.Equal(t, references.XRefActionNeutered, ref.RefAction)
}
func TestXRef_ResolveCrossReferences(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
d := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
d := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
i1 := testCreateIssue(t, 1, 2, "title1", "content1", false)
i2 := testCreateIssue(t, 1, 2, "title2", "content2", false)
@@ -98,21 +99,21 @@ func TestXRef_ResolveCrossReferences(t *testing.T) {
assert.NoError(t, err)
pr := testCreatePR(t, 1, 2, "titlepr", fmt.Sprintf("closes #%d", i1.Index))
rp := AssertExistsAndLoadBean(t, &Comment{IssueID: i1.ID, RefIssueID: pr.Issue.ID, RefCommentID: 0}).(*Comment)
rp := db.AssertExistsAndLoadBean(t, &Comment{IssueID: i1.ID, RefIssueID: pr.Issue.ID, RefCommentID: 0}).(*Comment)
c1 := testCreateComment(t, 1, 2, pr.Issue.ID, fmt.Sprintf("closes #%d", i2.Index))
r1 := AssertExistsAndLoadBean(t, &Comment{IssueID: i2.ID, RefIssueID: pr.Issue.ID, RefCommentID: c1.ID}).(*Comment)
r1 := db.AssertExistsAndLoadBean(t, &Comment{IssueID: i2.ID, RefIssueID: pr.Issue.ID, RefCommentID: c1.ID}).(*Comment)
// Must be ignored
c2 := testCreateComment(t, 1, 2, pr.Issue.ID, fmt.Sprintf("mentions #%d", i2.Index))
AssertExistsAndLoadBean(t, &Comment{IssueID: i2.ID, RefIssueID: pr.Issue.ID, RefCommentID: c2.ID})
db.AssertExistsAndLoadBean(t, &Comment{IssueID: i2.ID, RefIssueID: pr.Issue.ID, RefCommentID: c2.ID})
// Must be superseded by c4/r4
c3 := testCreateComment(t, 1, 2, pr.Issue.ID, fmt.Sprintf("reopens #%d", i3.Index))
AssertExistsAndLoadBean(t, &Comment{IssueID: i3.ID, RefIssueID: pr.Issue.ID, RefCommentID: c3.ID})
db.AssertExistsAndLoadBean(t, &Comment{IssueID: i3.ID, RefIssueID: pr.Issue.ID, RefCommentID: c3.ID})
c4 := testCreateComment(t, 1, 2, pr.Issue.ID, fmt.Sprintf("closes #%d", i3.Index))
r4 := AssertExistsAndLoadBean(t, &Comment{IssueID: i3.ID, RefIssueID: pr.Issue.ID, RefCommentID: c4.ID}).(*Comment)
r4 := db.AssertExistsAndLoadBean(t, &Comment{IssueID: i3.ID, RefIssueID: pr.Issue.ID, RefCommentID: c4.ID}).(*Comment)
refs, err := pr.ResolveCrossReferences()
assert.NoError(t, err)
@@ -123,10 +124,10 @@ func TestXRef_ResolveCrossReferences(t *testing.T) {
}
func testCreateIssue(t *testing.T, repo, doer int64, title, content string, ispull bool) *Issue {
r := AssertExistsAndLoadBean(t, &Repository{ID: repo}).(*Repository)
d := AssertExistsAndLoadBean(t, &User{ID: doer}).(*User)
r := db.AssertExistsAndLoadBean(t, &Repository{ID: repo}).(*Repository)
d := db.AssertExistsAndLoadBean(t, &User{ID: doer}).(*User)
idx, err := GetNextResourceIndex("issue_index", r.ID)
idx, err := db.GetNextResourceIndex("issue_index", r.ID)
assert.NoError(t, err)
i := &Issue{
RepoID: r.ID,
@@ -138,7 +139,7 @@ func testCreateIssue(t *testing.T, repo, doer int64, title, content string, ispu
Index: idx,
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
assert.NoError(t, sess.Begin())
@@ -155,8 +156,8 @@ func testCreateIssue(t *testing.T, repo, doer int64, title, content string, ispu
}
func testCreatePR(t *testing.T, repo, doer int64, title, content string) *PullRequest {
r := AssertExistsAndLoadBean(t, &Repository{ID: repo}).(*Repository)
d := AssertExistsAndLoadBean(t, &User{ID: doer}).(*User)
r := db.AssertExistsAndLoadBean(t, &Repository{ID: repo}).(*Repository)
d := db.AssertExistsAndLoadBean(t, &User{ID: doer}).(*User)
i := &Issue{RepoID: r.ID, PosterID: d.ID, Poster: d, Title: title, Content: content, IsPull: true}
pr := &PullRequest{HeadRepoID: repo, BaseRepoID: repo, HeadBranch: "head", BaseBranch: "base", Status: PullRequestStatusMergeable}
assert.NoError(t, NewPullRequest(r, i, nil, nil, pr))
@@ -165,11 +166,11 @@ func testCreatePR(t *testing.T, repo, doer int64, title, content string) *PullRe
}
func testCreateComment(t *testing.T, repo, doer, issue int64, content string) *Comment {
d := AssertExistsAndLoadBean(t, &User{ID: doer}).(*User)
i := AssertExistsAndLoadBean(t, &Issue{ID: issue}).(*Issue)
d := db.AssertExistsAndLoadBean(t, &User{ID: doer}).(*User)
i := db.AssertExistsAndLoadBean(t, &Issue{ID: issue}).(*Issue)
c := &Comment{Type: CommentTypeComment, PosterID: doer, Poster: d, IssueID: issue, Issue: i, Content: content}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
assert.NoError(t, sess.Begin())
_, err := sess.Insert(c)
+14 -9
View File
@@ -7,6 +7,7 @@ package models
import (
"errors"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/lfs"
"code.gitea.io/gitea/modules/timeutil"
@@ -22,6 +23,10 @@ type LFSMetaObject struct {
CreatedUnix timeutil.TimeStamp `xorm:"created"`
}
func init() {
db.RegisterModel(new(LFSMetaObject))
}
// LFSTokenResponse defines the JSON structure in which the JWT token is stored.
// This structure is fetched via SSH and passed by the Git LFS client to the server
// endpoint for authorization.
@@ -39,7 +44,7 @@ var ErrLFSObjectNotExist = errors.New("LFS Meta object does not exist")
func NewLFSMetaObject(m *LFSMetaObject) (*LFSMetaObject, error) {
var err error
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return nil, err
@@ -71,7 +76,7 @@ func (repo *Repository) GetLFSMetaObjectByOid(oid string) (*LFSMetaObject, error
}
m := &LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}, RepositoryID: repo.ID}
has, err := x.Get(m)
has, err := db.DefaultContext().Engine().Get(m)
if err != nil {
return nil, err
} else if !has {
@@ -87,7 +92,7 @@ func (repo *Repository) RemoveLFSMetaObjectByOid(oid string) (int64, error) {
return 0, ErrLFSObjectNotExist
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return -1, err
@@ -108,7 +113,7 @@ func (repo *Repository) RemoveLFSMetaObjectByOid(oid string) (int64, error) {
// GetLFSMetaObjects returns all LFSMetaObjects associated with a repository
func (repo *Repository) GetLFSMetaObjects(page, pageSize int) ([]*LFSMetaObject, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if page >= 0 && pageSize > 0 {
@@ -124,23 +129,23 @@ func (repo *Repository) GetLFSMetaObjects(page, pageSize int) ([]*LFSMetaObject,
// CountLFSMetaObjects returns a count of all LFSMetaObjects associated with a repository
func (repo *Repository) CountLFSMetaObjects() (int64, error) {
return x.Count(&LFSMetaObject{RepositoryID: repo.ID})
return db.DefaultContext().Engine().Count(&LFSMetaObject{RepositoryID: repo.ID})
}
// LFSObjectAccessible checks if a provided Oid is accessible to the user
func LFSObjectAccessible(user *User, oid string) (bool, error) {
if user.IsAdmin {
count, err := x.Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}})
count, err := db.DefaultContext().Engine().Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}})
return count > 0, err
}
cond := accessibleRepositoryCondition(user)
count, err := x.Where(cond).Join("INNER", "repository", "`lfs_meta_object`.repository_id = `repository`.id").Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}})
count, err := db.DefaultContext().Engine().Where(cond).Join("INNER", "repository", "`lfs_meta_object`.repository_id = `repository`.id").Count(&LFSMetaObject{Pointer: lfs.Pointer{Oid: oid}})
return count > 0, err
}
// LFSAutoAssociate auto associates accessible LFSMetaObjects
func LFSAutoAssociate(metas []*LFSMetaObject, user *User, repoID int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -179,7 +184,7 @@ func IterateLFS(f func(mo *LFSMetaObject) error) error {
const batchSize = 100
for {
mos := make([]*LFSMetaObject, 0, batchSize)
if err := x.Limit(batchSize, start).Find(&mos); err != nil {
if err := db.DefaultContext().Engine().Limit(batchSize, start).Find(&mos); err != nil {
return err
}
if len(mos) == 0 {
+11 -6
View File
@@ -10,6 +10,7 @@ import (
"strings"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"xorm.io/xorm"
@@ -26,6 +27,10 @@ type LFSLock struct {
Created time.Time `xorm:"created"`
}
func init() {
db.RegisterModel(new(LFSLock))
}
// BeforeInsert is invoked from XORM before inserting an object of this type.
func (l *LFSLock) BeforeInsert() {
l.OwnerID = l.Owner.ID
@@ -67,7 +72,7 @@ func CreateLFSLock(lock *LFSLock) (*LFSLock, error) {
return nil, err
}
_, err = x.InsertOne(lock)
_, err = db.DefaultContext().Engine().InsertOne(lock)
return lock, err
}
@@ -75,7 +80,7 @@ func CreateLFSLock(lock *LFSLock) (*LFSLock, error) {
func GetLFSLock(repo *Repository, path string) (*LFSLock, error) {
path = cleanPath(path)
rel := &LFSLock{RepoID: repo.ID}
has, err := x.Where("lower(path) = ?", strings.ToLower(path)).Get(rel)
has, err := db.DefaultContext().Engine().Where("lower(path) = ?", strings.ToLower(path)).Get(rel)
if err != nil {
return nil, err
}
@@ -88,7 +93,7 @@ func GetLFSLock(repo *Repository, path string) (*LFSLock, error) {
// GetLFSLockByID returns release by given id.
func GetLFSLockByID(id int64) (*LFSLock, error) {
lock := new(LFSLock)
has, err := x.ID(id).Get(lock)
has, err := db.DefaultContext().Engine().ID(id).Get(lock)
if err != nil {
return nil, err
} else if !has {
@@ -99,7 +104,7 @@ func GetLFSLockByID(id int64) (*LFSLock, error) {
// GetLFSLockByRepoID returns a list of locks of repository.
func GetLFSLockByRepoID(repoID int64, page, pageSize int) ([]*LFSLock, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if page >= 0 && pageSize > 0 {
@@ -115,7 +120,7 @@ func GetLFSLockByRepoID(repoID int64, page, pageSize int) ([]*LFSLock, error) {
// CountLFSLockByRepoID returns a count of all LFSLocks associated with a repository.
func CountLFSLockByRepoID(repoID int64) (int64, error) {
return x.Count(&LFSLock{RepoID: repoID})
return db.DefaultContext().Engine().Count(&LFSLock{RepoID: repoID})
}
// DeleteLFSLockByID deletes a lock by given ID.
@@ -134,7 +139,7 @@ func DeleteLFSLockByID(id int64, u *User, force bool) (*LFSLock, error) {
return nil, fmt.Errorf("user doesn't own lock and force flag is not set")
}
_, err = x.ID(id).Delete(new(LFSLock))
_, err = db.DefaultContext().Engine().ID(id).Delete(new(LFSLock))
return lock, err
}
+3 -2
View File
@@ -5,6 +5,7 @@
package models
import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"xorm.io/xorm"
@@ -20,7 +21,7 @@ type Paginator interface {
func getPaginatedSession(p Paginator) *xorm.Session {
skip, take := p.GetSkipTake()
return x.Limit(take, skip)
return db.DefaultContext().Engine().Limit(take, skip)
}
// setSessionPagination sets pagination for a database session
@@ -31,7 +32,7 @@ func setSessionPagination(sess *xorm.Session, p Paginator) *xorm.Session {
}
// setSessionPagination sets pagination for a database engine
func setEnginePagination(e Engine, p Paginator) Engine {
func setEnginePagination(e db.Engine, p Paginator) db.Engine {
skip, take := p.GetSkipTake()
return e.Limit(take, skip)
+20 -15
View File
@@ -9,6 +9,7 @@ import (
"reflect"
"strconv"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/timeutil"
@@ -118,6 +119,10 @@ type LoginSource struct {
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(LoginSource))
}
// Cell2Int64 converts a xorm.Cell type to int64,
// and handles possible irregular cases.
func Cell2Int64(val xorm.Cell) int64 {
@@ -203,7 +208,7 @@ func (source *LoginSource) SkipVerify() bool {
// CreateLoginSource inserts a LoginSource in the DB if not already
// existing with the given name.
func CreateLoginSource(source *LoginSource) error {
has, err := x.Where("name=?", source.Name).Exist(new(LoginSource))
has, err := db.DefaultContext().Engine().Where("name=?", source.Name).Exist(new(LoginSource))
if err != nil {
return err
} else if has {
@@ -214,7 +219,7 @@ func CreateLoginSource(source *LoginSource) error {
source.IsSyncEnabled = false
}
_, err = x.Insert(source)
_, err = db.DefaultContext().Engine().Insert(source)
if err != nil {
return err
}
@@ -235,7 +240,7 @@ func CreateLoginSource(source *LoginSource) error {
err = registerableSource.RegisterSource()
if err != nil {
// remove the LoginSource in case of errors while registering configuration
if _, err := x.Delete(source); err != nil {
if _, err := db.DefaultContext().Engine().Delete(source); err != nil {
log.Error("CreateLoginSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
@@ -245,13 +250,13 @@ func CreateLoginSource(source *LoginSource) error {
// LoginSources returns a slice of all login sources found in DB.
func LoginSources() ([]*LoginSource, error) {
auths := make([]*LoginSource, 0, 6)
return auths, x.Find(&auths)
return auths, db.DefaultContext().Engine().Find(&auths)
}
// LoginSourcesByType returns all sources of the specified type
func LoginSourcesByType(loginType LoginType) ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 1)
if err := x.Where("type = ?", loginType).Find(&sources); err != nil {
if err := db.DefaultContext().Engine().Where("type = ?", loginType).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@@ -260,7 +265,7 @@ func LoginSourcesByType(loginType LoginType) ([]*LoginSource, error) {
// AllActiveLoginSources returns all active sources
func AllActiveLoginSources() ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 5)
if err := x.Where("is_active = ?", true).Find(&sources); err != nil {
if err := db.DefaultContext().Engine().Where("is_active = ?", true).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@@ -269,7 +274,7 @@ func AllActiveLoginSources() ([]*LoginSource, error) {
// ActiveLoginSources returns all active sources of the specified type
func ActiveLoginSources(loginType LoginType) ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 1)
if err := x.Where("is_active = ? and type = ?", true, loginType).Find(&sources); err != nil {
if err := db.DefaultContext().Engine().Where("is_active = ? and type = ?", true, loginType).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@@ -278,7 +283,7 @@ func ActiveLoginSources(loginType LoginType) ([]*LoginSource, error) {
// IsSSPIEnabled returns true if there is at least one activated login
// source of type LoginSSPI
func IsSSPIEnabled() bool {
if !HasEngine {
if !db.HasEngine {
return false
}
sources, err := ActiveLoginSources(LoginSSPI)
@@ -300,7 +305,7 @@ func GetLoginSourceByID(id int64) (*LoginSource, error) {
return source, nil
}
has, err := x.ID(id).Get(source)
has, err := db.DefaultContext().Engine().ID(id).Get(source)
if err != nil {
return nil, err
} else if !has {
@@ -320,7 +325,7 @@ func UpdateSource(source *LoginSource) error {
}
}
_, err := x.ID(source.ID).AllCols().Update(source)
_, err := db.DefaultContext().Engine().ID(source.ID).AllCols().Update(source)
if err != nil {
return err
}
@@ -341,7 +346,7 @@ func UpdateSource(source *LoginSource) error {
err = registerableSource.RegisterSource()
if err != nil {
// restore original values since we cannot update the provider it self
if _, err := x.ID(source.ID).AllCols().Update(originalLoginSource); err != nil {
if _, err := db.DefaultContext().Engine().ID(source.ID).AllCols().Update(originalLoginSource); err != nil {
log.Error("UpdateSource: Error while wrapOpenIDConnectInitializeError: %v", err)
}
}
@@ -350,14 +355,14 @@ func UpdateSource(source *LoginSource) error {
// DeleteSource deletes a LoginSource record in DB.
func DeleteSource(source *LoginSource) error {
count, err := x.Count(&User{LoginSource: source.ID})
count, err := db.DefaultContext().Engine().Count(&User{LoginSource: source.ID})
if err != nil {
return err
} else if count > 0 {
return ErrLoginSourceInUse{source.ID}
}
count, err = x.Count(&ExternalLoginUser{LoginSourceID: source.ID})
count, err = db.DefaultContext().Engine().Count(&ExternalLoginUser{LoginSourceID: source.ID})
if err != nil {
return err
} else if count > 0 {
@@ -370,12 +375,12 @@ func DeleteSource(source *LoginSource) error {
}
}
_, err = x.ID(source.ID).Delete(new(LoginSource))
_, err = db.DefaultContext().Engine().ID(source.ID).Delete(new(LoginSource))
return err
}
// CountLoginSources returns number of login sources.
func CountLoginSources() int64 {
count, _ := x.Count(new(LoginSource))
count, _ := db.DefaultContext().Engine().Count(new(LoginSource))
return count
}
@@ -5,37 +5,16 @@
package models
import (
"encoding/json"
"io/ioutil"
"os"
"path/filepath"
"strings"
"testing"
"code.gitea.io/gitea/modules/setting"
"xorm.io/xorm/schemas"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/json"
"github.com/stretchr/testify/assert"
"xorm.io/xorm/schemas"
)
func TestDumpDatabase(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
dir, err := ioutil.TempDir(os.TempDir(), "dump")
assert.NoError(t, err)
type Version struct {
ID int64 `xorm:"pk autoincr"`
Version int64
}
assert.NoError(t, x.Sync2(new(Version)))
for _, dbName := range setting.SupportedDatabases {
dbType := setting.GetDBTypeByName(dbName)
assert.NoError(t, DumpDatabase(filepath.Join(dir, dbType+".sql"), dbType))
}
}
type TestSource struct {
Provider string
ClientID string
@@ -55,9 +34,9 @@ func (source *TestSource) ToDB() ([]byte, error) {
}
func TestDumpLoginSource(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
loginSourceSchema, err := x.TableInfo(new(LoginSource))
loginSourceSchema, err := db.TableInfo(new(LoginSource))
assert.NoError(t, err)
RegisterLoginTypeConfig(LoginOAuth2, new(TestSource))
@@ -74,7 +53,7 @@ func TestDumpLoginSource(t *testing.T) {
sb := new(strings.Builder)
x.DumpTables([]*schemas.Table{loginSourceSchema}, sb)
db.DumpTables([]*schemas.Table{loginSourceSchema}, sb)
assert.Contains(t, sb.String(), `"Provider":"ConvertibleSourceName"`)
}
+4 -2
View File
@@ -7,15 +7,17 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
// TestFixturesAreConsistent assert that test fixtures are consistent
func TestFixturesAreConsistent(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
CheckConsistencyForAll(t)
}
func TestMain(m *testing.M) {
MainTest(m, "..")
db.MainTest(m, "..")
}
+7 -6
View File
@@ -5,6 +5,7 @@
package models
import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/structs"
"xorm.io/builder"
@@ -17,7 +18,7 @@ func InsertMilestones(ms ...*Milestone) (err error) {
return nil
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -38,7 +39,7 @@ func InsertMilestones(ms ...*Milestone) (err error) {
// InsertIssues insert issues to database
func InsertIssues(issues ...*Issue) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -143,7 +144,7 @@ func InsertIssueComments(comments []*Comment) error {
issueIDs[comment.IssueID] = true
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -174,7 +175,7 @@ func InsertIssueComments(comments []*Comment) error {
// InsertPullRequests inserted pull requests
func InsertPullRequests(prs ...*PullRequest) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -194,7 +195,7 @@ func InsertPullRequests(prs ...*PullRequest) error {
// InsertReleases migrates release
func InsertReleases(rels ...*Release) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -232,7 +233,7 @@ func migratedIssueCond(tp structs.GitServiceType) builder.Cond {
// UpdateReviewsMigrationsByType updates reviews' migrations information via given git service type and original id and poster id
func UpdateReviewsMigrationsByType(tp structs.GitServiceType, originalAuthorID string, posterID int64) error {
_, err := x.Table("review").
_, err := db.DefaultContext().Engine().Table("review").
Where("original_author_id = ?", originalAuthorID).
And(migratedIssueCond(tp)).
Update(map[string]interface{}{
+6 -5
View File
@@ -14,12 +14,13 @@ import (
"testing"
"time"
"code.gitea.io/gitea/models"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
"github.com/stretchr/testify/assert"
"github.com/unknwon/com"
"xorm.io/xorm"
@@ -85,7 +86,7 @@ func removeAllWithRetry(dir string) error {
// SetEngine sets the xorm.Engine
func SetEngine() (*xorm.Engine, error) {
x, err := models.GetNewEngine()
x, err := db.GetNewEngine()
if err != nil {
return x, fmt.Errorf("Failed to connect to database: %v", err)
}
@@ -93,7 +94,7 @@ func SetEngine() (*xorm.Engine, error) {
x.SetMapper(names.GonicMapper{})
// WARNING: for serv command, MUST remove the output to os.stdout,
// so use log file to instead print to stdout.
x.SetLogger(models.NewXORMLogger(setting.Database.LogSQL))
x.SetLogger(db.NewXORMLogger(setting.Database.LogSQL))
x.ShowSQL(setting.Database.LogSQL)
x.SetMaxOpenConns(setting.Database.MaxOpenConns)
x.SetMaxIdleConns(setting.Database.MaxIdleConns)
@@ -240,11 +241,11 @@ func prepareTestEnv(t *testing.T, skip int, syncModels ...interface{}) (*xorm.En
if _, err := os.Stat(fixturesDir); err == nil {
t.Logf("initializing fixtures from: %s", fixturesDir)
if err := models.InitFixtures(fixturesDir, x); err != nil {
if err := db.InitFixtures(fixturesDir, x); err != nil {
t.Errorf("error whilst initializing fixtures from %s: %v", fixturesDir, err)
return x, deferFn
}
if err := models.LoadFixtures(x); err != nil {
if err := db.LoadFixtures(x); err != nil {
t.Errorf("error whilst loading fixtures from %s: %v", fixturesDir, err)
return x, deferFn
}
+41 -36
View File
@@ -8,6 +8,7 @@ import (
"fmt"
"strconv"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
@@ -67,6 +68,10 @@ type Notification struct {
UpdatedUnix timeutil.TimeStamp `xorm:"updated INDEX NOT NULL"`
}
func init() {
db.RegisterModel(new(Notification))
}
// FindNotificationOptions represent the filters for notifications. If an ID is 0 it will be ignored.
type FindNotificationOptions struct {
ListOptions
@@ -107,7 +112,7 @@ func (opts *FindNotificationOptions) ToCond() builder.Cond {
}
// ToSession will convert the given options to a xorm Session by using the conditions from ToCond and joining with issue table if required
func (opts *FindNotificationOptions) ToSession(e Engine) *xorm.Session {
func (opts *FindNotificationOptions) ToSession(e db.Engine) *xorm.Session {
sess := e.Where(opts.ToCond())
if opts.Page != 0 {
sess = setSessionPagination(sess, opts)
@@ -115,24 +120,24 @@ func (opts *FindNotificationOptions) ToSession(e Engine) *xorm.Session {
return sess
}
func getNotifications(e Engine, options *FindNotificationOptions) (nl NotificationList, err error) {
func getNotifications(e db.Engine, options *FindNotificationOptions) (nl NotificationList, err error) {
err = options.ToSession(e).OrderBy("notification.updated_unix DESC").Find(&nl)
return
}
// GetNotifications returns all notifications that fit to the given options.
func GetNotifications(opts *FindNotificationOptions) (NotificationList, error) {
return getNotifications(x, opts)
return getNotifications(db.DefaultContext().Engine(), opts)
}
// CountNotifications count all notifications that fit to the given options and ignore pagination.
func CountNotifications(opts *FindNotificationOptions) (int64, error) {
return x.Where(opts.ToCond()).Count(&Notification{})
return db.DefaultContext().Engine().Where(opts.ToCond()).Count(&Notification{})
}
// CreateRepoTransferNotification creates notification for the user a repository was transferred to
func CreateRepoTransferNotification(doer, newOwner *User, repo *Repository) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -174,7 +179,7 @@ func CreateRepoTransferNotification(doer, newOwner *User, repo *Repository) erro
// for each watcher, or updates it if already exists
// receiverID > 0 just send to reciver, else send to all watcher
func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID, receiverID int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -187,7 +192,7 @@ func CreateOrUpdateIssueNotifications(issueID, commentID, notificationAuthorID,
return sess.Commit()
}
func createOrUpdateIssueNotifications(e Engine, issueID, commentID, notificationAuthorID, receiverID int64) error {
func createOrUpdateIssueNotifications(e db.Engine, issueID, commentID, notificationAuthorID, receiverID int64) error {
// init
var toNotify map[int64]struct{}
notifications, err := getNotificationsByIssueID(e, issueID)
@@ -277,7 +282,7 @@ func createOrUpdateIssueNotifications(e Engine, issueID, commentID, notification
return nil
}
func getNotificationsByIssueID(e Engine, issueID int64) (notifications []*Notification, err error) {
func getNotificationsByIssueID(e db.Engine, issueID int64) (notifications []*Notification, err error) {
err = e.
Where("issue_id = ?", issueID).
Find(&notifications)
@@ -294,7 +299,7 @@ func notificationExists(notifications []*Notification, issueID, userID int64) bo
return false
}
func createIssueNotification(e Engine, userID int64, issue *Issue, commentID, updatedByID int64) error {
func createIssueNotification(e db.Engine, userID int64, issue *Issue, commentID, updatedByID int64) error {
notification := &Notification{
UserID: userID,
RepoID: issue.RepoID,
@@ -314,7 +319,7 @@ func createIssueNotification(e Engine, userID int64, issue *Issue, commentID, up
return err
}
func updateIssueNotification(e Engine, userID, issueID, commentID, updatedByID int64) error {
func updateIssueNotification(e db.Engine, userID, issueID, commentID, updatedByID int64) error {
notification, err := getIssueNotification(e, userID, issueID)
if err != nil {
return err
@@ -336,7 +341,7 @@ func updateIssueNotification(e Engine, userID, issueID, commentID, updatedByID i
return err
}
func getIssueNotification(e Engine, userID, issueID int64) (*Notification, error) {
func getIssueNotification(e db.Engine, userID, issueID int64) (*Notification, error) {
notification := new(Notification)
_, err := e.
Where("user_id = ?", userID).
@@ -347,10 +352,10 @@ func getIssueNotification(e Engine, userID, issueID int64) (*Notification, error
// NotificationsForUser returns notifications for a given user and status
func NotificationsForUser(user *User, statuses []NotificationStatus, page, perPage int) (NotificationList, error) {
return notificationsForUser(x, user, statuses, page, perPage)
return notificationsForUser(db.DefaultContext().Engine(), user, statuses, page, perPage)
}
func notificationsForUser(e Engine, user *User, statuses []NotificationStatus, page, perPage int) (notifications []*Notification, err error) {
func notificationsForUser(e db.Engine, user *User, statuses []NotificationStatus, page, perPage int) (notifications []*Notification, err error) {
if len(statuses) == 0 {
return
}
@@ -370,10 +375,10 @@ func notificationsForUser(e Engine, user *User, statuses []NotificationStatus, p
// CountUnread count unread notifications for a user
func CountUnread(user *User) int64 {
return countUnread(x, user.ID)
return countUnread(db.DefaultContext().Engine(), user.ID)
}
func countUnread(e Engine, userID int64) int64 {
func countUnread(e db.Engine, userID int64) int64 {
exist, err := e.Where("user_id = ?", userID).And("status = ?", NotificationStatusUnread).Count(new(Notification))
if err != nil {
log.Error("countUnread", err)
@@ -384,10 +389,10 @@ func countUnread(e Engine, userID int64) int64 {
// LoadAttributes load Repo Issue User and Comment if not loaded
func (n *Notification) LoadAttributes() (err error) {
return n.loadAttributes(x)
return n.loadAttributes(db.DefaultContext().Engine())
}
func (n *Notification) loadAttributes(e Engine) (err error) {
func (n *Notification) loadAttributes(e db.Engine) (err error) {
if err = n.loadRepo(e); err != nil {
return
}
@@ -403,7 +408,7 @@ func (n *Notification) loadAttributes(e Engine) (err error) {
return
}
func (n *Notification) loadRepo(e Engine) (err error) {
func (n *Notification) loadRepo(e db.Engine) (err error) {
if n.Repository == nil {
n.Repository, err = getRepositoryByID(e, n.RepoID)
if err != nil {
@@ -413,7 +418,7 @@ func (n *Notification) loadRepo(e Engine) (err error) {
return nil
}
func (n *Notification) loadIssue(e Engine) (err error) {
func (n *Notification) loadIssue(e db.Engine) (err error) {
if n.Issue == nil && n.IssueID != 0 {
n.Issue, err = getIssueByID(e, n.IssueID)
if err != nil {
@@ -424,7 +429,7 @@ func (n *Notification) loadIssue(e Engine) (err error) {
return nil
}
func (n *Notification) loadComment(e Engine) (err error) {
func (n *Notification) loadComment(e db.Engine) (err error) {
if n.Comment == nil && n.CommentID != 0 {
n.Comment, err = getCommentByID(e, n.CommentID)
if err != nil {
@@ -434,7 +439,7 @@ func (n *Notification) loadComment(e Engine) (err error) {
return nil
}
func (n *Notification) loadUser(e Engine) (err error) {
func (n *Notification) loadUser(e db.Engine) (err error) {
if n.User == nil {
n.User, err = getUserByID(e, n.UserID)
if err != nil {
@@ -446,12 +451,12 @@ func (n *Notification) loadUser(e Engine) (err error) {
// GetRepo returns the repo of the notification
func (n *Notification) GetRepo() (*Repository, error) {
return n.Repository, n.loadRepo(x)
return n.Repository, n.loadRepo(db.DefaultContext().Engine())
}
// GetIssue returns the issue of the notification
func (n *Notification) GetIssue() (*Issue, error) {
return n.Issue, n.loadIssue(x)
return n.Issue, n.loadIssue(db.DefaultContext().Engine())
}
// HTMLURL formats a URL-string to the notification
@@ -516,7 +521,7 @@ func (nl NotificationList) LoadRepos() (RepositoryList, []int, error) {
if left < limit {
limit = left
}
rows, err := x.
rows, err := db.DefaultContext().Engine().
In("id", repoIDs[:limit]).
Rows(new(Repository))
if err != nil {
@@ -592,7 +597,7 @@ func (nl NotificationList) LoadIssues() ([]int, error) {
if left < limit {
limit = left
}
rows, err := x.
rows, err := db.DefaultContext().Engine().
In("id", issueIDs[:limit]).
Rows(new(Issue))
if err != nil {
@@ -678,7 +683,7 @@ func (nl NotificationList) LoadComments() ([]int, error) {
if left < limit {
limit = left
}
rows, err := x.
rows, err := db.DefaultContext().Engine().
In("id", commentIDs[:limit]).
Rows(new(Comment))
if err != nil {
@@ -718,10 +723,10 @@ func (nl NotificationList) LoadComments() ([]int, error) {
// GetNotificationCount returns the notification count for user
func GetNotificationCount(user *User, status NotificationStatus) (int64, error) {
return getNotificationCount(x, user, status)
return getNotificationCount(db.DefaultContext().Engine(), user, status)
}
func getNotificationCount(e Engine, user *User, status NotificationStatus) (count int64, err error) {
func getNotificationCount(e db.Engine, user *User, status NotificationStatus) (count int64, err error) {
count, err = e.
Where("user_id = ?", user.ID).
And("status = ?", status).
@@ -741,10 +746,10 @@ func GetUIDsAndNotificationCounts(since, until timeutil.TimeStamp) ([]UserIDCoun
`WHERE user_id IN (SELECT user_id FROM notification WHERE updated_unix >= ? AND ` +
`updated_unix < ?) AND status = ? GROUP BY user_id`
var res []UserIDCount
return res, x.SQL(sql, since, until, NotificationStatusUnread).Find(&res)
return res, db.DefaultContext().Engine().SQL(sql, since, until, NotificationStatusUnread).Find(&res)
}
func setIssueNotificationStatusReadIfUnread(e Engine, userID, issueID int64) error {
func setIssueNotificationStatusReadIfUnread(e db.Engine, userID, issueID int64) error {
notification, err := getIssueNotification(e, userID, issueID)
// ignore if not exists
if err != nil {
@@ -761,7 +766,7 @@ func setIssueNotificationStatusReadIfUnread(e Engine, userID, issueID int64) err
return err
}
func setRepoNotificationStatusReadIfUnread(e Engine, userID, repoID int64) error {
func setRepoNotificationStatusReadIfUnread(e db.Engine, userID, repoID int64) error {
_, err := e.Where(builder.Eq{
"user_id": userID,
"status": NotificationStatusUnread,
@@ -773,7 +778,7 @@ func setRepoNotificationStatusReadIfUnread(e Engine, userID, repoID int64) error
// SetNotificationStatus change the notification status
func SetNotificationStatus(notificationID int64, user *User, status NotificationStatus) (*Notification, error) {
notification, err := getNotificationByID(x, notificationID)
notification, err := getNotificationByID(db.DefaultContext().Engine(), notificationID)
if err != nil {
return notification, err
}
@@ -784,16 +789,16 @@ func SetNotificationStatus(notificationID int64, user *User, status Notification
notification.Status = status
_, err = x.ID(notificationID).Update(notification)
_, err = db.DefaultContext().Engine().ID(notificationID).Update(notification)
return notification, err
}
// GetNotificationByID return notification by ID
func GetNotificationByID(notificationID int64) (*Notification, error) {
return getNotificationByID(x, notificationID)
return getNotificationByID(db.DefaultContext().Engine(), notificationID)
}
func getNotificationByID(e Engine, notificationID int64) (*Notification, error) {
func getNotificationByID(e db.Engine, notificationID int64) (*Notification, error) {
notification := new(Notification)
ok, err := e.
Where("id = ?", notificationID).
@@ -812,7 +817,7 @@ func getNotificationByID(e Engine, notificationID int64) (*Notification, error)
// UpdateNotificationStatuses updates the statuses of all of a user's notifications that are of the currentStatus type to the desiredStatus
func UpdateNotificationStatuses(user *User, currentStatus, desiredStatus NotificationStatus) error {
n := &Notification{Status: desiredStatus, UpdatedBy: user.ID}
_, err := x.
_, err := db.DefaultContext().Engine().
Where("user_id = ? AND status = ?", user.ID, currentStatus).
Cols("status", "updated_by", "updated_unix").
Update(n)
+26 -25
View File
@@ -7,27 +7,28 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestCreateOrUpdateIssueNotifications(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
issue := AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
assert.NoError(t, db.PrepareTestDatabase())
issue := db.AssertExistsAndLoadBean(t, &Issue{ID: 1}).(*Issue)
assert.NoError(t, CreateOrUpdateIssueNotifications(issue.ID, 0, 2, 0))
// User 9 is inactive, thus notifications for user 1 and 4 are created
notf := AssertExistsAndLoadBean(t, &Notification{UserID: 1, IssueID: issue.ID}).(*Notification)
notf := db.AssertExistsAndLoadBean(t, &Notification{UserID: 1, IssueID: issue.ID}).(*Notification)
assert.Equal(t, NotificationStatusUnread, notf.Status)
CheckConsistencyFor(t, &Issue{ID: issue.ID})
notf = AssertExistsAndLoadBean(t, &Notification{UserID: 4, IssueID: issue.ID}).(*Notification)
notf = db.AssertExistsAndLoadBean(t, &Notification{UserID: 4, IssueID: issue.ID}).(*Notification)
assert.Equal(t, NotificationStatusUnread, notf.Status)
}
func TestNotificationsForUser(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
user := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
statuses := []NotificationStatus{NotificationStatusRead, NotificationStatusUnread}
notfs, err := NotificationsForUser(user, statuses, 1, 10)
assert.NoError(t, err)
@@ -42,8 +43,8 @@ func TestNotificationsForUser(t *testing.T) {
}
func TestNotification_GetRepo(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
notf := AssertExistsAndLoadBean(t, &Notification{RepoID: 1}).(*Notification)
assert.NoError(t, db.PrepareTestDatabase())
notf := db.AssertExistsAndLoadBean(t, &Notification{RepoID: 1}).(*Notification)
repo, err := notf.GetRepo()
assert.NoError(t, err)
assert.Equal(t, repo, notf.Repository)
@@ -51,8 +52,8 @@ func TestNotification_GetRepo(t *testing.T) {
}
func TestNotification_GetIssue(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
notf := AssertExistsAndLoadBean(t, &Notification{RepoID: 1}).(*Notification)
assert.NoError(t, db.PrepareTestDatabase())
notf := db.AssertExistsAndLoadBean(t, &Notification{RepoID: 1}).(*Notification)
issue, err := notf.GetIssue()
assert.NoError(t, err)
assert.Equal(t, issue, notf.Issue)
@@ -60,8 +61,8 @@ func TestNotification_GetIssue(t *testing.T) {
}
func TestGetNotificationCount(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
user := AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
user := db.AssertExistsAndLoadBean(t, &User{ID: 1}).(*User)
cnt, err := GetNotificationCount(user, NotificationStatusRead)
assert.NoError(t, err)
assert.EqualValues(t, 0, cnt)
@@ -72,35 +73,35 @@ func TestGetNotificationCount(t *testing.T) {
}
func TestSetNotificationStatus(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
user := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
notf := AssertExistsAndLoadBean(t,
assert.NoError(t, db.PrepareTestDatabase())
user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
notf := db.AssertExistsAndLoadBean(t,
&Notification{UserID: user.ID, Status: NotificationStatusRead}).(*Notification)
_, err := SetNotificationStatus(notf.ID, user, NotificationStatusPinned)
assert.NoError(t, err)
AssertExistsAndLoadBean(t,
db.AssertExistsAndLoadBean(t,
&Notification{ID: notf.ID, Status: NotificationStatusPinned})
_, err = SetNotificationStatus(1, user, NotificationStatusRead)
assert.Error(t, err)
_, err = SetNotificationStatus(NonexistentID, user, NotificationStatusRead)
_, err = SetNotificationStatus(db.NonexistentID, user, NotificationStatusRead)
assert.Error(t, err)
}
func TestUpdateNotificationStatuses(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
user := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
notfUnread := AssertExistsAndLoadBean(t,
assert.NoError(t, db.PrepareTestDatabase())
user := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
notfUnread := db.AssertExistsAndLoadBean(t,
&Notification{UserID: user.ID, Status: NotificationStatusUnread}).(*Notification)
notfRead := AssertExistsAndLoadBean(t,
notfRead := db.AssertExistsAndLoadBean(t,
&Notification{UserID: user.ID, Status: NotificationStatusRead}).(*Notification)
notfPinned := AssertExistsAndLoadBean(t,
notfPinned := db.AssertExistsAndLoadBean(t,
&Notification{UserID: user.ID, Status: NotificationStatusPinned}).(*Notification)
assert.NoError(t, UpdateNotificationStatuses(user, NotificationStatusUnread, NotificationStatusRead))
AssertExistsAndLoadBean(t,
db.AssertExistsAndLoadBean(t,
&Notification{ID: notfUnread.ID, Status: NotificationStatusRead})
AssertExistsAndLoadBean(t,
db.AssertExistsAndLoadBean(t,
&Notification{ID: notfRead.ID, Status: NotificationStatusRead})
AssertExistsAndLoadBean(t,
db.AssertExistsAndLoadBean(t,
&Notification{ID: notfPinned.ID, Status: NotificationStatusPinned})
}
+4 -2
View File
@@ -4,10 +4,12 @@
package models
import "code.gitea.io/gitea/models/db"
// GetActiveOAuth2ProviderLoginSources returns all actived LoginOAuth2 sources
func GetActiveOAuth2ProviderLoginSources() ([]*LoginSource, error) {
sources := make([]*LoginSource, 0, 1)
if err := x.Where("is_active = ? and type = ?", true, LoginOAuth2).Find(&sources); err != nil {
if err := db.DefaultContext().Engine().Where("is_active = ? and type = ?", true, LoginOAuth2).Find(&sources); err != nil {
return nil, err
}
return sources, nil
@@ -16,7 +18,7 @@ func GetActiveOAuth2ProviderLoginSources() ([]*LoginSource, error) {
// GetActiveOAuth2LoginSourceByName returns a OAuth2 LoginSource based on the given name
func GetActiveOAuth2LoginSourceByName(name string) (*LoginSource, error) {
loginSource := new(LoginSource)
has, err := x.Where("name = ? and type = ? and is_active = ?", name, LoginOAuth2, true).Get(loginSource)
has, err := db.DefaultContext().Engine().Where("name = ? and type = ? and is_active = ?", name, LoginOAuth2, true).Get(loginSource)
if !has || err != nil {
return nil, err
}
+40 -33
View File
@@ -11,6 +11,7 @@ import (
"net/url"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/secret"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
@@ -37,6 +38,12 @@ type OAuth2Application struct {
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(OAuth2Application))
db.RegisterModel(new(OAuth2AuthorizationCode))
db.RegisterModel(new(OAuth2Grant))
}
// TableName sets the table name to `oauth2_application`
func (app *OAuth2Application) TableName() string {
return "oauth2_application"
@@ -74,7 +81,7 @@ func (app *OAuth2Application) GenerateClientSecret() (string, error) {
return "", err
}
app.ClientSecret = string(hashedSecret)
if _, err := x.ID(app.ID).Cols("client_secret").Update(app); err != nil {
if _, err := db.DefaultContext().Engine().ID(app.ID).Cols("client_secret").Update(app); err != nil {
return "", err
}
return clientSecret, nil
@@ -87,10 +94,10 @@ func (app *OAuth2Application) ValidateClientSecret(secret []byte) bool {
// GetGrantByUserID returns a OAuth2Grant by its user and application ID
func (app *OAuth2Application) GetGrantByUserID(userID int64) (*OAuth2Grant, error) {
return app.getGrantByUserID(x, userID)
return app.getGrantByUserID(db.DefaultContext().Engine(), userID)
}
func (app *OAuth2Application) getGrantByUserID(e Engine, userID int64) (grant *OAuth2Grant, err error) {
func (app *OAuth2Application) getGrantByUserID(e db.Engine, userID int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := e.Where("user_id = ? AND application_id = ?", userID, app.ID).Get(grant); err != nil {
return nil, err
@@ -102,10 +109,10 @@ func (app *OAuth2Application) getGrantByUserID(e Engine, userID int64) (grant *O
// CreateGrant generates a grant for an user
func (app *OAuth2Application) CreateGrant(userID int64, scope string) (*OAuth2Grant, error) {
return app.createGrant(x, userID, scope)
return app.createGrant(db.DefaultContext().Engine(), userID, scope)
}
func (app *OAuth2Application) createGrant(e Engine, userID int64, scope string) (*OAuth2Grant, error) {
func (app *OAuth2Application) createGrant(e db.Engine, userID int64, scope string) (*OAuth2Grant, error) {
grant := &OAuth2Grant{
ApplicationID: app.ID,
UserID: userID,
@@ -120,10 +127,10 @@ func (app *OAuth2Application) createGrant(e Engine, userID int64, scope string)
// GetOAuth2ApplicationByClientID returns the oauth2 application with the given client_id. Returns an error if not found.
func GetOAuth2ApplicationByClientID(clientID string) (app *OAuth2Application, err error) {
return getOAuth2ApplicationByClientID(x, clientID)
return getOAuth2ApplicationByClientID(db.DefaultContext().Engine(), clientID)
}
func getOAuth2ApplicationByClientID(e Engine, clientID string) (app *OAuth2Application, err error) {
func getOAuth2ApplicationByClientID(e db.Engine, clientID string) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := e.Where("client_id = ?", clientID).Get(app)
if !has {
@@ -134,10 +141,10 @@ func getOAuth2ApplicationByClientID(e Engine, clientID string) (app *OAuth2Appli
// GetOAuth2ApplicationByID returns the oauth2 application with the given id. Returns an error if not found.
func GetOAuth2ApplicationByID(id int64) (app *OAuth2Application, err error) {
return getOAuth2ApplicationByID(x, id)
return getOAuth2ApplicationByID(db.DefaultContext().Engine(), id)
}
func getOAuth2ApplicationByID(e Engine, id int64) (app *OAuth2Application, err error) {
func getOAuth2ApplicationByID(e db.Engine, id int64) (app *OAuth2Application, err error) {
app = new(OAuth2Application)
has, err := e.ID(id).Get(app)
if err != nil {
@@ -151,10 +158,10 @@ func getOAuth2ApplicationByID(e Engine, id int64) (app *OAuth2Application, err e
// GetOAuth2ApplicationsByUserID returns all oauth2 applications owned by the user
func GetOAuth2ApplicationsByUserID(userID int64) (apps []*OAuth2Application, err error) {
return getOAuth2ApplicationsByUserID(x, userID)
return getOAuth2ApplicationsByUserID(db.DefaultContext().Engine(), userID)
}
func getOAuth2ApplicationsByUserID(e Engine, userID int64) (apps []*OAuth2Application, err error) {
func getOAuth2ApplicationsByUserID(e db.Engine, userID int64) (apps []*OAuth2Application, err error) {
apps = make([]*OAuth2Application, 0)
err = e.Where("uid = ?", userID).Find(&apps)
return
@@ -169,10 +176,10 @@ type CreateOAuth2ApplicationOptions struct {
// CreateOAuth2Application inserts a new oauth2 application
func CreateOAuth2Application(opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
return createOAuth2Application(x, opts)
return createOAuth2Application(db.DefaultContext().Engine(), opts)
}
func createOAuth2Application(e Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
func createOAuth2Application(e db.Engine, opts CreateOAuth2ApplicationOptions) (*OAuth2Application, error) {
clientID := uuid.New().String()
app := &OAuth2Application{
UID: opts.UserID,
@@ -196,7 +203,7 @@ type UpdateOAuth2ApplicationOptions struct {
// UpdateOAuth2Application updates an oauth2 application
func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Application, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
if err := sess.Begin(); err != nil {
return nil, err
}
@@ -221,7 +228,7 @@ func UpdateOAuth2Application(opts UpdateOAuth2ApplicationOptions) (*OAuth2Applic
return app, sess.Commit()
}
func updateOAuth2Application(e Engine, app *OAuth2Application) error {
func updateOAuth2Application(e db.Engine, app *OAuth2Application) error {
if _, err := e.ID(app.ID).Update(app); err != nil {
return err
}
@@ -257,7 +264,7 @@ func deleteOAuth2Application(sess *xorm.Session, id, userid int64) error {
// DeleteOAuth2Application deletes the application with the given id and the grants and auth codes related to it. It checks if the userid was the creator of the app.
func DeleteOAuth2Application(id, userid int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -270,7 +277,7 @@ func DeleteOAuth2Application(id, userid int64) error {
// ListOAuth2Applications returns a list of oauth2 applications belongs to given user.
func ListOAuth2Applications(uid int64, listOptions ListOptions) ([]*OAuth2Application, int64, error) {
sess := x.
sess := db.DefaultContext().Engine().
Where("uid=?", uid).
Desc("id")
@@ -322,10 +329,10 @@ func (code *OAuth2AuthorizationCode) GenerateRedirectURI(state string) (redirect
// Invalidate deletes the auth code from the database to invalidate this code
func (code *OAuth2AuthorizationCode) Invalidate() error {
return code.invalidate(x)
return code.invalidate(db.DefaultContext().Engine())
}
func (code *OAuth2AuthorizationCode) invalidate(e Engine) error {
func (code *OAuth2AuthorizationCode) invalidate(e db.Engine) error {
_, err := e.Delete(code)
return err
}
@@ -354,10 +361,10 @@ func (code *OAuth2AuthorizationCode) validateCodeChallenge(verifier string) bool
// GetOAuth2AuthorizationByCode returns an authorization by its code
func GetOAuth2AuthorizationByCode(code string) (*OAuth2AuthorizationCode, error) {
return getOAuth2AuthorizationByCode(x, code)
return getOAuth2AuthorizationByCode(db.DefaultContext().Engine(), code)
}
func getOAuth2AuthorizationByCode(e Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
func getOAuth2AuthorizationByCode(e db.Engine, code string) (auth *OAuth2AuthorizationCode, err error) {
auth = new(OAuth2AuthorizationCode)
if has, err := e.Where("code = ?", code).Get(auth); err != nil {
return nil, err
@@ -395,10 +402,10 @@ func (grant *OAuth2Grant) TableName() string {
// GenerateNewAuthorizationCode generates a new authorization code for a grant and saves it to the database
func (grant *OAuth2Grant) GenerateNewAuthorizationCode(redirectURI, codeChallenge, codeChallengeMethod string) (*OAuth2AuthorizationCode, error) {
return grant.generateNewAuthorizationCode(x, redirectURI, codeChallenge, codeChallengeMethod)
return grant.generateNewAuthorizationCode(db.DefaultContext().Engine(), redirectURI, codeChallenge, codeChallengeMethod)
}
func (grant *OAuth2Grant) generateNewAuthorizationCode(e Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
func (grant *OAuth2Grant) generateNewAuthorizationCode(e db.Engine, redirectURI, codeChallenge, codeChallengeMethod string) (code *OAuth2AuthorizationCode, err error) {
var codeSecret string
if codeSecret, err = secret.New(); err != nil {
return &OAuth2AuthorizationCode{}, err
@@ -419,10 +426,10 @@ func (grant *OAuth2Grant) generateNewAuthorizationCode(e Engine, redirectURI, co
// IncreaseCounter increases the counter and updates the grant
func (grant *OAuth2Grant) IncreaseCounter() error {
return grant.increaseCount(x)
return grant.increaseCount(db.DefaultContext().Engine())
}
func (grant *OAuth2Grant) increaseCount(e Engine) error {
func (grant *OAuth2Grant) increaseCount(e db.Engine) error {
_, err := e.ID(grant.ID).Incr("counter").Update(new(OAuth2Grant))
if err != nil {
return err
@@ -447,10 +454,10 @@ func (grant *OAuth2Grant) ScopeContains(scope string) bool {
// SetNonce updates the current nonce value of a grant
func (grant *OAuth2Grant) SetNonce(nonce string) error {
return grant.setNonce(x, nonce)
return grant.setNonce(db.DefaultContext().Engine(), nonce)
}
func (grant *OAuth2Grant) setNonce(e Engine, nonce string) error {
func (grant *OAuth2Grant) setNonce(e db.Engine, nonce string) error {
grant.Nonce = nonce
_, err := e.ID(grant.ID).Cols("nonce").Update(grant)
if err != nil {
@@ -461,10 +468,10 @@ func (grant *OAuth2Grant) setNonce(e Engine, nonce string) error {
// GetOAuth2GrantByID returns the grant with the given ID
func GetOAuth2GrantByID(id int64) (*OAuth2Grant, error) {
return getOAuth2GrantByID(x, id)
return getOAuth2GrantByID(db.DefaultContext().Engine(), id)
}
func getOAuth2GrantByID(e Engine, id int64) (grant *OAuth2Grant, err error) {
func getOAuth2GrantByID(e db.Engine, id int64) (grant *OAuth2Grant, err error) {
grant = new(OAuth2Grant)
if has, err := e.ID(id).Get(grant); err != nil {
return nil, err
@@ -476,10 +483,10 @@ func getOAuth2GrantByID(e Engine, id int64) (grant *OAuth2Grant, err error) {
// GetOAuth2GrantsByUserID lists all grants of a certain user
func GetOAuth2GrantsByUserID(uid int64) ([]*OAuth2Grant, error) {
return getOAuth2GrantsByUserID(x, uid)
return getOAuth2GrantsByUserID(db.DefaultContext().Engine(), uid)
}
func getOAuth2GrantsByUserID(e Engine, uid int64) ([]*OAuth2Grant, error) {
func getOAuth2GrantsByUserID(e db.Engine, uid int64) ([]*OAuth2Grant, error) {
type joinedOAuth2Grant struct {
Grant *OAuth2Grant `xorm:"extends"`
Application *OAuth2Application `xorm:"extends"`
@@ -508,10 +515,10 @@ func getOAuth2GrantsByUserID(e Engine, uid int64) ([]*OAuth2Grant, error) {
// RevokeOAuth2Grant deletes the grant with grantID and userID
func RevokeOAuth2Grant(grantID, userID int64) error {
return revokeOAuth2Grant(x, grantID, userID)
return revokeOAuth2Grant(db.DefaultContext().Engine(), grantID, userID)
}
func revokeOAuth2Grant(e Engine, grantID, userID int64) error {
func revokeOAuth2Grant(e db.Engine, grantID, userID int64) error {
_, err := e.Delete(&OAuth2Grant{ID: grantID, UserID: userID})
return err
}
+32 -31
View File
@@ -7,23 +7,24 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
//////////////////// Application
func TestOAuth2Application_GenerateClientSecret(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
app := AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
assert.NoError(t, db.PrepareTestDatabase())
app := db.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
secret, err := app.GenerateClientSecret()
assert.NoError(t, err)
assert.True(t, len(secret) > 0)
AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
db.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1, ClientSecret: app.ClientSecret})
}
func BenchmarkOAuth2Application_GenerateClientSecret(b *testing.B) {
assert.NoError(b, PrepareTestDatabase())
app := AssertExistsAndLoadBean(b, &OAuth2Application{ID: 1}).(*OAuth2Application)
assert.NoError(b, db.PrepareTestDatabase())
app := db.AssertExistsAndLoadBean(b, &OAuth2Application{ID: 1}).(*OAuth2Application)
for i := 0; i < b.N; i++ {
_, _ = app.GenerateClientSecret()
}
@@ -40,8 +41,8 @@ func TestOAuth2Application_ContainsRedirectURI(t *testing.T) {
}
func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
app := AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
assert.NoError(t, db.PrepareTestDatabase())
app := db.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
secret, err := app.GenerateClientSecret()
assert.NoError(t, err)
assert.True(t, app.ValidateClientSecret([]byte(secret)))
@@ -49,7 +50,7 @@ func TestOAuth2Application_ValidateClientSecret(t *testing.T) {
}
func TestGetOAuth2ApplicationByClientID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
app, err := GetOAuth2ApplicationByClientID("da7da3ba-9a13-4167-856f-3899de0b0138")
assert.NoError(t, err)
assert.Equal(t, "da7da3ba-9a13-4167-856f-3899de0b0138", app.ClientID)
@@ -60,17 +61,17 @@ func TestGetOAuth2ApplicationByClientID(t *testing.T) {
}
func TestCreateOAuth2Application(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
app, err := CreateOAuth2Application(CreateOAuth2ApplicationOptions{Name: "newapp", UserID: 1})
assert.NoError(t, err)
assert.Equal(t, "newapp", app.Name)
assert.Len(t, app.ClientID, 36)
AssertExistsAndLoadBean(t, &OAuth2Application{Name: "newapp"})
db.AssertExistsAndLoadBean(t, &OAuth2Application{Name: "newapp"})
}
func TestOAuth2Application_LoadUser(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
app := AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
assert.NoError(t, db.PrepareTestDatabase())
app := db.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
assert.NoError(t, app.LoadUser())
assert.NotNil(t, app.User)
}
@@ -80,8 +81,8 @@ func TestOAuth2Application_TableName(t *testing.T) {
}
func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
app := AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
assert.NoError(t, db.PrepareTestDatabase())
app := db.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
grant, err := app.GetGrantByUserID(1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.UserID)
@@ -92,8 +93,8 @@ func TestOAuth2Application_GetGrantByUserID(t *testing.T) {
}
func TestOAuth2Application_CreateGrant(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
app := AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
assert.NoError(t, db.PrepareTestDatabase())
app := db.AssertExistsAndLoadBean(t, &OAuth2Application{ID: 1}).(*OAuth2Application)
grant, err := app.CreateGrant(2, "")
assert.NoError(t, err)
assert.NotNil(t, grant)
@@ -105,7 +106,7 @@ func TestOAuth2Application_CreateGrant(t *testing.T) {
//////////////////// Grant
func TestGetOAuth2GrantByID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
grant, err := GetOAuth2GrantByID(1)
assert.NoError(t, err)
assert.Equal(t, int64(1), grant.ID)
@@ -116,16 +117,16 @@ func TestGetOAuth2GrantByID(t *testing.T) {
}
func TestOAuth2Grant_IncreaseCounter(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
grant := AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant)
assert.NoError(t, db.PrepareTestDatabase())
grant := db.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 1}).(*OAuth2Grant)
assert.NoError(t, grant.IncreaseCounter())
assert.Equal(t, int64(2), grant.Counter)
AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2})
db.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Counter: 2})
}
func TestOAuth2Grant_ScopeContains(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
grant := AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Scope: "openid profile"}).(*OAuth2Grant)
assert.NoError(t, db.PrepareTestDatabase())
grant := db.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1, Scope: "openid profile"}).(*OAuth2Grant)
assert.True(t, grant.ScopeContains("openid"))
assert.True(t, grant.ScopeContains("profile"))
assert.False(t, grant.ScopeContains("profil"))
@@ -133,8 +134,8 @@ func TestOAuth2Grant_ScopeContains(t *testing.T) {
}
func TestOAuth2Grant_GenerateNewAuthorizationCode(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
grant := AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant)
assert.NoError(t, db.PrepareTestDatabase())
grant := db.AssertExistsAndLoadBean(t, &OAuth2Grant{ID: 1}).(*OAuth2Grant)
code, err := grant.GenerateNewAuthorizationCode("https://example2.com/callback", "CjvyTLSdR47G5zYenDA-eDWW4lRrO8yvjcWwbD_deOg", "S256")
assert.NoError(t, err)
assert.NotNil(t, code)
@@ -146,7 +147,7 @@ func TestOAuth2Grant_TableName(t *testing.T) {
}
func TestGetOAuth2GrantsByUserID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
result, err := GetOAuth2GrantsByUserID(1)
assert.NoError(t, err)
assert.Len(t, result, 1)
@@ -159,15 +160,15 @@ func TestGetOAuth2GrantsByUserID(t *testing.T) {
}
func TestRevokeOAuth2Grant(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
assert.NoError(t, RevokeOAuth2Grant(1, 1))
AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1})
db.AssertNotExistsBean(t, &OAuth2Grant{ID: 1, UserID: 1})
}
//////////////////// Authorization Code
func TestGetOAuth2AuthorizationByCode(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
code, err := GetOAuth2AuthorizationByCode("authcode")
assert.NoError(t, err)
assert.NotNil(t, code)
@@ -227,10 +228,10 @@ func TestOAuth2AuthorizationCode_GenerateRedirectURI(t *testing.T) {
}
func TestOAuth2AuthorizationCode_Invalidate(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
code := AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode)
assert.NoError(t, db.PrepareTestDatabase())
code := db.AssertExistsAndLoadBean(t, &OAuth2AuthorizationCode{Code: "authcode"}).(*OAuth2AuthorizationCode)
assert.NoError(t, code.Invalidate())
AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"})
db.AssertNotExistsBean(t, &OAuth2AuthorizationCode{Code: "authcode"})
}
func TestOAuth2AuthorizationCode_TableName(t *testing.T) {
+51 -46
View File
@@ -9,6 +9,7 @@ import (
"fmt"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/storage"
@@ -34,25 +35,25 @@ func (org *User) CanCreateOrgRepo(uid int64) (bool, error) {
return CanCreateOrgRepo(org.ID, uid)
}
func (org *User) getTeam(e Engine, name string) (*Team, error) {
func (org *User) getTeam(e db.Engine, name string) (*Team, error) {
return getTeam(e, org.ID, name)
}
// GetTeam returns named team of organization.
func (org *User) GetTeam(name string) (*Team, error) {
return org.getTeam(x, name)
return org.getTeam(db.DefaultContext().Engine(), name)
}
func (org *User) getOwnerTeam(e Engine) (*Team, error) {
func (org *User) getOwnerTeam(e db.Engine) (*Team, error) {
return org.getTeam(e, ownerTeamName)
}
// GetOwnerTeam returns owner team of organization.
func (org *User) GetOwnerTeam() (*Team, error) {
return org.getOwnerTeam(x)
return org.getOwnerTeam(db.DefaultContext().Engine())
}
func (org *User) loadTeams(e Engine) error {
func (org *User) loadTeams(e db.Engine) error {
if org.Teams != nil {
return nil
}
@@ -64,7 +65,7 @@ func (org *User) loadTeams(e Engine) error {
// LoadTeams load teams if not loaded.
func (org *User) LoadTeams() error {
return org.loadTeams(x)
return org.loadTeams(db.DefaultContext().Engine())
}
// GetMembers returns all members of organization.
@@ -84,7 +85,7 @@ type FindOrgMembersOpts struct {
// CountOrgMembers counts the organization's members
func CountOrgMembers(opts *FindOrgMembersOpts) (int64, error) {
sess := x.Where("org_id=?", opts.OrgID)
sess := db.DefaultContext().Engine().Where("org_id=?", opts.OrgID)
if opts.PublicOnly {
sess.And("is_public = ?", true)
}
@@ -122,13 +123,13 @@ func (org *User) RemoveMember(uid int64) error {
return RemoveOrgUser(org.ID, uid)
}
func (org *User) removeOrgRepo(e Engine, repoID int64) error {
func (org *User) removeOrgRepo(e db.Engine, repoID int64) error {
return removeOrgRepo(e, org.ID, repoID)
}
// RemoveOrgRepo removes all team-repository relations of organization.
func (org *User) RemoveOrgRepo(repoID int64) error {
return org.removeOrgRepo(x, repoID)
return org.removeOrgRepo(db.DefaultContext().Engine(), repoID)
}
// CreateOrganization creates record of a new organization.
@@ -161,7 +162,7 @@ func CreateOrganization(org, owner *User) (err error) {
org.NumMembers = 1
org.Type = UserTypeOrganization
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -237,7 +238,7 @@ func GetOrgByName(name string) (*User, error) {
LowerName: strings.ToLower(name),
Type: UserTypeOrganization,
}
has, err := x.Get(u)
has, err := db.DefaultContext().Engine().Get(u)
if err != nil {
return nil, err
} else if !has {
@@ -248,7 +249,7 @@ func GetOrgByName(name string) (*User, error) {
// CountOrganizations returns number of organizations.
func CountOrganizations() int64 {
count, _ := x.
count, _ := db.DefaultContext().Engine().
Where("type=1").
Count(new(User))
return count
@@ -260,7 +261,7 @@ func DeleteOrganization(org *User) (err error) {
return fmt.Errorf("%s is a user not an organization", org.Name)
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
@@ -334,7 +335,11 @@ type OrgUser struct {
IsPublic bool `xorm:"INDEX"`
}
func isOrganizationOwner(e Engine, orgID, uid int64) (bool, error) {
func init() {
db.RegisterModel(new(OrgUser))
}
func isOrganizationOwner(e db.Engine, orgID, uid int64) (bool, error) {
ownerTeam, err := getOwnerTeam(e, orgID)
if err != nil {
if IsErrTeamNotExist(err) {
@@ -348,15 +353,15 @@ func isOrganizationOwner(e Engine, orgID, uid int64) (bool, error) {
// IsOrganizationOwner returns true if given user is in the owner team.
func IsOrganizationOwner(orgID, uid int64) (bool, error) {
return isOrganizationOwner(x, orgID, uid)
return isOrganizationOwner(db.DefaultContext().Engine(), orgID, uid)
}
// IsOrganizationMember returns true if given user is member of organization.
func IsOrganizationMember(orgID, uid int64) (bool, error) {
return isOrganizationMember(x, orgID, uid)
return isOrganizationMember(db.DefaultContext().Engine(), orgID, uid)
}
func isOrganizationMember(e Engine, orgID, uid int64) (bool, error) {
func isOrganizationMember(e db.Engine, orgID, uid int64) (bool, error) {
return e.
Where("uid=?", uid).
And("org_id=?", orgID).
@@ -366,7 +371,7 @@ func isOrganizationMember(e Engine, orgID, uid int64) (bool, error) {
// IsPublicMembership returns true if given user public his/her membership.
func IsPublicMembership(orgID, uid int64) (bool, error) {
return x.
return db.DefaultContext().Engine().
Where("uid=?", uid).
And("org_id=?", orgID).
And("is_public=?", true).
@@ -379,7 +384,7 @@ func CanCreateOrgRepo(orgID, uid int64) (bool, error) {
if owner, err := IsOrganizationOwner(orgID, uid); owner || err != nil {
return owner, err
}
return x.
return db.DefaultContext().Engine().
Where(builder.Eq{"team.can_create_org_repo": true}).
Join("INNER", "team_user", "team_user.team_id = team.id").
And("team_user.uid = ?", uid).
@@ -389,12 +394,12 @@ func CanCreateOrgRepo(orgID, uid int64) (bool, error) {
// GetUsersWhoCanCreateOrgRepo returns users which are able to create repo in organization
func GetUsersWhoCanCreateOrgRepo(orgID int64) ([]*User, error) {
return getUsersWhoCanCreateOrgRepo(x, orgID)
return getUsersWhoCanCreateOrgRepo(db.DefaultContext().Engine(), orgID)
}
func getUsersWhoCanCreateOrgRepo(e Engine, orgID int64) ([]*User, error) {
func getUsersWhoCanCreateOrgRepo(e db.Engine, orgID int64) ([]*User, error) {
users := make([]*User, 0, 10)
return users, x.
return users, db.DefaultContext().Engine().
Join("INNER", "`team_user`", "`team_user`.uid=`user`.id").
Join("INNER", "`team`", "`team`.id=`team_user`.team_id").
Where(builder.Eq{"team.can_create_org_repo": true}.Or(builder.Eq{"team.authorize": AccessModeOwner})).
@@ -416,7 +421,7 @@ func getOrgsByUserID(sess *xorm.Session, userID int64, showAll bool) ([]*User, e
// GetOrgsByUserID returns a list of organizations that the given user ID
// has joined.
func GetOrgsByUserID(userID int64, showAll bool) ([]*User, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
return getOrgsByUserID(sess, userID, showAll)
}
@@ -426,10 +431,10 @@ type MinimalOrg = User
// GetUserOrgsList returns one user's all orgs list
func GetUserOrgsList(user *User) ([]*MinimalOrg, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
schema, err := x.TableInfo(new(User))
schema, err := db.TableInfo(new(User))
if err != nil {
return nil, err
}
@@ -497,10 +502,10 @@ func getOwnedOrgsByUserID(sess *xorm.Session, userID int64) ([]*User, error) {
// HasOrgOrUserVisible tells if the given user can see the given org or user
func HasOrgOrUserVisible(org, user *User) bool {
return hasOrgOrUserVisible(x, org, user)
return hasOrgOrUserVisible(db.DefaultContext().Engine(), org, user)
}
func hasOrgOrUserVisible(e Engine, orgOrUser, user *User) bool {
func hasOrgOrUserVisible(e db.Engine, orgOrUser, user *User) bool {
// Not SignedUser
if user == nil {
return orgOrUser.Visibility == structs.VisibleTypePublic
@@ -532,7 +537,7 @@ func HasOrgsVisible(orgs []*User, user *User) bool {
// GetOwnedOrgsByUserID returns a list of organizations are owned by given user ID.
func GetOwnedOrgsByUserID(userID int64) ([]*User, error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
return getOwnedOrgsByUserID(sess, userID)
}
@@ -540,7 +545,7 @@ func GetOwnedOrgsByUserID(userID int64) ([]*User, error) {
// GetOwnedOrgsByUserIDDesc returns a list of organizations are owned by
// given user ID, ordered descending by the given condition.
func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*User, error) {
return getOwnedOrgsByUserID(x.Desc(desc), userID)
return getOwnedOrgsByUserID(db.DefaultContext().Engine().Desc(desc), userID)
}
// GetOrgsCanCreateRepoByUserID returns a list of organizations where given user ID
@@ -548,7 +553,7 @@ func GetOwnedOrgsByUserIDDesc(userID int64, desc string) ([]*User, error) {
func GetOrgsCanCreateRepoByUserID(userID int64) ([]*User, error) {
orgs := make([]*User, 0, 10)
return orgs, x.Where(builder.In("id", builder.Select("`user`.id").From("`user`").
return orgs, db.DefaultContext().Engine().Where(builder.In("id", builder.Select("`user`.id").From("`user`").
Join("INNER", "`team_user`", "`team_user`.org_id = `user`.id").
Join("INNER", "`team`", "`team`.id = `team_user`.team_id").
Where(builder.Eq{"`team_user`.uid": userID}).
@@ -560,7 +565,7 @@ func GetOrgsCanCreateRepoByUserID(userID int64) ([]*User, error) {
// GetOrgUsersByUserID returns all organization-user relations by user ID.
func GetOrgUsersByUserID(uid int64, opts *SearchOrganizationsOptions) ([]*OrgUser, error) {
ous := make([]*OrgUser, 0, 10)
sess := x.
sess := db.DefaultContext().Engine().
Join("LEFT", "`user`", "`org_user`.org_id=`user`.id").
Where("`org_user`.uid=?", uid)
if !opts.All {
@@ -580,10 +585,10 @@ func GetOrgUsersByUserID(uid int64, opts *SearchOrganizationsOptions) ([]*OrgUse
// GetOrgUsersByOrgID returns all organization-user relations by organization ID.
func GetOrgUsersByOrgID(opts *FindOrgMembersOpts) ([]*OrgUser, error) {
return getOrgUsersByOrgID(x, opts)
return getOrgUsersByOrgID(db.DefaultContext().Engine(), opts)
}
func getOrgUsersByOrgID(e Engine, opts *FindOrgMembersOpts) ([]*OrgUser, error) {
func getOrgUsersByOrgID(e db.Engine, opts *FindOrgMembersOpts) ([]*OrgUser, error) {
sess := e.Where("org_id=?", opts.OrgID)
if opts.PublicOnly {
sess.And("is_public = ?", true)
@@ -602,7 +607,7 @@ func getOrgUsersByOrgID(e Engine, opts *FindOrgMembersOpts) ([]*OrgUser, error)
// ChangeOrgUserStatus changes public or private membership status.
func ChangeOrgUserStatus(orgID, uid int64, public bool) error {
ou := new(OrgUser)
has, err := x.
has, err := db.DefaultContext().Engine().
Where("uid=?", uid).
And("org_id=?", orgID).
Get(ou)
@@ -613,7 +618,7 @@ func ChangeOrgUserStatus(orgID, uid int64, public bool) error {
}
ou.IsPublic = public
_, err = x.ID(ou.ID).Cols("is_public").Update(ou)
_, err = db.DefaultContext().Engine().ID(ou.ID).Cols("is_public").Update(ou)
return err
}
@@ -624,7 +629,7 @@ func AddOrgUser(orgID, uid int64) error {
return err
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -733,7 +738,7 @@ func removeOrgUser(sess *xorm.Session, orgID, userID int64) error {
// RemoveOrgUser removes user from given organization.
func RemoveOrgUser(orgID, userID int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -744,7 +749,7 @@ func RemoveOrgUser(orgID, userID int64) error {
return sess.Commit()
}
func removeOrgRepo(e Engine, orgID, repoID int64) error {
func removeOrgRepo(e db.Engine, orgID, repoID int64) error {
teamRepos := make([]*TeamRepo, 0, 10)
if err := e.Find(&teamRepos, &TeamRepo{OrgID: orgID, RepoID: repoID}); err != nil {
return err
@@ -770,7 +775,7 @@ func removeOrgRepo(e Engine, orgID, repoID int64) error {
return err
}
func (org *User) getUserTeams(e Engine, userID int64, cols ...string) ([]*Team, error) {
func (org *User) getUserTeams(e db.Engine, userID int64, cols ...string) ([]*Team, error) {
teams := make([]*Team, 0, org.NumTeams)
return teams, e.
Where("`team_user`.org_id = ?", org.ID).
@@ -782,7 +787,7 @@ func (org *User) getUserTeams(e Engine, userID int64, cols ...string) ([]*Team,
Find(&teams)
}
func (org *User) getUserTeamIDs(e Engine, userID int64) ([]int64, error) {
func (org *User) getUserTeamIDs(e db.Engine, userID int64) ([]int64, error) {
teamIDs := make([]int64, 0, org.NumTeams)
return teamIDs, e.
Table("team").
@@ -800,13 +805,13 @@ func (org *User) TeamsWithAccessToRepo(repoID int64, mode AccessMode) ([]*Team,
// GetUserTeamIDs returns of all team IDs of the organization that user is member of.
func (org *User) GetUserTeamIDs(userID int64) ([]int64, error) {
return org.getUserTeamIDs(x, userID)
return org.getUserTeamIDs(db.DefaultContext().Engine(), userID)
}
// GetUserTeams returns all teams that belong to user,
// and that the user has joined.
func (org *User) GetUserTeams(userID int64) ([]*Team, error) {
return org.getUserTeams(x, userID)
return org.getUserTeams(db.DefaultContext().Engine(), userID)
}
// AccessibleReposEnvironment operations involving the repositories that are
@@ -825,7 +830,7 @@ type accessibleReposEnv struct {
user *User
team *Team
teamIDs []int64
e Engine
e db.Engine
keyword string
orderBy SearchOrderBy
}
@@ -833,10 +838,10 @@ type accessibleReposEnv struct {
// AccessibleReposEnv builds an AccessibleReposEnvironment for the repositories in `org`
// that are accessible to the specified user.
func (org *User) AccessibleReposEnv(userID int64) (AccessibleReposEnvironment, error) {
return org.accessibleReposEnv(x, userID)
return org.accessibleReposEnv(db.DefaultContext().Engine(), userID)
}
func (org *User) accessibleReposEnv(e Engine, userID int64) (AccessibleReposEnvironment, error) {
func (org *User) accessibleReposEnv(e db.Engine, userID int64) (AccessibleReposEnvironment, error) {
var user *User
if userID > 0 {
@@ -866,7 +871,7 @@ func (org *User) AccessibleTeamReposEnv(team *Team) AccessibleReposEnvironment {
return &accessibleReposEnv{
org: org,
team: team,
e: x,
e: db.DefaultContext().Engine(),
orderBy: SearchOrderByRecentUpdated,
}
}
+58 -50
View File
@@ -11,6 +11,7 @@ import (
"sort"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
@@ -37,6 +38,13 @@ type Team struct {
CanCreateOrgRepo bool `xorm:"NOT NULL DEFAULT false"`
}
func init() {
db.RegisterModel(new(Team))
db.RegisterModel(new(TeamUser))
db.RegisterModel(new(TeamRepo))
db.RegisterModel(new(TeamUnit))
}
// SearchTeamOptions holds the search options
type SearchTeamOptions struct {
ListOptions
@@ -74,7 +82,7 @@ func SearchTeam(opts *SearchTeamOptions) ([]*Team, int64, error) {
cond = cond.And(builder.Eq{"org_id": opts.OrgID})
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
count, err := sess.
@@ -112,10 +120,10 @@ func (t *Team) ColorFormat(s fmt.State) {
// GetUnits return a list of available units for a team
func (t *Team) GetUnits() error {
return t.getUnits(x)
return t.getUnits(db.DefaultContext().Engine())
}
func (t *Team) getUnits(e Engine) (err error) {
func (t *Team) getUnits(e db.Engine) (err error) {
if t.Units != nil {
return nil
}
@@ -152,7 +160,7 @@ func (t *Team) IsMember(userID int64) bool {
return isMember
}
func (t *Team) getRepositories(e Engine) error {
func (t *Team) getRepositories(e db.Engine) error {
if t.Repos != nil {
return nil
}
@@ -165,13 +173,13 @@ func (t *Team) getRepositories(e Engine) error {
// GetRepositories returns paginated repositories in team of organization.
func (t *Team) GetRepositories(opts *SearchTeamOptions) error {
if opts.Page == 0 {
return t.getRepositories(x)
return t.getRepositories(db.DefaultContext().Engine())
}
return t.getRepositories(getPaginatedSession(opts))
}
func (t *Team) getMembers(e Engine) (err error) {
func (t *Team) getMembers(e db.Engine) (err error) {
t.Members, err = getTeamMembers(e, t.ID)
return err
}
@@ -179,7 +187,7 @@ func (t *Team) getMembers(e Engine) (err error) {
// GetMembers returns paginated members in team of organization.
func (t *Team) GetMembers(opts *SearchMembersOptions) (err error) {
if opts.Page == 0 {
return t.getMembers(x)
return t.getMembers(db.DefaultContext().Engine())
}
return t.getMembers(getPaginatedSession(opts))
@@ -196,16 +204,16 @@ func (t *Team) RemoveMember(userID int64) error {
return RemoveTeamMember(t, userID)
}
func (t *Team) hasRepository(e Engine, repoID int64) bool {
func (t *Team) hasRepository(e db.Engine, repoID int64) bool {
return hasTeamRepo(e, t.OrgID, t.ID, repoID)
}
// HasRepository returns true if given repository belong to team.
func (t *Team) HasRepository(repoID int64) bool {
return t.hasRepository(x, repoID)
return t.hasRepository(db.DefaultContext().Engine(), repoID)
}
func (t *Team) addRepository(e Engine, repo *Repository) (err error) {
func (t *Team) addRepository(e db.Engine, repo *Repository) (err error) {
if err = addTeamRepo(e, t.OrgID, t.ID, repo.ID); err != nil {
return err
}
@@ -237,7 +245,7 @@ func (t *Team) addRepository(e Engine, repo *Repository) (err error) {
// addAllRepositories adds all repositories to the team.
// If the team already has some repositories they will be left unchanged.
func (t *Team) addAllRepositories(e Engine) error {
func (t *Team) addAllRepositories(e db.Engine) error {
var orgRepos []Repository
if err := e.Where("owner_id = ?", t.OrgID).Find(&orgRepos); err != nil {
return fmt.Errorf("get org repos: %v", err)
@@ -256,7 +264,7 @@ func (t *Team) addAllRepositories(e Engine) error {
// AddAllRepositories adds all repositories to the team
func (t *Team) AddAllRepositories() (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -277,7 +285,7 @@ func (t *Team) AddRepository(repo *Repository) (err error) {
return nil
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -296,7 +304,7 @@ func (t *Team) RemoveAllRepositories() (err error) {
return nil
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -311,7 +319,7 @@ func (t *Team) RemoveAllRepositories() (err error) {
// removeAllRepositories removes all repositories from team and recalculates access
// Note: Shall not be called if team includes all repositories
func (t *Team) removeAllRepositories(e Engine) (err error) {
func (t *Team) removeAllRepositories(e db.Engine) (err error) {
// Delete all accesses.
for _, repo := range t.Repos {
if err := repo.recalculateTeamAccesses(e, t.ID); err != nil {
@@ -355,7 +363,7 @@ func (t *Team) removeAllRepositories(e Engine) (err error) {
// removeRepository removes a repository from a team and recalculates access
// Note: Repository shall not be removed from team if it includes all repositories (unless the repository is deleted)
func (t *Team) removeRepository(e Engine, repo *Repository, recalculate bool) (err error) {
func (t *Team) removeRepository(e db.Engine, repo *Repository, recalculate bool) (err error) {
if err = removeTeamRepo(e, t.ID, repo.ID); err != nil {
return err
}
@@ -413,7 +421,7 @@ func (t *Team) RemoveRepository(repoID int64) error {
return err
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -428,10 +436,10 @@ func (t *Team) RemoveRepository(repoID int64) error {
// UnitEnabled returns if the team has the given unit type enabled
func (t *Team) UnitEnabled(tp UnitType) bool {
return t.unitEnabled(x, tp)
return t.unitEnabled(db.DefaultContext().Engine(), tp)
}
func (t *Team) unitEnabled(e Engine, tp UnitType) bool {
func (t *Team) unitEnabled(e db.Engine, tp UnitType) bool {
if err := t.getUnits(e); err != nil {
log.Warn("Error loading team (ID: %d) units: %s", t.ID, err.Error())
}
@@ -465,7 +473,7 @@ func NewTeam(t *Team) (err error) {
return err
}
has, err := x.ID(t.OrgID).Get(new(User))
has, err := db.DefaultContext().Engine().ID(t.OrgID).Get(new(User))
if err != nil {
return err
}
@@ -474,7 +482,7 @@ func NewTeam(t *Team) (err error) {
}
t.LowerName = strings.ToLower(t.Name)
has, err = x.
has, err = db.DefaultContext().Engine().
Where("org_id=?", t.OrgID).
And("lower_name=?", t.LowerName).
Get(new(Team))
@@ -485,7 +493,7 @@ func NewTeam(t *Team) (err error) {
return ErrTeamAlreadyExist{t.OrgID, t.LowerName}
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -532,7 +540,7 @@ func NewTeam(t *Team) (err error) {
return sess.Commit()
}
func getTeam(e Engine, orgID int64, name string) (*Team, error) {
func getTeam(e db.Engine, orgID int64, name string) (*Team, error) {
t := &Team{
OrgID: orgID,
LowerName: strings.ToLower(name),
@@ -548,7 +556,7 @@ func getTeam(e Engine, orgID int64, name string) (*Team, error) {
// GetTeam returns team by given team name and organization.
func GetTeam(orgID int64, name string) (*Team, error) {
return getTeam(x, orgID, name)
return getTeam(db.DefaultContext().Engine(), orgID, name)
}
// GetTeamIDsByNames returns a slice of team ids corresponds to names.
@@ -569,11 +577,11 @@ func GetTeamIDsByNames(orgID int64, names []string, ignoreNonExistent bool) ([]i
}
// getOwnerTeam returns team by given team name and organization.
func getOwnerTeam(e Engine, orgID int64) (*Team, error) {
func getOwnerTeam(e db.Engine, orgID int64) (*Team, error) {
return getTeam(e, orgID, ownerTeamName)
}
func getTeamByID(e Engine, teamID int64) (*Team, error) {
func getTeamByID(e db.Engine, teamID int64) (*Team, error) {
t := new(Team)
has, err := e.ID(teamID).Get(t)
if err != nil {
@@ -586,7 +594,7 @@ func getTeamByID(e Engine, teamID int64) (*Team, error) {
// GetTeamByID returns team by given ID.
func GetTeamByID(teamID int64) (*Team, error) {
return getTeamByID(x, teamID)
return getTeamByID(db.DefaultContext().Engine(), teamID)
}
// GetTeamNamesByID returns team's lower name from a list of team ids.
@@ -596,7 +604,7 @@ func GetTeamNamesByID(teamIDs []int64) ([]string, error) {
}
var teamNames []string
err := x.Table("team").
err := db.DefaultContext().Engine().Table("team").
Select("lower_name").
In("id", teamIDs).
Asc("name").
@@ -615,7 +623,7 @@ func UpdateTeam(t *Team, authChanged, includeAllChanged bool) (err error) {
t.Description = t.Description[:255]
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -689,7 +697,7 @@ func DeleteTeam(t *Team) error {
return err
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -745,7 +753,7 @@ type TeamUser struct {
UID int64 `xorm:"UNIQUE(s)"`
}
func isTeamMember(e Engine, orgID, teamID, userID int64) (bool, error) {
func isTeamMember(e db.Engine, orgID, teamID, userID int64) (bool, error) {
return e.
Where("org_id=?", orgID).
And("team_id=?", teamID).
@@ -756,17 +764,17 @@ func isTeamMember(e Engine, orgID, teamID, userID int64) (bool, error) {
// IsTeamMember returns true if given user is a member of team.
func IsTeamMember(orgID, teamID, userID int64) (bool, error) {
return isTeamMember(x, orgID, teamID, userID)
return isTeamMember(db.DefaultContext().Engine(), orgID, teamID, userID)
}
func getTeamUsersByTeamID(e Engine, teamID int64) ([]*TeamUser, error) {
func getTeamUsersByTeamID(e db.Engine, teamID int64) ([]*TeamUser, error) {
teamUsers := make([]*TeamUser, 0, 10)
return teamUsers, e.
Where("team_id=?", teamID).
Find(&teamUsers)
}
func getTeamMembers(e Engine, teamID int64) (_ []*User, err error) {
func getTeamMembers(e db.Engine, teamID int64) (_ []*User, err error) {
teamUsers, err := getTeamUsersByTeamID(e, teamID)
if err != nil {
return nil, fmt.Errorf("get team-users: %v", err)
@@ -787,10 +795,10 @@ func getTeamMembers(e Engine, teamID int64) (_ []*User, err error) {
// GetTeamMembers returns all members in given team of organization.
func GetTeamMembers(teamID int64) ([]*User, error) {
return getTeamMembers(x, teamID)
return getTeamMembers(db.DefaultContext().Engine(), teamID)
}
func getUserOrgTeams(e Engine, orgID, userID int64) (teams []*Team, err error) {
func getUserOrgTeams(e db.Engine, orgID, userID int64) (teams []*Team, err error) {
return teams, e.
Join("INNER", "team_user", "team_user.team_id = team.id").
Where("team.org_id = ?", orgID).
@@ -798,7 +806,7 @@ func getUserOrgTeams(e Engine, orgID, userID int64) (teams []*Team, err error) {
Find(&teams)
}
func getUserRepoTeams(e Engine, orgID, userID, repoID int64) (teams []*Team, err error) {
func getUserRepoTeams(e db.Engine, orgID, userID, repoID int64) (teams []*Team, err error) {
return teams, e.
Join("INNER", "team_user", "team_user.team_id = team.id").
Join("INNER", "team_repo", "team_repo.team_id = team.id").
@@ -810,7 +818,7 @@ func getUserRepoTeams(e Engine, orgID, userID, repoID int64) (teams []*Team, err
// GetUserOrgTeams returns all teams that user belongs to in given organization.
func GetUserOrgTeams(orgID, userID int64) ([]*Team, error) {
return getUserOrgTeams(x, orgID, userID)
return getUserOrgTeams(db.DefaultContext().Engine(), orgID, userID)
}
// AddTeamMember adds new membership of given team to given organization,
@@ -830,7 +838,7 @@ func AddTeamMember(team *Team, userID int64) error {
return err
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -925,7 +933,7 @@ func removeTeamMember(e *xorm.Session, team *Team, userID int64) error {
// RemoveTeamMember removes member from given team of given organization.
func RemoveTeamMember(team *Team, userID int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -938,17 +946,17 @@ func RemoveTeamMember(team *Team, userID int64) error {
// IsUserInTeams returns if a user in some teams
func IsUserInTeams(userID int64, teamIDs []int64) (bool, error) {
return isUserInTeams(x, userID, teamIDs)
return isUserInTeams(db.DefaultContext().Engine(), userID, teamIDs)
}
func isUserInTeams(e Engine, userID int64, teamIDs []int64) (bool, error) {
func isUserInTeams(e db.Engine, userID int64, teamIDs []int64) (bool, error) {
return e.Where("uid=?", userID).In("team_id", teamIDs).Exist(new(TeamUser))
}
// UsersInTeamsCount counts the number of users which are in userIDs and teamIDs
func UsersInTeamsCount(userIDs, teamIDs []int64) (int64, error) {
var ids []int64
if err := x.In("uid", userIDs).In("team_id", teamIDs).
if err := db.DefaultContext().Engine().In("uid", userIDs).In("team_id", teamIDs).
Table("team_user").
Cols("uid").GroupBy("uid").Find(&ids); err != nil {
return 0, err
@@ -971,7 +979,7 @@ type TeamRepo struct {
RepoID int64 `xorm:"UNIQUE(s)"`
}
func hasTeamRepo(e Engine, orgID, teamID, repoID int64) bool {
func hasTeamRepo(e db.Engine, orgID, teamID, repoID int64) bool {
has, _ := e.
Where("org_id=?", orgID).
And("team_id=?", teamID).
@@ -982,10 +990,10 @@ func hasTeamRepo(e Engine, orgID, teamID, repoID int64) bool {
// HasTeamRepo returns true if given repository belongs to team.
func HasTeamRepo(orgID, teamID, repoID int64) bool {
return hasTeamRepo(x, orgID, teamID, repoID)
return hasTeamRepo(db.DefaultContext().Engine(), orgID, teamID, repoID)
}
func addTeamRepo(e Engine, orgID, teamID, repoID int64) error {
func addTeamRepo(e db.Engine, orgID, teamID, repoID int64) error {
_, err := e.InsertOne(&TeamRepo{
OrgID: orgID,
TeamID: teamID,
@@ -994,7 +1002,7 @@ func addTeamRepo(e Engine, orgID, teamID, repoID int64) error {
return err
}
func removeTeamRepo(e Engine, teamID, repoID int64) error {
func removeTeamRepo(e db.Engine, teamID, repoID int64) error {
_, err := e.Delete(&TeamRepo{
TeamID: teamID,
RepoID: repoID,
@@ -1005,7 +1013,7 @@ func removeTeamRepo(e Engine, teamID, repoID int64) error {
// GetTeamsWithAccessToRepo returns all teams in an organization that have given access level to the repository.
func GetTeamsWithAccessToRepo(orgID, repoID int64, mode AccessMode) ([]*Team, error) {
teams := make([]*Team, 0, 5)
return teams, x.Where("team.authorize >= ?", mode).
return teams, db.DefaultContext().Engine().Where("team.authorize >= ?", mode).
Join("INNER", "team_repo", "team_repo.team_id = team.id").
And("team_repo.org_id = ?", orgID).
And("team_repo.repo_id = ?", repoID).
@@ -1032,13 +1040,13 @@ func (t *TeamUnit) Unit() Unit {
return Units[t.Type]
}
func getUnitsByTeamID(e Engine, teamID int64) (units []*TeamUnit, err error) {
func getUnitsByTeamID(e db.Engine, teamID int64) (units []*TeamUnit, err error) {
return units, e.Where("team_id = ?", teamID).Find(&units)
}
// UpdateTeamUnits updates a teams's units
func UpdateTeamUnits(team *Team, units []TeamUnit) (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
+80 -79
View File
@@ -8,42 +8,43 @@ import (
"strings"
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestTeam_IsOwnerTeam(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
team := AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
assert.True(t, team.IsOwnerTeam())
team = AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
team = db.AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
assert.False(t, team.IsOwnerTeam())
}
func TestTeam_IsMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
team := AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
assert.True(t, team.IsMember(2))
assert.False(t, team.IsMember(4))
assert.False(t, team.IsMember(NonexistentID))
assert.False(t, team.IsMember(db.NonexistentID))
team = AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
team = db.AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
assert.True(t, team.IsMember(2))
assert.True(t, team.IsMember(4))
assert.False(t, team.IsMember(NonexistentID))
assert.False(t, team.IsMember(db.NonexistentID))
}
func TestTeam_GetRepositories(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(teamID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.NoError(t, team.GetRepositories(&SearchTeamOptions{}))
assert.Len(t, team.Repos, team.NumRepos)
for _, repo := range team.Repos {
AssertExistsAndLoadBean(t, &TeamRepo{TeamID: teamID, RepoID: repo.ID})
db.AssertExistsAndLoadBean(t, &TeamRepo{TeamID: teamID, RepoID: repo.ID})
}
}
test(1)
@@ -51,14 +52,14 @@ func TestTeam_GetRepositories(t *testing.T) {
}
func TestTeam_GetMembers(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(teamID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.NoError(t, team.GetMembers(&SearchMembersOptions{}))
assert.Len(t, team.Members, team.NumMembers)
for _, member := range team.Members {
AssertExistsAndLoadBean(t, &TeamUser{UID: member.ID, TeamID: teamID})
db.AssertExistsAndLoadBean(t, &TeamUser{UID: member.ID, TeamID: teamID})
}
}
test(1)
@@ -66,12 +67,12 @@ func TestTeam_GetMembers(t *testing.T) {
}
func TestTeam_AddMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(teamID, userID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.NoError(t, team.AddMember(userID))
AssertExistsAndLoadBean(t, &TeamUser{UID: userID, TeamID: teamID})
db.AssertExistsAndLoadBean(t, &TeamUser{UID: userID, TeamID: teamID})
CheckConsistencyFor(t, &Team{ID: teamID}, &User{ID: team.OrgID})
}
test(1, 2)
@@ -80,71 +81,71 @@ func TestTeam_AddMember(t *testing.T) {
}
func TestTeam_RemoveMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(teamID, userID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.NoError(t, team.RemoveMember(userID))
AssertNotExistsBean(t, &TeamUser{UID: userID, TeamID: teamID})
db.AssertNotExistsBean(t, &TeamUser{UID: userID, TeamID: teamID})
CheckConsistencyFor(t, &Team{ID: teamID})
}
testSuccess(1, 4)
testSuccess(2, 2)
testSuccess(3, 2)
testSuccess(3, NonexistentID)
testSuccess(3, db.NonexistentID)
team := AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
err := team.RemoveMember(2)
assert.True(t, IsErrLastOrgOwner(err))
}
func TestTeam_HasRepository(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(teamID, repoID int64, expected bool) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.Equal(t, expected, team.HasRepository(repoID))
}
test(1, 1, false)
test(1, 3, true)
test(1, 5, true)
test(1, NonexistentID, false)
test(1, db.NonexistentID, false)
test(2, 3, true)
test(2, 5, false)
}
func TestTeam_AddRepository(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(teamID, repoID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
assert.NoError(t, team.AddRepository(repo))
AssertExistsAndLoadBean(t, &TeamRepo{TeamID: teamID, RepoID: repoID})
db.AssertExistsAndLoadBean(t, &TeamRepo{TeamID: teamID, RepoID: repoID})
CheckConsistencyFor(t, &Team{ID: teamID}, &Repository{ID: repoID})
}
testSuccess(2, 3)
testSuccess(2, 5)
team := AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
repo := AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
team := db.AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 1}).(*Repository)
assert.Error(t, team.AddRepository(repo))
CheckConsistencyFor(t, &Team{ID: 1}, &Repository{ID: 1})
}
func TestTeam_RemoveRepository(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(teamID, repoID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.NoError(t, team.RemoveRepository(repoID))
AssertNotExistsBean(t, &TeamRepo{TeamID: teamID, RepoID: repoID})
db.AssertNotExistsBean(t, &TeamRepo{TeamID: teamID, RepoID: repoID})
CheckConsistencyFor(t, &Team{ID: teamID}, &Repository{ID: repoID})
}
testSuccess(2, 3)
testSuccess(2, 5)
testSuccess(1, NonexistentID)
testSuccess(1, db.NonexistentID)
}
func TestIsUsableTeamName(t *testing.T) {
@@ -153,17 +154,17 @@ func TestIsUsableTeamName(t *testing.T) {
}
func TestNewTeam(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
const teamName = "newTeamName"
team := &Team{Name: teamName, OrgID: 3}
assert.NoError(t, NewTeam(team))
AssertExistsAndLoadBean(t, &Team{Name: teamName})
db.AssertExistsAndLoadBean(t, &Team{Name: teamName})
CheckConsistencyFor(t, &Team{}, &User{ID: team.OrgID})
}
func TestGetTeam(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(orgID int64, name string) {
team, err := GetTeam(orgID, name)
@@ -176,12 +177,12 @@ func TestGetTeam(t *testing.T) {
_, err := GetTeam(3, "nonexistent")
assert.Error(t, err)
_, err = GetTeam(NonexistentID, "Owners")
_, err = GetTeam(db.NonexistentID, "Owners")
assert.Error(t, err)
}
func TestGetTeamByID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(teamID int64) {
team, err := GetTeamByID(teamID)
@@ -193,25 +194,25 @@ func TestGetTeamByID(t *testing.T) {
testSuccess(3)
testSuccess(4)
_, err := GetTeamByID(NonexistentID)
_, err := GetTeamByID(db.NonexistentID)
assert.Error(t, err)
}
func TestUpdateTeam(t *testing.T) {
// successful update
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
team := AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
team.LowerName = "newname"
team.Name = "newName"
team.Description = strings.Repeat("A long description!", 100)
team.Authorize = AccessModeAdmin
assert.NoError(t, UpdateTeam(team, true, false))
team = AssertExistsAndLoadBean(t, &Team{Name: "newName"}).(*Team)
team = db.AssertExistsAndLoadBean(t, &Team{Name: "newName"}).(*Team)
assert.True(t, strings.HasPrefix(team.Description, "A long description!"))
access := AssertExistsAndLoadBean(t, &Access{UserID: 4, RepoID: 3}).(*Access)
access := db.AssertExistsAndLoadBean(t, &Access{UserID: 4, RepoID: 3}).(*Access)
assert.EqualValues(t, AccessModeAdmin, access.Mode)
CheckConsistencyFor(t, &Team{ID: team.ID})
@@ -219,9 +220,9 @@ func TestUpdateTeam(t *testing.T) {
func TestUpdateTeam2(t *testing.T) {
// update to already-existing team
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
team := AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
team.LowerName = "owners"
team.Name = "Owners"
team.Description = strings.Repeat("A long description!", 100)
@@ -232,24 +233,24 @@ func TestUpdateTeam2(t *testing.T) {
}
func TestDeleteTeam(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
team := AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: 2}).(*Team)
assert.NoError(t, DeleteTeam(team))
AssertNotExistsBean(t, &Team{ID: team.ID})
AssertNotExistsBean(t, &TeamRepo{TeamID: team.ID})
AssertNotExistsBean(t, &TeamUser{TeamID: team.ID})
db.AssertNotExistsBean(t, &Team{ID: team.ID})
db.AssertNotExistsBean(t, &TeamRepo{TeamID: team.ID})
db.AssertNotExistsBean(t, &TeamUser{TeamID: team.ID})
// check that team members don't have "leftover" access to repos
user := AssertExistsAndLoadBean(t, &User{ID: 4}).(*User)
repo := AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
user := db.AssertExistsAndLoadBean(t, &User{ID: 4}).(*User)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 3}).(*Repository)
accessMode, err := AccessLevel(user, repo)
assert.NoError(t, err)
assert.True(t, accessMode < AccessModeWrite)
}
func TestIsTeamMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(orgID, teamID, userID int64, expected bool) {
isMember, err := IsTeamMember(orgID, teamID, userID)
assert.NoError(t, err)
@@ -258,25 +259,25 @@ func TestIsTeamMember(t *testing.T) {
test(3, 1, 2, true)
test(3, 1, 4, false)
test(3, 1, NonexistentID, false)
test(3, 1, db.NonexistentID, false)
test(3, 2, 2, true)
test(3, 2, 4, true)
test(3, NonexistentID, NonexistentID, false)
test(NonexistentID, NonexistentID, NonexistentID, false)
test(3, db.NonexistentID, db.NonexistentID, false)
test(db.NonexistentID, db.NonexistentID, db.NonexistentID, false)
}
func TestGetTeamMembers(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(teamID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
members, err := GetTeamMembers(teamID)
assert.NoError(t, err)
assert.Len(t, members, team.NumMembers)
for _, member := range members {
AssertExistsAndLoadBean(t, &TeamUser{UID: member.ID, TeamID: teamID})
db.AssertExistsAndLoadBean(t, &TeamUser{UID: member.ID, TeamID: teamID})
}
}
test(1)
@@ -284,41 +285,41 @@ func TestGetTeamMembers(t *testing.T) {
}
func TestGetUserTeams(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(userID int64) {
teams, _, err := SearchTeam(&SearchTeamOptions{UserID: userID})
assert.NoError(t, err)
for _, team := range teams {
AssertExistsAndLoadBean(t, &TeamUser{TeamID: team.ID, UID: userID})
db.AssertExistsAndLoadBean(t, &TeamUser{TeamID: team.ID, UID: userID})
}
}
test(2)
test(5)
test(NonexistentID)
test(db.NonexistentID)
}
func TestGetUserOrgTeams(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(orgID, userID int64) {
teams, err := GetUserOrgTeams(orgID, userID)
assert.NoError(t, err)
for _, team := range teams {
assert.EqualValues(t, orgID, team.OrgID)
AssertExistsAndLoadBean(t, &TeamUser{TeamID: team.ID, UID: userID})
db.AssertExistsAndLoadBean(t, &TeamUser{TeamID: team.ID, UID: userID})
}
}
test(3, 2)
test(3, 4)
test(3, NonexistentID)
test(3, db.NonexistentID)
}
func TestAddTeamMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(teamID, userID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.NoError(t, AddTeamMember(team, userID))
AssertExistsAndLoadBean(t, &TeamUser{UID: userID, TeamID: teamID})
db.AssertExistsAndLoadBean(t, &TeamUser{UID: userID, TeamID: teamID})
CheckConsistencyFor(t, &Team{ID: teamID}, &User{ID: team.OrgID})
}
test(1, 2)
@@ -327,42 +328,42 @@ func TestAddTeamMember(t *testing.T) {
}
func TestRemoveTeamMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(teamID, userID int64) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.NoError(t, RemoveTeamMember(team, userID))
AssertNotExistsBean(t, &TeamUser{UID: userID, TeamID: teamID})
db.AssertNotExistsBean(t, &TeamUser{UID: userID, TeamID: teamID})
CheckConsistencyFor(t, &Team{ID: teamID})
}
testSuccess(1, 4)
testSuccess(2, 2)
testSuccess(3, 2)
testSuccess(3, NonexistentID)
testSuccess(3, db.NonexistentID)
team := AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: 1}).(*Team)
err := RemoveTeamMember(team, 2)
assert.True(t, IsErrLastOrgOwner(err))
}
func TestHasTeamRepo(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(teamID, repoID int64, expected bool) {
team := AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
team := db.AssertExistsAndLoadBean(t, &Team{ID: teamID}).(*Team)
assert.Equal(t, expected, HasTeamRepo(team.OrgID, teamID, repoID))
}
test(1, 1, false)
test(1, 3, true)
test(1, 5, true)
test(1, NonexistentID, false)
test(1, db.NonexistentID, false)
test(2, 3, true)
test(2, 5, false)
}
func TestUsersInTeamsCount(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(teamIDs, userIDs []int64, expected int64) {
count, err := UsersInTeamsCount(teamIDs, userIDs)
+120 -119
View File
@@ -7,6 +7,7 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/structs"
@@ -14,7 +15,7 @@ import (
)
func TestUser_IsOwnedBy(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
for _, testCase := range []struct {
OrgID int64
UserID int64
@@ -27,7 +28,7 @@ func TestUser_IsOwnedBy(t *testing.T) {
{2, 2, false}, // user2 is not an organization
{2, 3, false},
} {
org := AssertExistsAndLoadBean(t, &User{ID: testCase.OrgID}).(*User)
org := db.AssertExistsAndLoadBean(t, &User{ID: testCase.OrgID}).(*User)
isOwner, err := org.IsOwnedBy(testCase.UserID)
assert.NoError(t, err)
assert.Equal(t, testCase.ExpectedOwner, isOwner)
@@ -35,7 +36,7 @@ func TestUser_IsOwnedBy(t *testing.T) {
}
func TestUser_IsOrgMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
for _, testCase := range []struct {
OrgID int64
UserID int64
@@ -48,7 +49,7 @@ func TestUser_IsOrgMember(t *testing.T) {
{2, 2, false}, // user2 is not an organization
{2, 3, false},
} {
org := AssertExistsAndLoadBean(t, &User{ID: testCase.OrgID}).(*User)
org := db.AssertExistsAndLoadBean(t, &User{ID: testCase.OrgID}).(*User)
isMember, err := org.IsOrgMember(testCase.UserID)
assert.NoError(t, err)
assert.Equal(t, testCase.ExpectedMember, isMember)
@@ -56,8 +57,8 @@ func TestUser_IsOrgMember(t *testing.T) {
}
func TestUser_GetTeam(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
team, err := org.GetTeam("team1")
assert.NoError(t, err)
assert.Equal(t, org.ID, team.OrgID)
@@ -66,26 +67,26 @@ func TestUser_GetTeam(t *testing.T) {
_, err = org.GetTeam("does not exist")
assert.True(t, IsErrTeamNotExist(err))
nonOrg := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
nonOrg := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
_, err = nonOrg.GetTeam("team")
assert.True(t, IsErrTeamNotExist(err))
}
func TestUser_GetOwnerTeam(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
team, err := org.GetOwnerTeam()
assert.NoError(t, err)
assert.Equal(t, org.ID, team.OrgID)
nonOrg := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
nonOrg := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
_, err = nonOrg.GetOwnerTeam()
assert.True(t, IsErrTeamNotExist(err))
}
func TestUser_GetTeams(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, org.LoadTeams())
if assert.Len(t, org.Teams, 4) {
assert.Equal(t, int64(1), org.Teams[0].ID)
@@ -96,8 +97,8 @@ func TestUser_GetTeams(t *testing.T) {
}
func TestUser_GetMembers(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, org.GetMembers())
if assert.Len(t, org.Members, 3) {
assert.Equal(t, int64(2), org.Members[0].ID)
@@ -107,67 +108,67 @@ func TestUser_GetMembers(t *testing.T) {
}
func TestUser_AddMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
// add a user that is not a member
AssertNotExistsBean(t, &OrgUser{UID: 5, OrgID: 3})
db.AssertNotExistsBean(t, &OrgUser{UID: 5, OrgID: 3})
prevNumMembers := org.NumMembers
assert.NoError(t, org.AddMember(5))
AssertExistsAndLoadBean(t, &OrgUser{UID: 5, OrgID: 3})
org = AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
db.AssertExistsAndLoadBean(t, &OrgUser{UID: 5, OrgID: 3})
org = db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.Equal(t, prevNumMembers+1, org.NumMembers)
// add a user that is already a member
AssertExistsAndLoadBean(t, &OrgUser{UID: 4, OrgID: 3})
db.AssertExistsAndLoadBean(t, &OrgUser{UID: 4, OrgID: 3})
prevNumMembers = org.NumMembers
assert.NoError(t, org.AddMember(4))
AssertExistsAndLoadBean(t, &OrgUser{UID: 4, OrgID: 3})
org = AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
db.AssertExistsAndLoadBean(t, &OrgUser{UID: 4, OrgID: 3})
org = db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.Equal(t, prevNumMembers, org.NumMembers)
CheckConsistencyFor(t, &User{})
}
func TestUser_RemoveMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
// remove a user that is a member
AssertExistsAndLoadBean(t, &OrgUser{UID: 4, OrgID: 3})
db.AssertExistsAndLoadBean(t, &OrgUser{UID: 4, OrgID: 3})
prevNumMembers := org.NumMembers
assert.NoError(t, org.RemoveMember(4))
AssertNotExistsBean(t, &OrgUser{UID: 4, OrgID: 3})
org = AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
db.AssertNotExistsBean(t, &OrgUser{UID: 4, OrgID: 3})
org = db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.Equal(t, prevNumMembers-1, org.NumMembers)
// remove a user that is not a member
AssertNotExistsBean(t, &OrgUser{UID: 5, OrgID: 3})
db.AssertNotExistsBean(t, &OrgUser{UID: 5, OrgID: 3})
prevNumMembers = org.NumMembers
assert.NoError(t, org.RemoveMember(5))
AssertNotExistsBean(t, &OrgUser{UID: 5, OrgID: 3})
org = AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
db.AssertNotExistsBean(t, &OrgUser{UID: 5, OrgID: 3})
org = db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.Equal(t, prevNumMembers, org.NumMembers)
CheckConsistencyFor(t, &User{}, &Team{})
}
func TestUser_RemoveOrgRepo(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
repo := AssertExistsAndLoadBean(t, &Repository{OwnerID: org.ID}).(*Repository)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
repo := db.AssertExistsAndLoadBean(t, &Repository{OwnerID: org.ID}).(*Repository)
// remove a repo that does belong to org
AssertExistsAndLoadBean(t, &TeamRepo{RepoID: repo.ID, OrgID: org.ID})
db.AssertExistsAndLoadBean(t, &TeamRepo{RepoID: repo.ID, OrgID: org.ID})
assert.NoError(t, org.RemoveOrgRepo(repo.ID))
AssertNotExistsBean(t, &TeamRepo{RepoID: repo.ID, OrgID: org.ID})
AssertExistsAndLoadBean(t, &Repository{ID: repo.ID}) // repo should still exist
db.AssertNotExistsBean(t, &TeamRepo{RepoID: repo.ID, OrgID: org.ID})
db.AssertExistsAndLoadBean(t, &Repository{ID: repo.ID}) // repo should still exist
// remove a repo that does not belong to org
assert.NoError(t, org.RemoveOrgRepo(repo.ID))
AssertNotExistsBean(t, &TeamRepo{RepoID: repo.ID, OrgID: org.ID})
db.AssertNotExistsBean(t, &TeamRepo{RepoID: repo.ID, OrgID: org.ID})
assert.NoError(t, org.RemoveOrgRepo(NonexistentID))
assert.NoError(t, org.RemoveOrgRepo(db.NonexistentID))
CheckConsistencyFor(t,
&User{ID: org.ID},
@@ -177,49 +178,49 @@ func TestUser_RemoveOrgRepo(t *testing.T) {
func TestCreateOrganization(t *testing.T) {
// successful creation of org
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
owner := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
owner := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
const newOrgName = "neworg"
org := &User{
Name: newOrgName,
}
AssertNotExistsBean(t, &User{Name: newOrgName, Type: UserTypeOrganization})
db.AssertNotExistsBean(t, &User{Name: newOrgName, Type: UserTypeOrganization})
assert.NoError(t, CreateOrganization(org, owner))
org = AssertExistsAndLoadBean(t,
org = db.AssertExistsAndLoadBean(t,
&User{Name: newOrgName, Type: UserTypeOrganization}).(*User)
ownerTeam := AssertExistsAndLoadBean(t,
ownerTeam := db.AssertExistsAndLoadBean(t,
&Team{Name: ownerTeamName, OrgID: org.ID}).(*Team)
AssertExistsAndLoadBean(t, &TeamUser{UID: owner.ID, TeamID: ownerTeam.ID})
db.AssertExistsAndLoadBean(t, &TeamUser{UID: owner.ID, TeamID: ownerTeam.ID})
CheckConsistencyFor(t, &User{}, &Team{})
}
func TestCreateOrganization2(t *testing.T) {
// unauthorized creation of org
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
owner := AssertExistsAndLoadBean(t, &User{ID: 5}).(*User)
owner := db.AssertExistsAndLoadBean(t, &User{ID: 5}).(*User)
const newOrgName = "neworg"
org := &User{
Name: newOrgName,
}
AssertNotExistsBean(t, &User{Name: newOrgName, Type: UserTypeOrganization})
db.AssertNotExistsBean(t, &User{Name: newOrgName, Type: UserTypeOrganization})
err := CreateOrganization(org, owner)
assert.Error(t, err)
assert.True(t, IsErrUserNotAllowedCreateOrg(err))
AssertNotExistsBean(t, &User{Name: newOrgName, Type: UserTypeOrganization})
db.AssertNotExistsBean(t, &User{Name: newOrgName, Type: UserTypeOrganization})
CheckConsistencyFor(t, &User{}, &Team{})
}
func TestCreateOrganization3(t *testing.T) {
// create org with same name as existent org
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
owner := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
org := &User{Name: "user3"} // should already exist
AssertExistsAndLoadBean(t, &User{Name: org.Name}) // sanity check
owner := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
org := &User{Name: "user3"} // should already exist
db.AssertExistsAndLoadBean(t, &User{Name: org.Name}) // sanity check
err := CreateOrganization(org, owner)
assert.Error(t, err)
assert.True(t, IsErrUserAlreadyExist(err))
@@ -228,9 +229,9 @@ func TestCreateOrganization3(t *testing.T) {
func TestCreateOrganization4(t *testing.T) {
// create org with unusable name
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
owner := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
owner := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
err := CreateOrganization(&User{Name: "assets"}, owner)
assert.Error(t, err)
assert.True(t, IsErrNameReserved(err))
@@ -238,7 +239,7 @@ func TestCreateOrganization4(t *testing.T) {
}
func TestGetOrgByName(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
org, err := GetOrgByName("user3")
assert.NoError(t, err)
@@ -253,32 +254,32 @@ func TestGetOrgByName(t *testing.T) {
}
func TestCountOrganizations(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
expected, err := x.Where("type=?", UserTypeOrganization).Count(&User{})
assert.NoError(t, db.PrepareTestDatabase())
expected, err := db.DefaultContext().Engine().Where("type=?", UserTypeOrganization).Count(&User{})
assert.NoError(t, err)
assert.Equal(t, expected, CountOrganizations())
}
func TestDeleteOrganization(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 6}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 6}).(*User)
assert.NoError(t, DeleteOrganization(org))
AssertNotExistsBean(t, &User{ID: 6})
AssertNotExistsBean(t, &OrgUser{OrgID: 6})
AssertNotExistsBean(t, &Team{OrgID: 6})
db.AssertNotExistsBean(t, &User{ID: 6})
db.AssertNotExistsBean(t, &OrgUser{OrgID: 6})
db.AssertNotExistsBean(t, &Team{OrgID: 6})
org = AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
org = db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
err := DeleteOrganization(org)
assert.Error(t, err)
assert.True(t, IsErrUserOwnRepos(err))
user := AssertExistsAndLoadBean(t, &User{ID: 5}).(*User)
user := db.AssertExistsAndLoadBean(t, &User{ID: 5}).(*User)
assert.Error(t, DeleteOrganization(user))
CheckConsistencyFor(t, &User{}, &Team{})
}
func TestIsOrganizationOwner(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(orgID, userID int64, expected bool) {
isOwner, err := IsOrganizationOwner(orgID, userID)
assert.NoError(t, err)
@@ -288,11 +289,11 @@ func TestIsOrganizationOwner(t *testing.T) {
test(3, 3, false)
test(6, 5, true)
test(6, 4, false)
test(NonexistentID, NonexistentID, false)
test(db.NonexistentID, db.NonexistentID, false)
}
func TestIsOrganizationMember(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(orgID, userID int64, expected bool) {
isMember, err := IsOrganizationMember(orgID, userID)
assert.NoError(t, err)
@@ -303,11 +304,11 @@ func TestIsOrganizationMember(t *testing.T) {
test(3, 4, true)
test(6, 5, true)
test(6, 4, false)
test(NonexistentID, NonexistentID, false)
test(db.NonexistentID, db.NonexistentID, false)
}
func TestIsPublicMembership(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(orgID, userID int64, expected bool) {
isMember, err := IsPublicMembership(orgID, userID)
assert.NoError(t, err)
@@ -318,11 +319,11 @@ func TestIsPublicMembership(t *testing.T) {
test(3, 4, false)
test(6, 5, true)
test(6, 4, false)
test(NonexistentID, NonexistentID, false)
test(db.NonexistentID, db.NonexistentID, false)
}
func TestGetOrgsByUserID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
orgs, err := GetOrgsByUserID(4, true)
assert.NoError(t, err)
@@ -336,7 +337,7 @@ func TestGetOrgsByUserID(t *testing.T) {
}
func TestGetOwnedOrgsByUserID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
orgs, err := GetOwnedOrgsByUserID(2)
assert.NoError(t, err)
@@ -350,7 +351,7 @@ func TestGetOwnedOrgsByUserID(t *testing.T) {
}
func TestGetOwnedOrgsByUserIDDesc(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
orgs, err := GetOwnedOrgsByUserIDDesc(5, "id")
assert.NoError(t, err)
@@ -365,7 +366,7 @@ func TestGetOwnedOrgsByUserIDDesc(t *testing.T) {
}
func TestGetOrgUsersByUserID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
orgUsers, err := GetOrgUsersByUserID(5, &SearchOrganizationsOptions{All: true})
assert.NoError(t, err)
@@ -395,7 +396,7 @@ func TestGetOrgUsersByUserID(t *testing.T) {
}
func TestGetOrgUsersByOrgID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
orgUsers, err := GetOrgUsersByOrgID(&FindOrgMembersOpts{
ListOptions: ListOptions{},
@@ -420,7 +421,7 @@ func TestGetOrgUsersByOrgID(t *testing.T) {
orgUsers, err = GetOrgUsersByOrgID(&FindOrgMembersOpts{
ListOptions: ListOptions{},
OrgID: NonexistentID,
OrgID: db.NonexistentID,
PublicOnly: false,
})
assert.NoError(t, err)
@@ -428,33 +429,33 @@ func TestGetOrgUsersByOrgID(t *testing.T) {
}
func TestChangeOrgUserStatus(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(orgID, userID int64, public bool) {
assert.NoError(t, ChangeOrgUserStatus(orgID, userID, public))
orgUser := AssertExistsAndLoadBean(t, &OrgUser{OrgID: orgID, UID: userID}).(*OrgUser)
orgUser := db.AssertExistsAndLoadBean(t, &OrgUser{OrgID: orgID, UID: userID}).(*OrgUser)
assert.Equal(t, public, orgUser.IsPublic)
}
testSuccess(3, 2, false)
testSuccess(3, 2, false)
testSuccess(3, 4, true)
assert.NoError(t, ChangeOrgUserStatus(NonexistentID, NonexistentID, true))
assert.NoError(t, ChangeOrgUserStatus(db.NonexistentID, db.NonexistentID, true))
}
func TestAddOrgUser(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(orgID, userID int64, isPublic bool) {
org := AssertExistsAndLoadBean(t, &User{ID: orgID}).(*User)
org := db.AssertExistsAndLoadBean(t, &User{ID: orgID}).(*User)
expectedNumMembers := org.NumMembers
if !BeanExists(t, &OrgUser{OrgID: orgID, UID: userID}) {
if !db.BeanExists(t, &OrgUser{OrgID: orgID, UID: userID}) {
expectedNumMembers++
}
assert.NoError(t, AddOrgUser(orgID, userID))
ou := &OrgUser{OrgID: orgID, UID: userID}
AssertExistsAndLoadBean(t, ou)
db.AssertExistsAndLoadBean(t, ou)
assert.Equal(t, isPublic, ou.IsPublic)
org = AssertExistsAndLoadBean(t, &User{ID: orgID}).(*User)
org = db.AssertExistsAndLoadBean(t, &User{ID: orgID}).(*User)
assert.EqualValues(t, expectedNumMembers, org.NumMembers)
}
@@ -470,16 +471,16 @@ func TestAddOrgUser(t *testing.T) {
}
func TestRemoveOrgUser(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(orgID, userID int64) {
org := AssertExistsAndLoadBean(t, &User{ID: orgID}).(*User)
org := db.AssertExistsAndLoadBean(t, &User{ID: orgID}).(*User)
expectedNumMembers := org.NumMembers
if BeanExists(t, &OrgUser{OrgID: orgID, UID: userID}) {
if db.BeanExists(t, &OrgUser{OrgID: orgID, UID: userID}) {
expectedNumMembers--
}
assert.NoError(t, RemoveOrgUser(orgID, userID))
AssertNotExistsBean(t, &OrgUser{OrgID: orgID, UID: userID})
org = AssertExistsAndLoadBean(t, &User{ID: orgID}).(*User)
db.AssertNotExistsBean(t, &OrgUser{OrgID: orgID, UID: userID})
org = db.AssertExistsAndLoadBean(t, &User{ID: orgID}).(*User)
assert.EqualValues(t, expectedNumMembers, org.NumMembers)
}
testSuccess(3, 4)
@@ -488,13 +489,13 @@ func TestRemoveOrgUser(t *testing.T) {
err := RemoveOrgUser(7, 5)
assert.Error(t, err)
assert.True(t, IsErrLastOrgOwner(err))
AssertExistsAndLoadBean(t, &OrgUser{OrgID: 7, UID: 5})
db.AssertExistsAndLoadBean(t, &OrgUser{OrgID: 7, UID: 5})
CheckConsistencyFor(t, &User{}, &Team{})
}
func TestUser_GetUserTeamIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
testSuccess := func(userID int64, expected []int64) {
teamIDs, err := org.GetUserTeamIDs(userID)
assert.NoError(t, err)
@@ -502,12 +503,12 @@ func TestUser_GetUserTeamIDs(t *testing.T) {
}
testSuccess(2, []int64{1, 2})
testSuccess(4, []int64{2})
testSuccess(NonexistentID, []int64{})
testSuccess(db.NonexistentID, []int64{})
}
func TestAccessibleReposEnv_CountRepos(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
testSuccess := func(userID, expectedCount int64) {
env, err := org.AccessibleReposEnv(userID)
assert.NoError(t, err)
@@ -520,8 +521,8 @@ func TestAccessibleReposEnv_CountRepos(t *testing.T) {
}
func TestAccessibleReposEnv_RepoIDs(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
testSuccess := func(userID, _, pageSize int64, expectedRepoIDs []int64) {
env, err := org.AccessibleReposEnv(userID)
assert.NoError(t, err)
@@ -534,8 +535,8 @@ func TestAccessibleReposEnv_RepoIDs(t *testing.T) {
}
func TestAccessibleReposEnv_Repos(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
testSuccess := func(userID int64, expectedRepoIDs []int64) {
env, err := org.AccessibleReposEnv(userID)
assert.NoError(t, err)
@@ -543,7 +544,7 @@ func TestAccessibleReposEnv_Repos(t *testing.T) {
assert.NoError(t, err)
expectedRepos := make([]*Repository, len(expectedRepoIDs))
for i, repoID := range expectedRepoIDs {
expectedRepos[i] = AssertExistsAndLoadBean(t,
expectedRepos[i] = db.AssertExistsAndLoadBean(t,
&Repository{ID: repoID}).(*Repository)
}
assert.Equal(t, expectedRepos, repos)
@@ -553,8 +554,8 @@ func TestAccessibleReposEnv_Repos(t *testing.T) {
}
func TestAccessibleReposEnv_MirrorRepos(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
org := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
org := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
testSuccess := func(userID int64, expectedRepoIDs []int64) {
env, err := org.AccessibleReposEnv(userID)
assert.NoError(t, err)
@@ -562,7 +563,7 @@ func TestAccessibleReposEnv_MirrorRepos(t *testing.T) {
assert.NoError(t, err)
expectedRepos := make([]*Repository, len(expectedRepoIDs))
for i, repoID := range expectedRepoIDs {
expectedRepos[i] = AssertExistsAndLoadBean(t,
expectedRepos[i] = db.AssertExistsAndLoadBean(t,
&Repository{ID: repoID}).(*Repository)
}
assert.Equal(t, expectedRepos, repos)
@@ -572,9 +573,9 @@ func TestAccessibleReposEnv_MirrorRepos(t *testing.T) {
}
func TestHasOrgVisibleTypePublic(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
owner := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
owner := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
const newOrgName = "test-org-public"
org := &User{
@@ -582,9 +583,9 @@ func TestHasOrgVisibleTypePublic(t *testing.T) {
Visibility: structs.VisibleTypePublic,
}
AssertNotExistsBean(t, &User{Name: org.Name, Type: UserTypeOrganization})
db.AssertNotExistsBean(t, &User{Name: org.Name, Type: UserTypeOrganization})
assert.NoError(t, CreateOrganization(org, owner))
org = AssertExistsAndLoadBean(t,
org = db.AssertExistsAndLoadBean(t,
&User{Name: org.Name, Type: UserTypeOrganization}).(*User)
test1 := HasOrgOrUserVisible(org, owner)
test2 := HasOrgOrUserVisible(org, user3)
@@ -595,9 +596,9 @@ func TestHasOrgVisibleTypePublic(t *testing.T) {
}
func TestHasOrgVisibleTypeLimited(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
owner := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
owner := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
const newOrgName = "test-org-limited"
org := &User{
@@ -605,9 +606,9 @@ func TestHasOrgVisibleTypeLimited(t *testing.T) {
Visibility: structs.VisibleTypeLimited,
}
AssertNotExistsBean(t, &User{Name: org.Name, Type: UserTypeOrganization})
db.AssertNotExistsBean(t, &User{Name: org.Name, Type: UserTypeOrganization})
assert.NoError(t, CreateOrganization(org, owner))
org = AssertExistsAndLoadBean(t,
org = db.AssertExistsAndLoadBean(t,
&User{Name: org.Name, Type: UserTypeOrganization}).(*User)
test1 := HasOrgOrUserVisible(org, owner)
test2 := HasOrgOrUserVisible(org, user3)
@@ -618,9 +619,9 @@ func TestHasOrgVisibleTypeLimited(t *testing.T) {
}
func TestHasOrgVisibleTypePrivate(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
owner := AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
assert.NoError(t, db.PrepareTestDatabase())
owner := db.AssertExistsAndLoadBean(t, &User{ID: 2}).(*User)
user3 := db.AssertExistsAndLoadBean(t, &User{ID: 3}).(*User)
const newOrgName = "test-org-private"
org := &User{
@@ -628,9 +629,9 @@ func TestHasOrgVisibleTypePrivate(t *testing.T) {
Visibility: structs.VisibleTypePrivate,
}
AssertNotExistsBean(t, &User{Name: org.Name, Type: UserTypeOrganization})
db.AssertNotExistsBean(t, &User{Name: org.Name, Type: UserTypeOrganization})
assert.NoError(t, CreateOrganization(org, owner))
org = AssertExistsAndLoadBean(t,
org = db.AssertExistsAndLoadBean(t,
&User{Name: org.Name, Type: UserTypeOrganization}).(*User)
test1 := HasOrgOrUserVisible(org, owner)
test2 := HasOrgOrUserVisible(org, user3)
@@ -641,7 +642,7 @@ func TestHasOrgVisibleTypePrivate(t *testing.T) {
}
func TestGetUsersWhoCanCreateOrgRepo(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
users, err := GetUsersWhoCanCreateOrgRepo(3)
assert.NoError(t, err)
+18 -13
View File
@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"
@@ -55,6 +56,10 @@ type Project struct {
ClosedDateUnix timeutil.TimeStamp
}
func init() {
db.RegisterModel(new(Project))
}
// GetProjectsConfig retrieves the types of configurations projects could have
func GetProjectsConfig() []ProjectsConfig {
return []ProjectsConfig{
@@ -85,10 +90,10 @@ type ProjectSearchOptions struct {
// GetProjects returns a list of all projects that have been created in the repository
func GetProjects(opts ProjectSearchOptions) ([]*Project, int64, error) {
return getProjects(x, opts)
return getProjects(db.DefaultContext().Engine(), opts)
}
func getProjects(e Engine, opts ProjectSearchOptions) ([]*Project, int64, error) {
func getProjects(e db.Engine, opts ProjectSearchOptions) ([]*Project, int64, error) {
projects := make([]*Project, 0, setting.UI.IssuePagingNum)
var cond builder.Cond = builder.Eq{"repo_id": opts.RepoID}
@@ -138,7 +143,7 @@ func NewProject(p *Project) error {
return errors.New("project type is not valid")
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
@@ -162,10 +167,10 @@ func NewProject(p *Project) error {
// GetProjectByID returns the projects in a repository
func GetProjectByID(id int64) (*Project, error) {
return getProjectByID(x, id)
return getProjectByID(db.DefaultContext().Engine(), id)
}
func getProjectByID(e Engine, id int64) (*Project, error) {
func getProjectByID(e db.Engine, id int64) (*Project, error) {
p := new(Project)
has, err := e.ID(id).Get(p)
@@ -180,10 +185,10 @@ func getProjectByID(e Engine, id int64) (*Project, error) {
// UpdateProject updates project properties
func UpdateProject(p *Project) error {
return updateProject(x, p)
return updateProject(db.DefaultContext().Engine(), p)
}
func updateProject(e Engine, p *Project) error {
func updateProject(e db.Engine, p *Project) error {
_, err := e.ID(p.ID).Cols(
"title",
"description",
@@ -191,7 +196,7 @@ func updateProject(e Engine, p *Project) error {
return err
}
func updateRepositoryProjectCount(e Engine, repoID int64) error {
func updateRepositoryProjectCount(e db.Engine, repoID int64) error {
if _, err := e.Exec(builder.Update(
builder.Eq{
"`num_projects`": builder.Select("count(*)").From("`project`").
@@ -215,7 +220,7 @@ func updateRepositoryProjectCount(e Engine, repoID int64) error {
// ChangeProjectStatusByRepoIDAndID toggles a project between opened and closed
func ChangeProjectStatusByRepoIDAndID(repoID, projectID int64, isClosed bool) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -239,7 +244,7 @@ func ChangeProjectStatusByRepoIDAndID(repoID, projectID int64, isClosed bool) er
// ChangeProjectStatus toggle a project between opened and closed
func ChangeProjectStatus(p *Project, isClosed bool) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -252,7 +257,7 @@ func ChangeProjectStatus(p *Project, isClosed bool) error {
return sess.Commit()
}
func changeProjectStatus(e Engine, p *Project, isClosed bool) error {
func changeProjectStatus(e db.Engine, p *Project, isClosed bool) error {
p.IsClosed = isClosed
p.ClosedDateUnix = timeutil.TimeStampNow()
count, err := e.ID(p.ID).Where("repo_id = ? AND is_closed = ?", p.RepoID, !isClosed).Cols("is_closed", "closed_date_unix").Update(p)
@@ -268,7 +273,7 @@ func changeProjectStatus(e Engine, p *Project, isClosed bool) error {
// DeleteProjectByID deletes a project from a repository.
func DeleteProjectByID(id int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -281,7 +286,7 @@ func DeleteProjectByID(id int64) error {
return sess.Commit()
}
func deleteProjectByID(e Engine, id int64) error {
func deleteProjectByID(e db.Engine, id int64) error {
p, err := getProjectByID(e, id)
if err != nil {
if IsErrProjectNotExist(err) {
+19 -16
View File
@@ -5,6 +5,7 @@
package models
import (
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
@@ -47,6 +48,10 @@ type ProjectBoard struct {
Issues []*Issue `xorm:"-"`
}
func init() {
db.RegisterModel(new(ProjectBoard))
}
// IsProjectBoardTypeValid checks if the project board type is valid
func IsProjectBoardTypeValid(p ProjectBoardType) bool {
switch p {
@@ -95,13 +100,13 @@ func createBoardsForProjectsType(sess *xorm.Session, project *Project) error {
// NewProjectBoard adds a new project board to a given project
func NewProjectBoard(board *ProjectBoard) error {
_, err := x.Insert(board)
_, err := db.DefaultContext().Engine().Insert(board)
return err
}
// DeleteProjectBoardByID removes all issues references to the project board.
func DeleteProjectBoardByID(boardID int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -114,7 +119,7 @@ func DeleteProjectBoardByID(boardID int64) error {
return sess.Commit()
}
func deleteProjectBoardByID(e Engine, boardID int64) error {
func deleteProjectBoardByID(e db.Engine, boardID int64) error {
board, err := getProjectBoard(e, boardID)
if err != nil {
if IsErrProjectBoardNotExist(err) {
@@ -134,17 +139,17 @@ func deleteProjectBoardByID(e Engine, boardID int64) error {
return nil
}
func deleteProjectBoardByProjectID(e Engine, projectID int64) error {
func deleteProjectBoardByProjectID(e db.Engine, projectID int64) error {
_, err := e.Where("project_id=?", projectID).Delete(&ProjectBoard{})
return err
}
// GetProjectBoard fetches the current board of a project
func GetProjectBoard(boardID int64) (*ProjectBoard, error) {
return getProjectBoard(x, boardID)
return getProjectBoard(db.DefaultContext().Engine(), boardID)
}
func getProjectBoard(e Engine, boardID int64) (*ProjectBoard, error) {
func getProjectBoard(e db.Engine, boardID int64) (*ProjectBoard, error) {
board := new(ProjectBoard)
has, err := e.ID(boardID).Get(board)
@@ -159,10 +164,10 @@ func getProjectBoard(e Engine, boardID int64) (*ProjectBoard, error) {
// UpdateProjectBoard updates a project board
func UpdateProjectBoard(board *ProjectBoard) error {
return updateProjectBoard(x, board)
return updateProjectBoard(db.DefaultContext().Engine(), board)
}
func updateProjectBoard(e Engine, board *ProjectBoard) error {
func updateProjectBoard(e db.Engine, board *ProjectBoard) error {
var fieldToUpdate []string
if board.Sorting != 0 {
@@ -181,10 +186,10 @@ func updateProjectBoard(e Engine, board *ProjectBoard) error {
// GetProjectBoards fetches all boards related to a project
// if no default board set, first board is a temporary "Uncategorized" board
func GetProjectBoards(projectID int64) (ProjectBoardList, error) {
return getProjectBoards(x, projectID)
return getProjectBoards(db.DefaultContext().Engine(), projectID)
}
func getProjectBoards(e Engine, projectID int64) ([]*ProjectBoard, error) {
func getProjectBoards(e db.Engine, projectID int64) ([]*ProjectBoard, error) {
boards := make([]*ProjectBoard, 0, 5)
if err := e.Where("project_id=? AND `default`=?", projectID, false).OrderBy("Sorting").Find(&boards); err != nil {
@@ -200,7 +205,7 @@ func getProjectBoards(e Engine, projectID int64) ([]*ProjectBoard, error) {
}
// getDefaultBoard return default board and create a dummy if none exist
func getDefaultBoard(e Engine, projectID int64) (*ProjectBoard, error) {
func getDefaultBoard(e db.Engine, projectID int64) (*ProjectBoard, error) {
var board ProjectBoard
exist, err := e.Where("project_id=? AND `default`=?", projectID, true).Get(&board)
if err != nil {
@@ -221,9 +226,7 @@ func getDefaultBoard(e Engine, projectID int64) (*ProjectBoard, error) {
// SetDefaultBoard represents a board for issues not assigned to one
// if boardID is 0 unset default
func SetDefaultBoard(projectID, boardID int64) error {
sess := x
_, err := sess.Where(builder.Eq{
_, err := db.DefaultContext().Engine().Where(builder.Eq{
"project_id": projectID,
"`default`": true,
}).Cols("`default`").Update(&ProjectBoard{Default: false})
@@ -232,7 +235,7 @@ func SetDefaultBoard(projectID, boardID int64) error {
}
if boardID > 0 {
_, err = sess.ID(boardID).Where(builder.Eq{"project_id": projectID}).
_, err = db.DefaultContext().Engine().ID(boardID).Where(builder.Eq{"project_id": projectID}).
Cols("`default`").Update(&ProjectBoard{Default: true})
}
@@ -290,7 +293,7 @@ func (bs ProjectBoardList) LoadIssues() (IssueList, error) {
// UpdateProjectBoardSorting update project board sorting
func UpdateProjectBoardSorting(bs ProjectBoardList) error {
for i := range bs {
_, err := x.ID(bs[i].ID).Cols(
_, err := db.DefaultContext().Engine().ID(bs[i].ID).Cols(
"sorting",
).Update(bs[i])
if err != nil {
+19 -13
View File
@@ -7,6 +7,8 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"xorm.io/xorm"
)
@@ -20,7 +22,11 @@ type ProjectIssue struct {
ProjectBoardID int64 `xorm:"INDEX"`
}
func deleteProjectIssuesByProjectID(e Engine, projectID int64) error {
func init() {
db.RegisterModel(new(ProjectIssue))
}
func deleteProjectIssuesByProjectID(e db.Engine, projectID int64) error {
_, err := e.Where("project_id=?", projectID).Delete(&ProjectIssue{})
return err
}
@@ -33,10 +39,10 @@ func deleteProjectIssuesByProjectID(e Engine, projectID int64) error {
// LoadProject load the project the issue was assigned to
func (i *Issue) LoadProject() (err error) {
return i.loadProject(x)
return i.loadProject(db.DefaultContext().Engine())
}
func (i *Issue) loadProject(e Engine) (err error) {
func (i *Issue) loadProject(e db.Engine) (err error) {
if i.Project == nil {
var p Project
if _, err = e.Table("project").
@@ -52,10 +58,10 @@ func (i *Issue) loadProject(e Engine) (err error) {
// ProjectID return project id if issue was assigned to one
func (i *Issue) ProjectID() int64 {
return i.projectID(x)
return i.projectID(db.DefaultContext().Engine())
}
func (i *Issue) projectID(e Engine) int64 {
func (i *Issue) projectID(e db.Engine) int64 {
var ip ProjectIssue
has, err := e.Where("issue_id=?", i.ID).Get(&ip)
if err != nil || !has {
@@ -66,10 +72,10 @@ func (i *Issue) projectID(e Engine) int64 {
// ProjectBoardID return project board id if issue was assigned to one
func (i *Issue) ProjectBoardID() int64 {
return i.projectBoardID(x)
return i.projectBoardID(db.DefaultContext().Engine())
}
func (i *Issue) projectBoardID(e Engine) int64 {
func (i *Issue) projectBoardID(e db.Engine) int64 {
var ip ProjectIssue
has, err := e.Where("issue_id=?", i.ID).Get(&ip)
if err != nil || !has {
@@ -87,7 +93,7 @@ func (i *Issue) projectBoardID(e Engine) int64 {
// NumIssues return counter of all issues assigned to a project
func (p *Project) NumIssues() int {
c, err := x.Table("project_issue").
c, err := db.DefaultContext().Engine().Table("project_issue").
Where("project_id=?", p.ID).
GroupBy("issue_id").
Cols("issue_id").
@@ -100,7 +106,7 @@ func (p *Project) NumIssues() int {
// NumClosedIssues return counter of closed issues assigned to a project
func (p *Project) NumClosedIssues() int {
c, err := x.Table("project_issue").
c, err := db.DefaultContext().Engine().Table("project_issue").
Join("INNER", "issue", "project_issue.issue_id=issue.id").
Where("project_issue.project_id=? AND issue.is_closed=?", p.ID, true).
Cols("issue_id").
@@ -113,7 +119,7 @@ func (p *Project) NumClosedIssues() int {
// NumOpenIssues return counter of open issues assigned to a project
func (p *Project) NumOpenIssues() int {
c, err := x.Table("project_issue").
c, err := db.DefaultContext().Engine().Table("project_issue").
Join("INNER", "issue", "project_issue.issue_id=issue.id").
Where("project_issue.project_id=? AND issue.is_closed=?", p.ID, false).Count("issue.id")
if err != nil {
@@ -124,7 +130,7 @@ func (p *Project) NumOpenIssues() int {
// ChangeProjectAssign changes the project associated with an issue
func ChangeProjectAssign(issue *Issue, doer *User, newProjectID int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -177,7 +183,7 @@ func addUpdateIssueProject(e *xorm.Session, issue *Issue, doer *User, newProject
// MoveIssueAcrossProjectBoards move a card from one board to another
func MoveIssueAcrossProjectBoards(issue *Issue, board *ProjectBoard) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -201,7 +207,7 @@ func MoveIssueAcrossProjectBoards(issue *Issue, board *ProjectBoard) error {
return sess.Commit()
}
func (pb *ProjectBoard) removeIssues(e Engine) error {
func (pb *ProjectBoard) removeIssues(e db.Engine) error {
_, err := e.Exec("UPDATE `project_issue` SET project_board_id = 0 WHERE project_board_id = ? ", pb.ID)
return err
}
+3 -2
View File
@@ -7,6 +7,7 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"github.com/stretchr/testify/assert"
@@ -31,7 +32,7 @@ func TestIsProjectTypeValid(t *testing.T) {
}
func TestGetProjects(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
projects, _, err := GetProjects(ProjectSearchOptions{RepoID: 1})
assert.NoError(t, err)
@@ -47,7 +48,7 @@ func TestGetProjects(t *testing.T) {
}
func TestProject(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
project := &Project{
Type: ProjectTypeRepository,
+10 -5
View File
@@ -8,6 +8,7 @@ import (
"regexp"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/timeutil"
@@ -28,21 +29,25 @@ type ProtectedTag struct {
UpdatedUnix timeutil.TimeStamp `xorm:"updated"`
}
func init() {
db.RegisterModel(new(ProtectedTag))
}
// InsertProtectedTag inserts a protected tag to database
func InsertProtectedTag(pt *ProtectedTag) error {
_, err := x.Insert(pt)
_, err := db.DefaultContext().Engine().Insert(pt)
return err
}
// UpdateProtectedTag updates the protected tag
func UpdateProtectedTag(pt *ProtectedTag) error {
_, err := x.ID(pt.ID).AllCols().Update(pt)
_, err := db.DefaultContext().Engine().ID(pt.ID).AllCols().Update(pt)
return err
}
// DeleteProtectedTag deletes a protected tag by ID
func DeleteProtectedTag(pt *ProtectedTag) error {
_, err := x.ID(pt.ID).Delete(&ProtectedTag{})
_, err := db.DefaultContext().Engine().ID(pt.ID).Delete(&ProtectedTag{})
return err
}
@@ -81,13 +86,13 @@ func (pt *ProtectedTag) IsUserAllowed(userID int64) (bool, error) {
// GetProtectedTags gets all protected tags of the repository
func (repo *Repository) GetProtectedTags() ([]*ProtectedTag, error) {
tags := make([]*ProtectedTag, 0)
return tags, x.Find(&tags, &ProtectedTag{RepoID: repo.ID})
return tags, db.DefaultContext().Engine().Find(&tags, &ProtectedTag{RepoID: repo.ID})
}
// GetProtectedTagByID gets the protected tag with the specific id
func GetProtectedTagByID(id int64) (*ProtectedTag, error) {
tag := new(ProtectedTag)
has, err := x.ID(id).Get(tag)
has, err := db.DefaultContext().Engine().ID(id).Get(tag)
if err != nil {
return nil, err
}
+2 -1
View File
@@ -7,11 +7,12 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestIsUserAllowed(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
pt := &ProtectedTag{}
allowed, err := pt.IsUserAllowed(1)
+35 -30
View File
@@ -10,6 +10,7 @@ import (
"io"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/timeutil"
@@ -84,6 +85,10 @@ type PullRequest struct {
Flow PullRequestFlow `xorm:"NOT NULL DEFAULT 0"`
}
func init() {
db.RegisterModel(new(PullRequest))
}
// MustHeadUserName returns the HeadRepo's username if failed return blank
func (pr *PullRequest) MustHeadUserName() string {
if err := pr.LoadHeadRepo(); err != nil {
@@ -101,7 +106,7 @@ func (pr *PullRequest) MustHeadUserName() string {
}
// Note: don't try to get Issue because will end up recursive querying.
func (pr *PullRequest) loadAttributes(e Engine) (err error) {
func (pr *PullRequest) loadAttributes(e db.Engine) (err error) {
if pr.HasMerged && pr.Merger == nil {
pr.Merger, err = getUserByID(e, pr.MergerID)
if IsErrUserNotExist(err) {
@@ -117,10 +122,10 @@ func (pr *PullRequest) loadAttributes(e Engine) (err error) {
// LoadAttributes loads pull request attributes from database
func (pr *PullRequest) LoadAttributes() error {
return pr.loadAttributes(x)
return pr.loadAttributes(db.DefaultContext().Engine())
}
func (pr *PullRequest) loadHeadRepo(e Engine) (err error) {
func (pr *PullRequest) loadHeadRepo(e db.Engine) (err error) {
if !pr.isHeadRepoLoaded && pr.HeadRepo == nil && pr.HeadRepoID > 0 {
if pr.HeadRepoID == pr.BaseRepoID {
if pr.BaseRepo != nil {
@@ -143,15 +148,15 @@ func (pr *PullRequest) loadHeadRepo(e Engine) (err error) {
// LoadHeadRepo loads the head repository
func (pr *PullRequest) LoadHeadRepo() error {
return pr.loadHeadRepo(x)
return pr.loadHeadRepo(db.DefaultContext().Engine())
}
// LoadBaseRepo loads the target repository
func (pr *PullRequest) LoadBaseRepo() error {
return pr.loadBaseRepo(x)
return pr.loadBaseRepo(db.DefaultContext().Engine())
}
func (pr *PullRequest) loadBaseRepo(e Engine) (err error) {
func (pr *PullRequest) loadBaseRepo(e db.Engine) (err error) {
if pr.BaseRepo != nil {
return nil
}
@@ -175,10 +180,10 @@ func (pr *PullRequest) loadBaseRepo(e Engine) (err error) {
// LoadIssue loads issue information from database
func (pr *PullRequest) LoadIssue() (err error) {
return pr.loadIssue(x)
return pr.loadIssue(db.DefaultContext().Engine())
}
func (pr *PullRequest) loadIssue(e Engine) (err error) {
func (pr *PullRequest) loadIssue(e db.Engine) (err error) {
if pr.Issue != nil {
return nil
}
@@ -192,10 +197,10 @@ func (pr *PullRequest) loadIssue(e Engine) (err error) {
// LoadProtectedBranch loads the protected branch of the base branch
func (pr *PullRequest) LoadProtectedBranch() (err error) {
return pr.loadProtectedBranch(x)
return pr.loadProtectedBranch(db.DefaultContext().Engine())
}
func (pr *PullRequest) loadProtectedBranch(e Engine) (err error) {
func (pr *PullRequest) loadProtectedBranch(e db.Engine) (err error) {
if pr.ProtectedBranch == nil {
if pr.BaseRepo == nil {
if pr.BaseRepoID == 0 {
@@ -252,10 +257,10 @@ type ReviewCount struct {
// GetApprovalCounts returns the approval counts by type
// FIXME: Only returns official counts due to double counting of non-official counts
func (pr *PullRequest) GetApprovalCounts() ([]*ReviewCount, error) {
return pr.getApprovalCounts(x)
return pr.getApprovalCounts(db.DefaultContext().Engine())
}
func (pr *PullRequest) getApprovalCounts(e Engine) ([]*ReviewCount, error) {
func (pr *PullRequest) getApprovalCounts(e db.Engine) ([]*ReviewCount, error) {
rCounts := make([]*ReviewCount, 0, 6)
sess := e.Where("issue_id = ?", pr.IssueID)
return rCounts, sess.Select("issue_id, type, count(id) as `count`").Where("official = ? AND dismissed = ?", true, false).GroupBy("issue_id, type").Table("review").Find(&rCounts)
@@ -279,7 +284,7 @@ func (pr *PullRequest) getReviewedByLines(writer io.Writer) error {
return nil
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -388,7 +393,7 @@ func (pr *PullRequest) SetMerged() (bool, error) {
pr.HasMerged = true
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return false, err
@@ -443,14 +448,14 @@ func (pr *PullRequest) SetMerged() (bool, error) {
// NewPullRequest creates new pull request with labels for repository.
func NewPullRequest(repo *Repository, issue *Issue, labelIDs []int64, uuids []string, pr *PullRequest) (err error) {
idx, err := GetNextResourceIndex("issue_index", repo.ID)
idx, err := db.GetNextResourceIndex("issue_index", repo.ID)
if err != nil {
return fmt.Errorf("generate issue index failed: %v", err)
}
issue.Index = idx
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -487,7 +492,7 @@ func NewPullRequest(repo *Repository, issue *Issue, labelIDs []int64, uuids []st
// by given head/base and repo/branch.
func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch string, flow PullRequestFlow) (*PullRequest, error) {
pr := new(PullRequest)
has, err := x.
has, err := db.DefaultContext().Engine().
Where("head_repo_id=? AND head_branch=? AND base_repo_id=? AND base_branch=? AND has_merged=? AND flow = ? AND issue.is_closed=?",
headRepoID, headBranch, baseRepoID, baseBranch, false, flow, false).
Join("INNER", "issue", "issue.id=pull_request.issue_id").
@@ -505,7 +510,7 @@ func GetUnmergedPullRequest(headRepoID, baseRepoID int64, headBranch, baseBranch
// by given head information (repo and branch).
func GetLatestPullRequestByHeadInfo(repoID int64, branch string) (*PullRequest, error) {
pr := new(PullRequest)
has, err := x.
has, err := db.DefaultContext().Engine().
Where("head_repo_id = ? AND head_branch = ? AND flow = ?", repoID, branch, PullRequestFlowGithub).
OrderBy("id DESC").
Get(pr)
@@ -522,7 +527,7 @@ func GetPullRequestByIndex(repoID, index int64) (*PullRequest, error) {
Index: index,
}
has, err := x.Get(pr)
has, err := db.DefaultContext().Engine().Get(pr)
if err != nil {
return nil, err
} else if !has {
@@ -539,7 +544,7 @@ func GetPullRequestByIndex(repoID, index int64) (*PullRequest, error) {
return pr, nil
}
func getPullRequestByID(e Engine, id int64) (*PullRequest, error) {
func getPullRequestByID(e db.Engine, id int64) (*PullRequest, error) {
pr := new(PullRequest)
has, err := e.ID(id).Get(pr)
if err != nil {
@@ -552,13 +557,13 @@ func getPullRequestByID(e Engine, id int64) (*PullRequest, error) {
// GetPullRequestByID returns a pull request by given ID.
func GetPullRequestByID(id int64) (*PullRequest, error) {
return getPullRequestByID(x, id)
return getPullRequestByID(db.DefaultContext().Engine(), id)
}
// GetPullRequestByIssueIDWithNoAttributes returns pull request with no attributes loaded by given issue ID.
func GetPullRequestByIssueIDWithNoAttributes(issueID int64) (*PullRequest, error) {
var pr PullRequest
has, err := x.Where("issue_id = ?", issueID).Get(&pr)
has, err := db.DefaultContext().Engine().Where("issue_id = ?", issueID).Get(&pr)
if err != nil {
return nil, err
}
@@ -568,7 +573,7 @@ func GetPullRequestByIssueIDWithNoAttributes(issueID int64) (*PullRequest, error
return &pr, nil
}
func getPullRequestByIssueID(e Engine, issueID int64) (*PullRequest, error) {
func getPullRequestByIssueID(e db.Engine, issueID int64) (*PullRequest, error) {
pr := &PullRequest{
IssueID: issueID,
}
@@ -586,7 +591,7 @@ func getPullRequestByIssueID(e Engine, issueID int64) (*PullRequest, error) {
func GetAllUnmergedAgitPullRequestByPoster(uid int64) ([]*PullRequest, error) {
pulls := make([]*PullRequest, 0, 10)
err := x.
err := db.DefaultContext().Engine().
Where("has_merged=? AND flow = ? AND issue.is_closed=? AND issue.poster_id=?",
false, PullRequestFlowAGit, false, uid).
Join("INNER", "issue", "issue.id=pull_request.issue_id").
@@ -597,24 +602,24 @@ func GetAllUnmergedAgitPullRequestByPoster(uid int64) ([]*PullRequest, error) {
// GetPullRequestByIssueID returns pull request by given issue ID.
func GetPullRequestByIssueID(issueID int64) (*PullRequest, error) {
return getPullRequestByIssueID(x, issueID)
return getPullRequestByIssueID(db.DefaultContext().Engine(), issueID)
}
// Update updates all fields of pull request.
func (pr *PullRequest) Update() error {
_, err := x.ID(pr.ID).AllCols().Update(pr)
_, err := db.DefaultContext().Engine().ID(pr.ID).AllCols().Update(pr)
return err
}
// UpdateCols updates specific fields of pull request.
func (pr *PullRequest) UpdateCols(cols ...string) error {
_, err := x.ID(pr.ID).Cols(cols...).Update(pr)
_, err := db.DefaultContext().Engine().ID(pr.ID).Cols(cols...).Update(pr)
return err
}
// UpdateColsIfNotMerged updates specific fields of a pull request if it has not been merged
func (pr *PullRequest) UpdateColsIfNotMerged(cols ...string) error {
_, err := x.Where("id = ? AND has_merged = ?", pr.ID, false).Cols(cols...).Update(pr)
_, err := db.DefaultContext().Engine().Where("id = ? AND has_merged = ?", pr.ID, false).Cols(cols...).Update(pr)
return err
}
@@ -660,10 +665,10 @@ func (pr *PullRequest) GetWorkInProgressPrefix() string {
// UpdateCommitDivergence update Divergence of a pull request
func (pr *PullRequest) UpdateCommitDivergence(ahead, behind int) error {
return pr.updateCommitDivergence(x, ahead, behind)
return pr.updateCommitDivergence(db.DefaultContext().Engine(), ahead, behind)
}
func (pr *PullRequest) updateCommitDivergence(e Engine, ahead, behind int) error {
func (pr *PullRequest) updateCommitDivergence(e db.Engine, ahead, behind int) error {
if pr.ID == 0 {
return fmt.Errorf("pull ID is 0")
}
+9 -8
View File
@@ -7,6 +7,7 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
@@ -24,7 +25,7 @@ type PullRequestsOptions struct {
}
func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xorm.Session, error) {
sess := x.Where("pull_request.base_repo_id=?", baseRepoID)
sess := db.DefaultContext().Engine().Where("pull_request.base_repo_id=?", baseRepoID)
sess.Join("INNER", "issue", "pull_request.issue_id = issue.id")
switch opts.State {
@@ -50,7 +51,7 @@ func listPullRequestStatement(baseRepoID int64, opts *PullRequestsOptions) (*xor
// by given head information (repo and branch).
func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequest, error) {
prs := make([]*PullRequest, 0, 2)
return prs, x.
return prs, db.DefaultContext().Engine().
Where("head_repo_id = ? AND head_branch = ? AND has_merged = ? AND issue.is_closed = ? AND flow = ?",
repoID, branch, false, false, PullRequestFlowGithub).
Join("INNER", "issue", "issue.id = pull_request.issue_id").
@@ -61,7 +62,7 @@ func GetUnmergedPullRequestsByHeadInfo(repoID int64, branch string) ([]*PullRequ
// by given base information (repo and branch).
func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequest, error) {
prs := make([]*PullRequest, 0, 2)
return prs, x.
return prs, db.DefaultContext().Engine().
Where("base_repo_id=? AND base_branch=? AND has_merged=? AND issue.is_closed=?",
repoID, branch, false, false).
Join("INNER", "issue", "issue.id=pull_request.issue_id").
@@ -71,7 +72,7 @@ func GetUnmergedPullRequestsByBaseInfo(repoID int64, branch string) ([]*PullRequ
// GetPullRequestIDsByCheckStatus returns all pull requests according the special checking status.
func GetPullRequestIDsByCheckStatus(status PullRequestStatus) ([]int64, error) {
prs := make([]int64, 0, 10)
return prs, x.Table("pull_request").
return prs, db.DefaultContext().Engine().Table("pull_request").
Where("status=?", status).
Cols("pull_request.id").
Find(&prs)
@@ -108,7 +109,7 @@ func PullRequests(baseRepoID int64, opts *PullRequestsOptions) ([]*PullRequest,
// PullRequestList defines a list of pull requests
type PullRequestList []*PullRequest
func (prs PullRequestList) loadAttributes(e Engine) error {
func (prs PullRequestList) loadAttributes(e db.Engine) error {
if len(prs) == 0 {
return nil
}
@@ -143,10 +144,10 @@ func (prs PullRequestList) getIssueIDs() []int64 {
// LoadAttributes load all the prs attributes
func (prs PullRequestList) LoadAttributes() error {
return prs.loadAttributes(x)
return prs.loadAttributes(db.DefaultContext().Engine())
}
func (prs PullRequestList) invalidateCodeComments(e Engine, doer *User, repo *git.Repository, branch string) error {
func (prs PullRequestList) invalidateCodeComments(e db.Engine, doer *User, repo *git.Repository, branch string) error {
if len(prs) == 0 {
return nil
}
@@ -168,5 +169,5 @@ func (prs PullRequestList) invalidateCodeComments(e Engine, doer *User, repo *gi
// InvalidateCodeComments will lookup the prs for code comments which got invalidated by change
func (prs PullRequestList) InvalidateCodeComments(doer *User, repo *git.Repository, branch string) error {
return prs.invalidateCodeComments(x, doer, repo, branch)
return prs.invalidateCodeComments(db.DefaultContext().Engine(), doer, repo, branch)
}
+33 -32
View File
@@ -7,20 +7,21 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestPullRequest_LoadAttributes(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, db.PrepareTestDatabase())
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, pr.LoadAttributes())
assert.NotNil(t, pr.Merger)
assert.Equal(t, pr.MergerID, pr.Merger.ID)
}
func TestPullRequest_LoadIssue(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, db.PrepareTestDatabase())
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, pr.LoadIssue())
assert.NotNil(t, pr.Issue)
assert.Equal(t, int64(2), pr.Issue.ID)
@@ -30,8 +31,8 @@ func TestPullRequest_LoadIssue(t *testing.T) {
}
func TestPullRequest_LoadBaseRepo(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, db.PrepareTestDatabase())
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, pr.LoadBaseRepo())
assert.NotNil(t, pr.BaseRepo)
assert.Equal(t, pr.BaseRepoID, pr.BaseRepo.ID)
@@ -41,8 +42,8 @@ func TestPullRequest_LoadBaseRepo(t *testing.T) {
}
func TestPullRequest_LoadHeadRepo(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, db.PrepareTestDatabase())
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, pr.LoadHeadRepo())
assert.NotNil(t, pr.HeadRepo)
assert.Equal(t, pr.HeadRepoID, pr.HeadRepo.ID)
@@ -53,7 +54,7 @@ func TestPullRequest_LoadHeadRepo(t *testing.T) {
// TODO TestNewPullRequest
func TestPullRequestsNewest(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
prs, count, err := PullRequests(1, &PullRequestsOptions{
ListOptions: ListOptions{
Page: 1,
@@ -72,7 +73,7 @@ func TestPullRequestsNewest(t *testing.T) {
}
func TestPullRequestsOldest(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
prs, count, err := PullRequests(1, &PullRequestsOptions{
ListOptions: ListOptions{
Page: 1,
@@ -91,7 +92,7 @@ func TestPullRequestsOldest(t *testing.T) {
}
func TestGetUnmergedPullRequest(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
pr, err := GetUnmergedPullRequest(1, 1, "branch2", "master", PullRequestFlowGithub)
assert.NoError(t, err)
assert.Equal(t, int64(2), pr.ID)
@@ -102,7 +103,7 @@ func TestGetUnmergedPullRequest(t *testing.T) {
}
func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
prs, err := GetUnmergedPullRequestsByHeadInfo(1, "branch2")
assert.NoError(t, err)
assert.Len(t, prs, 1)
@@ -113,7 +114,7 @@ func TestGetUnmergedPullRequestsByHeadInfo(t *testing.T) {
}
func TestGetUnmergedPullRequestsByBaseInfo(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
prs, err := GetUnmergedPullRequestsByBaseInfo(1, "master")
assert.NoError(t, err)
assert.Len(t, prs, 1)
@@ -124,7 +125,7 @@ func TestGetUnmergedPullRequestsByBaseInfo(t *testing.T) {
}
func TestGetPullRequestByIndex(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
pr, err := GetPullRequestByIndex(1, 2)
assert.NoError(t, err)
assert.Equal(t, int64(1), pr.BaseRepoID)
@@ -136,7 +137,7 @@ func TestGetPullRequestByIndex(t *testing.T) {
}
func TestGetPullRequestByID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
pr, err := GetPullRequestByID(1)
assert.NoError(t, err)
assert.Equal(t, int64(1), pr.ID)
@@ -148,7 +149,7 @@ func TestGetPullRequestByID(t *testing.T) {
}
func TestGetPullRequestByIssueID(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
pr, err := GetPullRequestByIssueID(2)
assert.NoError(t, err)
assert.Equal(t, int64(2), pr.IssueID)
@@ -159,20 +160,20 @@ func TestGetPullRequestByIssueID(t *testing.T) {
}
func TestPullRequest_Update(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.NoError(t, db.PrepareTestDatabase())
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
pr.BaseBranch = "baseBranch"
pr.HeadBranch = "headBranch"
pr.Update()
pr = AssertExistsAndLoadBean(t, &PullRequest{ID: pr.ID}).(*PullRequest)
pr = db.AssertExistsAndLoadBean(t, &PullRequest{ID: pr.ID}).(*PullRequest)
assert.Equal(t, "baseBranch", pr.BaseBranch)
assert.Equal(t, "headBranch", pr.HeadBranch)
CheckConsistencyFor(t, pr)
}
func TestPullRequest_UpdateCols(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
pr := &PullRequest{
ID: 1,
BaseBranch: "baseBranch",
@@ -180,18 +181,18 @@ func TestPullRequest_UpdateCols(t *testing.T) {
}
assert.NoError(t, pr.UpdateCols("head_branch"))
pr = AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
pr = db.AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest)
assert.Equal(t, "master", pr.BaseBranch)
assert.Equal(t, "headBranch", pr.HeadBranch)
CheckConsistencyFor(t, pr)
}
func TestPullRequestList_LoadAttributes(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
prs := []*PullRequest{
AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest),
AssertExistsAndLoadBean(t, &PullRequest{ID: 2}).(*PullRequest),
db.AssertExistsAndLoadBean(t, &PullRequest{ID: 1}).(*PullRequest),
db.AssertExistsAndLoadBean(t, &PullRequest{ID: 2}).(*PullRequest),
}
assert.NoError(t, PullRequestList(prs).LoadAttributes())
for _, pr := range prs {
@@ -205,9 +206,9 @@ func TestPullRequestList_LoadAttributes(t *testing.T) {
// TODO TestAddTestPullRequestTask
func TestPullRequest_IsWorkInProgress(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 2}).(*PullRequest)
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 2}).(*PullRequest)
pr.LoadIssue()
assert.False(t, pr.IsWorkInProgress())
@@ -220,9 +221,9 @@ func TestPullRequest_IsWorkInProgress(t *testing.T) {
}
func TestPullRequest_GetWorkInProgressPrefixWorkInProgress(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 2}).(*PullRequest)
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 2}).(*PullRequest)
pr.LoadIssue()
assert.Empty(t, pr.GetWorkInProgressPrefix())
@@ -236,8 +237,8 @@ func TestPullRequest_GetWorkInProgressPrefixWorkInProgress(t *testing.T) {
}
func TestPullRequest_GetDefaultMergeMessage_InternalTracker(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 2}).(*PullRequest)
assert.NoError(t, db.PrepareTestDatabase())
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 2}).(*PullRequest)
assert.Equal(t, "Merge pull request 'issue3' (#3) from branch2 into master", pr.GetDefaultMergeMessage())
@@ -247,7 +248,7 @@ func TestPullRequest_GetDefaultMergeMessage_InternalTracker(t *testing.T) {
}
func TestPullRequest_GetDefaultMergeMessage_ExternalTracker(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
externalTracker := RepoUnit{
Type: UnitTypeExternalTracker,
@@ -259,7 +260,7 @@ func TestPullRequest_GetDefaultMergeMessage_ExternalTracker(t *testing.T) {
baseRepo.Owner = &User{Name: "testOwner"}
baseRepo.Units = []*RepoUnit{&externalTracker}
pr := AssertExistsAndLoadBean(t, &PullRequest{ID: 2, BaseRepo: baseRepo}).(*PullRequest)
pr := db.AssertExistsAndLoadBean(t, &PullRequest{ID: 2, BaseRepo: baseRepo}).(*PullRequest)
assert.Equal(t, "Merge pull request 'issue3' (!3) from branch2 into master", pr.GetDefaultMergeMessage())
+28 -23
View File
@@ -11,6 +11,7 @@ import (
"sort"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/modules/timeutil"
@@ -44,7 +45,11 @@ type Release struct {
CreatedUnix timeutil.TimeStamp `xorm:"INDEX"`
}
func (r *Release) loadAttributes(e Engine) error {
func init() {
db.RegisterModel(new(Release))
}
func (r *Release) loadAttributes(e db.Engine) error {
var err error
if r.Repo == nil {
r.Repo, err = GetRepositoryByID(r.RepoID)
@@ -67,7 +72,7 @@ func (r *Release) loadAttributes(e Engine) error {
// LoadAttributes load repo and publisher attributes for a release
func (r *Release) LoadAttributes() error {
return r.loadAttributes(x)
return r.loadAttributes(db.DefaultContext().Engine())
}
// APIURL the api url for a release. release must have attributes loaded
@@ -97,31 +102,31 @@ func IsReleaseExist(repoID int64, tagName string) (bool, error) {
return false, nil
}
return x.Get(&Release{RepoID: repoID, LowerTagName: strings.ToLower(tagName)})
return db.DefaultContext().Engine().Get(&Release{RepoID: repoID, LowerTagName: strings.ToLower(tagName)})
}
// InsertRelease inserts a release
func InsertRelease(rel *Release) error {
_, err := x.Insert(rel)
_, err := db.DefaultContext().Engine().Insert(rel)
return err
}
// InsertReleasesContext insert releases
func InsertReleasesContext(ctx DBContext, rels []*Release) error {
_, err := ctx.e.Insert(rels)
func InsertReleasesContext(ctx *db.Context, rels []*Release) error {
_, err := ctx.Engine().Insert(rels)
return err
}
// UpdateRelease updates all columns of a release
func UpdateRelease(ctx DBContext, rel *Release) error {
_, err := ctx.e.ID(rel.ID).AllCols().Update(rel)
func UpdateRelease(ctx *db.Context, rel *Release) error {
_, err := ctx.Engine().ID(rel.ID).AllCols().Update(rel)
return err
}
// AddReleaseAttachments adds a release attachments
func AddReleaseAttachments(ctx DBContext, releaseID int64, attachmentUUIDs []string) (err error) {
func AddReleaseAttachments(ctx *db.Context, releaseID int64, attachmentUUIDs []string) (err error) {
// Check attachments
attachments, err := getAttachmentsByUUIDs(ctx.e, attachmentUUIDs)
attachments, err := getAttachmentsByUUIDs(ctx.Engine(), attachmentUUIDs)
if err != nil {
return fmt.Errorf("GetAttachmentsByUUIDs [uuids: %v]: %v", attachmentUUIDs, err)
}
@@ -132,7 +137,7 @@ func AddReleaseAttachments(ctx DBContext, releaseID int64, attachmentUUIDs []str
}
attachments[i].ReleaseID = releaseID
// No assign value could be 0, so ignore AllCols().
if _, err = ctx.e.ID(attachments[i].ID).Update(attachments[i]); err != nil {
if _, err = ctx.Engine().ID(attachments[i].ID).Update(attachments[i]); err != nil {
return fmt.Errorf("update attachment [%d]: %v", attachments[i].ID, err)
}
}
@@ -150,14 +155,14 @@ func GetRelease(repoID int64, tagName string) (*Release, error) {
}
rel := &Release{RepoID: repoID, LowerTagName: strings.ToLower(tagName)}
_, err = x.Get(rel)
_, err = db.DefaultContext().Engine().Get(rel)
return rel, err
}
// GetReleaseByID returns release with given ID.
func GetReleaseByID(id int64) (*Release, error) {
rel := new(Release)
has, err := x.
has, err := db.DefaultContext().Engine().
ID(id).
Get(rel)
if err != nil {
@@ -203,7 +208,7 @@ func (opts *FindReleasesOptions) toConds(repoID int64) builder.Cond {
// GetReleasesByRepoID returns a list of releases of repository.
func GetReleasesByRepoID(repoID int64, opts FindReleasesOptions) ([]*Release, error) {
sess := x.
sess := db.DefaultContext().Engine().
Desc("created_unix", "id").
Where(opts.toConds(repoID))
@@ -217,7 +222,7 @@ func GetReleasesByRepoID(repoID int64, opts FindReleasesOptions) ([]*Release, er
// CountReleasesByRepoID returns a number of releases matching FindReleaseOptions and RepoID.
func CountReleasesByRepoID(repoID int64, opts FindReleasesOptions) (int64, error) {
return x.Where(opts.toConds(repoID)).Count(new(Release))
return db.DefaultContext().Engine().Where(opts.toConds(repoID)).Count(new(Release))
}
// GetLatestReleaseByRepoID returns the latest release for a repository
@@ -229,7 +234,7 @@ func GetLatestReleaseByRepoID(repoID int64) (*Release, error) {
And(builder.Eq{"is_tag": false})
rel := new(Release)
has, err := x.
has, err := db.DefaultContext().Engine().
Desc("created_unix", "id").
Where(cond).
Get(rel)
@@ -243,8 +248,8 @@ func GetLatestReleaseByRepoID(repoID int64) (*Release, error) {
}
// GetReleasesByRepoIDAndNames returns a list of releases of repository according repoID and tagNames.
func GetReleasesByRepoIDAndNames(ctx DBContext, repoID int64, tagNames []string) (rels []*Release, err error) {
err = ctx.e.
func GetReleasesByRepoIDAndNames(ctx *db.Context, repoID int64, tagNames []string) (rels []*Release, err error) {
err = ctx.Engine().
In("tag_name", tagNames).
Desc("created_unix").
Find(&rels, Release{RepoID: repoID})
@@ -253,7 +258,7 @@ func GetReleasesByRepoIDAndNames(ctx DBContext, repoID int64, tagNames []string)
// GetReleaseCountByRepoID returns the count of releases of repository
func GetReleaseCountByRepoID(repoID int64, opts FindReleasesOptions) (int64, error) {
return x.Where(opts.toConds(repoID)).Count(&Release{})
return db.DefaultContext().Engine().Where(opts.toConds(repoID)).Count(&Release{})
}
type releaseMetaSearch struct {
@@ -276,10 +281,10 @@ func (s releaseMetaSearch) Less(i, j int) bool {
// GetReleaseAttachments retrieves the attachments for releases
func GetReleaseAttachments(rels ...*Release) (err error) {
return getReleaseAttachments(x, rels...)
return getReleaseAttachments(db.DefaultContext().Engine(), rels...)
}
func getReleaseAttachments(e Engine, rels ...*Release) (err error) {
func getReleaseAttachments(e db.Engine, rels ...*Release) (err error) {
if len(rels) == 0 {
return
}
@@ -347,13 +352,13 @@ func SortReleases(rels []*Release) {
// DeleteReleaseByID deletes a release from database by given ID.
func DeleteReleaseByID(id int64) error {
_, err := x.ID(id).Delete(new(Release))
_, err := db.DefaultContext().Engine().ID(id).Delete(new(Release))
return err
}
// UpdateReleasesMigrationsByType updates all migrated repositories' releases from gitServiceType to replace originalAuthorID to posterID
func UpdateReleasesMigrationsByType(gitServiceType structs.GitServiceType, originalAuthorID string, posterID int64) error {
_, err := x.Table("release").
_, err := db.DefaultContext().Engine().Table("release").
Where("repo_id IN (SELECT id FROM repository WHERE original_service_type = ?)", gitServiceType).
And("original_author_id = ?", originalAuthorID).
Update(map[string]interface{}{
+109 -104
View File
@@ -25,6 +25,7 @@ import (
"strings"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/lfs"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/markup"
@@ -251,6 +252,10 @@ type Repository struct {
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func init() {
db.RegisterModel(new(Repository))
}
// SanitizedOriginalURL returns a sanitized OriginalURL
func (repo *Repository) SanitizedOriginalURL() string {
if repo.OriginalURL == "" {
@@ -300,7 +305,7 @@ func (repo *Repository) AfterLoad() {
// It creates a fake object that contains error details
// when error occurs.
func (repo *Repository) MustOwner() *User {
return repo.mustOwner(x)
return repo.mustOwner(db.DefaultContext().Engine())
}
// FullName returns the repository full name
@@ -340,7 +345,7 @@ func (repo *Repository) GetCommitsCountCacheKey(contextName string, isRef bool)
return fmt.Sprintf("commits-count-%d-%s-%s", repo.ID, prefix, contextName)
}
func (repo *Repository) getUnits(e Engine) (err error) {
func (repo *Repository) getUnits(e db.Engine) (err error) {
if repo.Units != nil {
return nil
}
@@ -352,10 +357,10 @@ func (repo *Repository) getUnits(e Engine) (err error) {
// CheckUnitUser check whether user could visit the unit of this repository
func (repo *Repository) CheckUnitUser(user *User, unitType UnitType) bool {
return repo.checkUnitUser(x, user, unitType)
return repo.checkUnitUser(db.DefaultContext().Engine(), user, unitType)
}
func (repo *Repository) checkUnitUser(e Engine, user *User, unitType UnitType) bool {
func (repo *Repository) checkUnitUser(e db.Engine, user *User, unitType UnitType) bool {
if user.IsAdmin {
return true
}
@@ -370,7 +375,7 @@ func (repo *Repository) checkUnitUser(e Engine, user *User, unitType UnitType) b
// UnitEnabled if this repository has the given unit enabled
func (repo *Repository) UnitEnabled(tp UnitType) bool {
if err := repo.getUnits(x); err != nil {
if err := repo.getUnits(db.DefaultContext().Engine()); err != nil {
log.Warn("Error loading repository (ID: %d) units: %s", repo.ID, err.Error())
}
for _, unit := range repo.Units {
@@ -432,10 +437,10 @@ func (repo *Repository) MustGetUnit(tp UnitType) *RepoUnit {
// GetUnit returns a RepoUnit object
func (repo *Repository) GetUnit(tp UnitType) (*RepoUnit, error) {
return repo.getUnit(x, tp)
return repo.getUnit(db.DefaultContext().Engine(), tp)
}
func (repo *Repository) getUnit(e Engine, tp UnitType) (*RepoUnit, error) {
func (repo *Repository) getUnit(e db.Engine, tp UnitType) (*RepoUnit, error) {
if err := repo.getUnits(e); err != nil {
return nil, err
}
@@ -447,7 +452,7 @@ func (repo *Repository) getUnit(e Engine, tp UnitType) (*RepoUnit, error) {
return nil, ErrUnitTypeNotExist{tp}
}
func (repo *Repository) getOwner(e Engine) (err error) {
func (repo *Repository) getOwner(e db.Engine) (err error) {
if repo.Owner != nil {
return nil
}
@@ -458,10 +463,10 @@ func (repo *Repository) getOwner(e Engine) (err error) {
// GetOwner returns the repository owner
func (repo *Repository) GetOwner() error {
return repo.getOwner(x)
return repo.getOwner(db.DefaultContext().Engine())
}
func (repo *Repository) mustOwner(e Engine) *User {
func (repo *Repository) mustOwner(e db.Engine) *User {
if err := repo.getOwner(e); err != nil {
return &User{
Name: "error",
@@ -496,7 +501,7 @@ func (repo *Repository) ComposeMetas() map[string]string {
repo.MustOwner()
if repo.Owner.IsOrganization() {
teams := make([]string, 0, 5)
_ = x.Table("team_repo").
_ = db.DefaultContext().Engine().Table("team_repo").
Join("INNER", "team", "team.id = team_repo.team_id").
Where("team_repo.repo_id = ?", repo.ID).
Select("team.lower_name").
@@ -524,7 +529,7 @@ func (repo *Repository) ComposeDocumentMetas() map[string]string {
return repo.DocumentRenderingMetas
}
func (repo *Repository) getAssignees(e Engine) (_ []*User, err error) {
func (repo *Repository) getAssignees(e db.Engine) (_ []*User, err error) {
if err = repo.getOwner(e); err != nil {
return nil, err
}
@@ -559,10 +564,10 @@ func (repo *Repository) getAssignees(e Engine) (_ []*User, err error) {
// GetAssignees returns all users that have write access and can be assigned to issues
// of the repository,
func (repo *Repository) GetAssignees() (_ []*User, err error) {
return repo.getAssignees(x)
return repo.getAssignees(db.DefaultContext().Engine())
}
func (repo *Repository) getReviewers(e Engine, doerID, posterID int64) ([]*User, error) {
func (repo *Repository) getReviewers(e db.Engine, doerID, posterID int64) ([]*User, error) {
// Get the owner of the repository - this often already pre-cached and if so saves complexity for the following queries
if err := repo.getOwner(e); err != nil {
return nil, err
@@ -611,7 +616,7 @@ func (repo *Repository) getReviewers(e Engine, doerID, posterID int64) ([]*User,
// all repo watchers and all organization members.
// TODO: may be we should have a busy choice for users to block review request to them.
func (repo *Repository) GetReviewers(doerID, posterID int64) ([]*User, error) {
return repo.getReviewers(x, doerID, posterID)
return repo.getReviewers(db.DefaultContext().Engine(), doerID, posterID)
}
// GetReviewerTeams get all teams can be requested to review
@@ -657,10 +662,10 @@ func (repo *Repository) LoadPushMirrors() (err error) {
// returns an error on failure (NOTE: no error is returned for
// non-fork repositories, and BaseRepo will be left untouched)
func (repo *Repository) GetBaseRepo() (err error) {
return repo.getBaseRepo(x)
return repo.getBaseRepo(db.DefaultContext().Engine())
}
func (repo *Repository) getBaseRepo(e Engine) (err error) {
func (repo *Repository) getBaseRepo(e db.Engine) (err error) {
if !repo.IsFork {
return nil
}
@@ -678,10 +683,10 @@ func (repo *Repository) IsGenerated() bool {
// returns an error on failure (NOTE: no error is returned for
// non-generated repositories, and TemplateRepo will be left untouched)
func (repo *Repository) GetTemplateRepo() (err error) {
return repo.getTemplateRepo(x)
return repo.getTemplateRepo(db.DefaultContext().Engine())
}
func (repo *Repository) getTemplateRepo(e Engine) (err error) {
func (repo *Repository) getTemplateRepo(e db.Engine) (err error) {
if !repo.IsGenerated() {
return nil
}
@@ -722,7 +727,7 @@ func (repo *Repository) ComposeCompareURL(oldCommitID, newCommitID string) strin
// UpdateDefaultBranch updates the default branch
func (repo *Repository) UpdateDefaultBranch() error {
_, err := x.ID(repo.ID).Cols("default_branch").Update(repo)
_, err := db.DefaultContext().Engine().ID(repo.ID).Cols("default_branch").Update(repo)
return err
}
@@ -731,7 +736,7 @@ func (repo *Repository) IsOwnedBy(userID int64) bool {
return repo.OwnerID == userID
}
func (repo *Repository) updateSize(e Engine) error {
func (repo *Repository) updateSize(e db.Engine) error {
size, err := util.GetDirectorySize(repo.RepoPath())
if err != nil {
return fmt.Errorf("updateSize: %v", err)
@@ -748,8 +753,8 @@ func (repo *Repository) updateSize(e Engine) error {
}
// UpdateSize updates the repository size, calculating it using util.GetDirectorySize
func (repo *Repository) UpdateSize(ctx DBContext) error {
return repo.updateSize(ctx.e)
func (repo *Repository) UpdateSize(ctx *db.Context) error {
return repo.updateSize(ctx.Engine())
}
// CanUserFork returns true if specified user can fork repository.
@@ -810,12 +815,12 @@ func (repo *Repository) CanEnableEditor() bool {
// GetReaders returns all users that have explicit read access or higher to the repository.
func (repo *Repository) GetReaders() (_ []*User, err error) {
return repo.getUsersWithAccessMode(x, AccessModeRead)
return repo.getUsersWithAccessMode(db.DefaultContext().Engine(), AccessModeRead)
}
// GetWriters returns all users that have write access to the repository.
func (repo *Repository) GetWriters() (_ []*User, err error) {
return repo.getUsersWithAccessMode(x, AccessModeWrite)
return repo.getUsersWithAccessMode(db.DefaultContext().Engine(), AccessModeWrite)
}
// IsReader returns true if user has explicit read access or higher to the repository.
@@ -823,11 +828,11 @@ func (repo *Repository) IsReader(userID int64) (bool, error) {
if repo.OwnerID == userID {
return true, nil
}
return x.Where("repo_id = ? AND user_id = ? AND mode >= ?", repo.ID, userID, AccessModeRead).Get(&Access{})
return db.DefaultContext().Engine().Where("repo_id = ? AND user_id = ? AND mode >= ?", repo.ID, userID, AccessModeRead).Get(&Access{})
}
// getUsersWithAccessMode returns users that have at least given access mode to the repository.
func (repo *Repository) getUsersWithAccessMode(e Engine, mode AccessMode) (_ []*User, err error) {
func (repo *Repository) getUsersWithAccessMode(e db.Engine, mode AccessMode) (_ []*User, err error) {
if err = repo.getOwner(e); err != nil {
return nil, err
}
@@ -872,10 +877,10 @@ func (repo *Repository) DescriptionHTML() template.HTML {
// ReadBy sets repo to be visited by given user.
func (repo *Repository) ReadBy(userID int64) error {
return setRepoNotificationStatusReadIfUnread(x, userID, repo.ID)
return setRepoNotificationStatusReadIfUnread(db.DefaultContext().Engine(), userID, repo.ID)
}
func isRepositoryExist(e Engine, u *User, repoName string) (bool, error) {
func isRepositoryExist(e db.Engine, u *User, repoName string) (bool, error) {
has, err := e.Get(&Repository{
OwnerID: u.ID,
LowerName: strings.ToLower(repoName),
@@ -889,7 +894,7 @@ func isRepositoryExist(e Engine, u *User, repoName string) (bool, error) {
// IsRepositoryExist returns true if the repository with given name under user has already existed.
func IsRepositoryExist(u *User, repoName string) (bool, error) {
return isRepositoryExist(x, u, repoName)
return isRepositoryExist(db.DefaultContext().Engine(), u, repoName)
}
// CloneLink represents different types of clone URLs of repository.
@@ -951,7 +956,7 @@ func CheckCreateRepository(doer, u *User, name string, overwriteOrAdopt bool) er
return err
}
has, err := isRepositoryExist(x, u, name)
has, err := isRepositoryExist(db.DefaultContext().Engine(), u, name)
if err != nil {
return fmt.Errorf("IsRepositoryExist: %v", err)
} else if has {
@@ -1040,12 +1045,12 @@ func IsUsableRepoName(name string) error {
}
// CreateRepository creates a repository for the user/organization.
func CreateRepository(ctx DBContext, doer, u *User, repo *Repository, overwriteOrAdopt bool) (err error) {
func CreateRepository(ctx *db.Context, doer, u *User, repo *Repository, overwriteOrAdopt bool) (err error) {
if err = IsUsableRepoName(repo.Name); err != nil {
return err
}
has, err := isRepositoryExist(ctx.e, u, repo.Name)
has, err := isRepositoryExist(ctx.Engine(), u, repo.Name)
if err != nil {
return fmt.Errorf("IsRepositoryExist: %v", err)
} else if has {
@@ -1066,10 +1071,10 @@ func CreateRepository(ctx DBContext, doer, u *User, repo *Repository, overwriteO
}
}
if _, err = ctx.e.Insert(repo); err != nil {
if _, err = ctx.Engine().Insert(repo); err != nil {
return err
}
if err = deleteRepoRedirect(ctx.e, u.ID, repo.Name); err != nil {
if err = deleteRepoRedirect(ctx.Engine(), u.ID, repo.Name); err != nil {
return err
}
@@ -1100,46 +1105,46 @@ func CreateRepository(ctx DBContext, doer, u *User, repo *Repository, overwriteO
}
}
if _, err = ctx.e.Insert(&units); err != nil {
if _, err = ctx.Engine().Insert(&units); err != nil {
return err
}
// Remember visibility preference.
u.LastRepoVisibility = repo.IsPrivate
if err = updateUserCols(ctx.e, u, "last_repo_visibility"); err != nil {
if err = updateUserCols(ctx.Engine(), u, "last_repo_visibility"); err != nil {
return fmt.Errorf("updateUser: %v", err)
}
if _, err = ctx.e.Incr("num_repos").ID(u.ID).Update(new(User)); err != nil {
if _, err = ctx.Engine().Incr("num_repos").ID(u.ID).Update(new(User)); err != nil {
return fmt.Errorf("increment user total_repos: %v", err)
}
u.NumRepos++
// Give access to all members in teams with access to all repositories.
if u.IsOrganization() {
if err := u.loadTeams(ctx.e); err != nil {
if err := u.loadTeams(ctx.Engine()); err != nil {
return fmt.Errorf("loadTeams: %v", err)
}
for _, t := range u.Teams {
if t.IncludesAllRepositories {
if err := t.addRepository(ctx.e, repo); err != nil {
if err := t.addRepository(ctx.Engine(), repo); err != nil {
return fmt.Errorf("addRepository: %v", err)
}
}
}
if isAdmin, err := isUserRepoAdmin(ctx.e, repo, doer); err != nil {
if isAdmin, err := isUserRepoAdmin(ctx.Engine(), repo, doer); err != nil {
return fmt.Errorf("isUserRepoAdmin: %v", err)
} else if !isAdmin {
// Make creator repo admin if it wan't assigned automatically
if err = repo.addCollaborator(ctx.e, doer); err != nil {
if err = repo.addCollaborator(ctx.Engine(), doer); err != nil {
return fmt.Errorf("AddCollaborator: %v", err)
}
if err = repo.changeCollaborationAccessMode(ctx.e, doer.ID, AccessModeAdmin); err != nil {
if err = repo.changeCollaborationAccessMode(ctx.Engine(), doer.ID, AccessModeAdmin); err != nil {
return fmt.Errorf("ChangeCollaborationAccessMode: %v", err)
}
}
} else if err = repo.recalculateAccesses(ctx.e); err != nil {
} else if err = repo.recalculateAccesses(ctx.Engine()); err != nil {
// Organization automatically called this in addRepository method.
return fmt.Errorf("recalculateAccesses: %v", err)
}
@@ -1155,12 +1160,12 @@ func CreateRepository(ctx DBContext, doer, u *User, repo *Repository, overwriteO
}
if setting.Service.AutoWatchNewRepos {
if err = watchRepo(ctx.e, doer.ID, repo.ID, true); err != nil {
if err = watchRepo(ctx.Engine(), doer.ID, repo.ID, true); err != nil {
return fmt.Errorf("watchRepo: %v", err)
}
}
if err = copyDefaultWebhooksToRepo(ctx.e, repo.ID); err != nil {
if err = copyDefaultWebhooksToRepo(ctx.Engine(), repo.ID); err != nil {
return fmt.Errorf("copyDefaultWebhooksToRepo: %v", err)
}
@@ -1168,7 +1173,7 @@ func CreateRepository(ctx DBContext, doer, u *User, repo *Repository, overwriteO
}
func countRepositories(userID int64, private bool) int64 {
sess := x.Where("id > 0")
sess := db.DefaultContext().Engine().Where("id > 0")
if userID > 0 {
sess.And("owner_id = ?", userID)
@@ -1204,14 +1209,14 @@ func RepoPath(userName, repoName string) string {
}
// IncrementRepoForkNum increment repository fork number
func IncrementRepoForkNum(ctx DBContext, repoID int64) error {
_, err := ctx.e.Exec("UPDATE `repository` SET num_forks=num_forks+1 WHERE id=?", repoID)
func IncrementRepoForkNum(ctx *db.Context, repoID int64) error {
_, err := ctx.Engine().Exec("UPDATE `repository` SET num_forks=num_forks+1 WHERE id=?", repoID)
return err
}
// DecrementRepoForkNum decrement repository fork number
func DecrementRepoForkNum(ctx DBContext, repoID int64) error {
_, err := ctx.e.Exec("UPDATE `repository` SET num_forks=num_forks-1 WHERE id=?", repoID)
func DecrementRepoForkNum(ctx *db.Context, repoID int64) error {
_, err := ctx.Engine().Exec("UPDATE `repository` SET num_forks=num_forks-1 WHERE id=?", repoID)
return err
}
@@ -1251,7 +1256,7 @@ func ChangeRepositoryName(doer *User, repo *Repository, newRepoName string) (err
}
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return fmt.Errorf("sess.Begin: %v", err)
@@ -1264,7 +1269,7 @@ func ChangeRepositoryName(doer *User, repo *Repository, newRepoName string) (err
return sess.Commit()
}
func getRepositoriesByForkID(e Engine, forkID int64) ([]*Repository, error) {
func getRepositoriesByForkID(e db.Engine, forkID int64) ([]*Repository, error) {
repos := make([]*Repository, 0, 10)
return repos, e.
Where("fork_id=?", forkID).
@@ -1273,10 +1278,10 @@ func getRepositoriesByForkID(e Engine, forkID int64) ([]*Repository, error) {
// GetRepositoriesByForkID returns all repositories with given fork ID.
func GetRepositoriesByForkID(forkID int64) ([]*Repository, error) {
return getRepositoriesByForkID(x, forkID)
return getRepositoriesByForkID(db.DefaultContext().Engine(), forkID)
}
func updateRepository(e Engine, repo *Repository, visibilityChanged bool) (err error) {
func updateRepository(e db.Engine, repo *Repository, visibilityChanged bool) (err error) {
repo.LowerName = strings.ToLower(repo.Name)
if utf8.RuneCountInString(repo.Description) > 255 {
@@ -1351,13 +1356,13 @@ func updateRepository(e Engine, repo *Repository, visibilityChanged bool) (err e
}
// UpdateRepositoryCtx updates a repository with db context
func UpdateRepositoryCtx(ctx DBContext, repo *Repository, visibilityChanged bool) error {
return updateRepository(ctx.e, repo, visibilityChanged)
func UpdateRepositoryCtx(ctx *db.Context, repo *Repository, visibilityChanged bool) error {
return updateRepository(ctx.Engine(), repo, visibilityChanged)
}
// UpdateRepository updates a repository
func UpdateRepository(repo *Repository, visibilityChanged bool) (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -1375,7 +1380,7 @@ func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error {
if ownerID == 0 {
return nil
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -1392,13 +1397,13 @@ func UpdateRepositoryOwnerNames(ownerID int64, ownerName string) error {
// UpdateRepositoryUpdatedTime updates a repository's updated time
func UpdateRepositoryUpdatedTime(repoID int64, updateTime time.Time) error {
_, err := x.Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID)
_, err := db.DefaultContext().Engine().Exec("UPDATE repository SET updated_unix = ? WHERE id = ?", updateTime.Unix(), repoID)
return err
}
// UpdateRepositoryUnits updates a repository's units
func UpdateRepositoryUnits(repo *Repository, units []RepoUnit, deleteUnitTypes []UnitType) (err error) {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -1425,7 +1430,7 @@ func UpdateRepositoryUnits(repo *Repository, units []RepoUnit, deleteUnitTypes [
// DeleteRepository deletes a repository for a user or organization.
// make sure if you call this func to close open sessions (sqlite will otherwise get a deadlock)
func DeleteRepository(doer *User, uid, repoID int64) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -1533,7 +1538,7 @@ func DeleteRepository(doer *User, uid, repoID int64) error {
}
// Delete issue index
if err := deleteResouceIndex(sess, "issue_index", repoID); err != nil {
if err := db.DeleteResouceIndex(sess, "issue_index", repoID); err != nil {
return err
}
@@ -1641,21 +1646,21 @@ func DeleteRepository(doer *User, uid, repoID int64) error {
// Remove repository files.
repoPath := repo.RepoPath()
removeAllWithNotice(x, "Delete repository files", repoPath)
removeAllWithNotice(db.DefaultContext().Engine(), "Delete repository files", repoPath)
// Remove wiki files
if repo.HasWiki() {
removeAllWithNotice(x, "Delete repository wiki", repo.WikiPath())
removeAllWithNotice(db.DefaultContext().Engine(), "Delete repository wiki", repo.WikiPath())
}
// Remove archives
for i := range archivePaths {
removeStorageWithNotice(x, storage.RepoArchives, "Delete repo archive file", archivePaths[i])
removeStorageWithNotice(db.DefaultContext().Engine(), storage.RepoArchives, "Delete repo archive file", archivePaths[i])
}
// Remove lfs objects
for i := range lfsPaths {
removeStorageWithNotice(x, storage.LFS, "Delete orphaned LFS file", lfsPaths[i])
removeStorageWithNotice(db.DefaultContext().Engine(), storage.LFS, "Delete orphaned LFS file", lfsPaths[i])
}
// Remove issue attachment files.
@@ -1684,10 +1689,10 @@ func DeleteRepository(doer *User, uid, repoID int64) error {
// GetRepositoryByOwnerAndName returns the repository by given ownername and reponame.
func GetRepositoryByOwnerAndName(ownerName, repoName string) (*Repository, error) {
return getRepositoryByOwnerAndName(x, ownerName, repoName)
return getRepositoryByOwnerAndName(db.DefaultContext().Engine(), ownerName, repoName)
}
func getRepositoryByOwnerAndName(e Engine, ownerName, repoName string) (*Repository, error) {
func getRepositoryByOwnerAndName(e db.Engine, ownerName, repoName string) (*Repository, error) {
var repo Repository
has, err := e.Table("repository").Select("repository.*").
Join("INNER", "`user`", "`user`.id = repository.owner_id").
@@ -1708,7 +1713,7 @@ func GetRepositoryByName(ownerID int64, name string) (*Repository, error) {
OwnerID: ownerID,
LowerName: strings.ToLower(name),
}
has, err := x.Get(repo)
has, err := db.DefaultContext().Engine().Get(repo)
if err != nil {
return nil, err
} else if !has {
@@ -1717,7 +1722,7 @@ func GetRepositoryByName(ownerID int64, name string) (*Repository, error) {
return repo, err
}
func getRepositoryByID(e Engine, id int64) (*Repository, error) {
func getRepositoryByID(e db.Engine, id int64) (*Repository, error) {
repo := new(Repository)
has, err := e.ID(id).Get(repo)
if err != nil {
@@ -1730,18 +1735,18 @@ func getRepositoryByID(e Engine, id int64) (*Repository, error) {
// GetRepositoryByID returns the repository by given id if exists.
func GetRepositoryByID(id int64) (*Repository, error) {
return getRepositoryByID(x, id)
return getRepositoryByID(db.DefaultContext().Engine(), id)
}
// GetRepositoryByIDCtx returns the repository by given id if exists.
func GetRepositoryByIDCtx(ctx DBContext, id int64) (*Repository, error) {
return getRepositoryByID(ctx.e, id)
func GetRepositoryByIDCtx(ctx *db.Context, id int64) (*Repository, error) {
return getRepositoryByID(ctx.Engine(), id)
}
// GetRepositoriesMapByIDs returns the repositories by given id slice.
func GetRepositoriesMapByIDs(ids []int64) (map[int64]*Repository, error) {
repos := make(map[int64]*Repository, len(ids))
return repos, x.In("id", ids).Find(&repos)
return repos, db.DefaultContext().Engine().In("id", ids).Find(&repos)
}
// GetUserRepositories returns a list of repositories of given user.
@@ -1760,7 +1765,7 @@ func GetUserRepositories(opts *SearchRepoOptions) ([]*Repository, int64, error)
cond = cond.And(builder.In("lower_name", opts.LowerNames))
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
count, err := sess.Where(cond).Count(new(Repository))
@@ -1776,37 +1781,37 @@ func GetUserRepositories(opts *SearchRepoOptions) ([]*Repository, int64, error)
// GetUserMirrorRepositories returns a list of mirror repositories of given user.
func GetUserMirrorRepositories(userID int64) ([]*Repository, error) {
repos := make([]*Repository, 0, 10)
return repos, x.
return repos, db.DefaultContext().Engine().
Where("owner_id = ?", userID).
And("is_mirror = ?", true).
Find(&repos)
}
func getRepositoryCount(e Engine, u *User) (int64, error) {
func getRepositoryCount(e db.Engine, u *User) (int64, error) {
return e.Count(&Repository{OwnerID: u.ID})
}
func getPublicRepositoryCount(e Engine, u *User) (int64, error) {
func getPublicRepositoryCount(e db.Engine, u *User) (int64, error) {
return e.Where("is_private = ?", false).Count(&Repository{OwnerID: u.ID})
}
func getPrivateRepositoryCount(e Engine, u *User) (int64, error) {
func getPrivateRepositoryCount(e db.Engine, u *User) (int64, error) {
return e.Where("is_private = ?", true).Count(&Repository{OwnerID: u.ID})
}
// GetRepositoryCount returns the total number of repositories of user.
func GetRepositoryCount(u *User) (int64, error) {
return getRepositoryCount(x, u)
return getRepositoryCount(db.DefaultContext().Engine(), u)
}
// GetPublicRepositoryCount returns the total number of public repositories of user.
func GetPublicRepositoryCount(u *User) (int64, error) {
return getPublicRepositoryCount(x, u)
return getPublicRepositoryCount(db.DefaultContext().Engine(), u)
}
// GetPrivateRepositoryCount returns the total number of private repositories of user.
func GetPrivateRepositoryCount(u *User) (int64, error) {
return getPrivateRepositoryCount(x, u)
return getPrivateRepositoryCount(db.DefaultContext().Engine(), u)
}
// DeleteOldRepositoryArchives deletes old repository archives.
@@ -1815,7 +1820,7 @@ func DeleteOldRepositoryArchives(ctx context.Context, olderThan time.Duration) e
for {
var archivers []RepoArchiver
err := x.Where("created_unix < ?", time.Now().Add(-olderThan).Unix()).
err := db.DefaultContext().Engine().Where("created_unix < ?", time.Now().Add(-olderThan).Unix()).
Asc("created_unix").
Limit(100).
Find(&archivers)
@@ -1845,7 +1850,7 @@ func deleteOldRepoArchiver(ctx context.Context, archiver *RepoArchiver) error {
if err != nil {
return err
}
_, err = x.ID(archiver.ID).Delete(delRepoArchiver)
_, err = db.DefaultContext().Engine().ID(archiver.ID).Delete(delRepoArchiver)
if err != nil {
return err
}
@@ -1861,7 +1866,7 @@ type repoChecker struct {
}
func repoStatsCheck(ctx context.Context, checker *repoChecker) {
results, err := x.Query(checker.querySQL)
results, err := db.DefaultContext().Engine().Query(checker.querySQL)
if err != nil {
log.Error("Select %s: %v", checker.desc, err)
return
@@ -1875,7 +1880,7 @@ func repoStatsCheck(ctx context.Context, checker *repoChecker) {
default:
}
log.Trace("Updating %s: %d", checker.desc, id)
_, err = x.Exec(checker.correctSQL, id, id)
_, err = db.DefaultContext().Engine().Exec(checker.correctSQL, id, id)
if err != nil {
log.Error("Update %s[%d]: %v", checker.desc, id, err)
}
@@ -1930,7 +1935,7 @@ func CheckRepoStats(ctx context.Context) error {
// ***** START: Repository.NumClosedIssues *****
desc := "repository count 'num_closed_issues'"
results, err := x.Query("SELECT repo.id FROM `repository` repo WHERE repo.num_closed_issues!=(SELECT COUNT(*) FROM `issue` WHERE repo_id=repo.id AND is_closed=? AND is_pull=?)", true, false)
results, err := db.DefaultContext().Engine().Query("SELECT repo.id FROM `repository` repo WHERE repo.num_closed_issues!=(SELECT COUNT(*) FROM `issue` WHERE repo_id=repo.id AND is_closed=? AND is_pull=?)", true, false)
if err != nil {
log.Error("Select %s: %v", desc, err)
} else {
@@ -1943,7 +1948,7 @@ func CheckRepoStats(ctx context.Context) error {
default:
}
log.Trace("Updating %s: %d", desc, id)
_, err = x.Exec("UPDATE `repository` SET num_closed_issues=(SELECT COUNT(*) FROM `issue` WHERE repo_id=? AND is_closed=? AND is_pull=?) WHERE id=?", id, true, false, id)
_, err = db.DefaultContext().Engine().Exec("UPDATE `repository` SET num_closed_issues=(SELECT COUNT(*) FROM `issue` WHERE repo_id=? AND is_closed=? AND is_pull=?) WHERE id=?", id, true, false, id)
if err != nil {
log.Error("Update %s[%d]: %v", desc, id, err)
}
@@ -1953,7 +1958,7 @@ func CheckRepoStats(ctx context.Context) error {
// ***** START: Repository.NumClosedPulls *****
desc = "repository count 'num_closed_pulls'"
results, err = x.Query("SELECT repo.id FROM `repository` repo WHERE repo.num_closed_pulls!=(SELECT COUNT(*) FROM `issue` WHERE repo_id=repo.id AND is_closed=? AND is_pull=?)", true, true)
results, err = db.DefaultContext().Engine().Query("SELECT repo.id FROM `repository` repo WHERE repo.num_closed_pulls!=(SELECT COUNT(*) FROM `issue` WHERE repo_id=repo.id AND is_closed=? AND is_pull=?)", true, true)
if err != nil {
log.Error("Select %s: %v", desc, err)
} else {
@@ -1966,7 +1971,7 @@ func CheckRepoStats(ctx context.Context) error {
default:
}
log.Trace("Updating %s: %d", desc, id)
_, err = x.Exec("UPDATE `repository` SET num_closed_pulls=(SELECT COUNT(*) FROM `issue` WHERE repo_id=? AND is_closed=? AND is_pull=?) WHERE id=?", id, true, true, id)
_, err = db.DefaultContext().Engine().Exec("UPDATE `repository` SET num_closed_pulls=(SELECT COUNT(*) FROM `issue` WHERE repo_id=? AND is_closed=? AND is_pull=?) WHERE id=?", id, true, true, id)
if err != nil {
log.Error("Update %s[%d]: %v", desc, id, err)
}
@@ -1976,7 +1981,7 @@ func CheckRepoStats(ctx context.Context) error {
// FIXME: use checker when stop supporting old fork repo format.
// ***** START: Repository.NumForks *****
results, err = x.Query("SELECT repo.id FROM `repository` repo WHERE repo.num_forks!=(SELECT COUNT(*) FROM `repository` WHERE fork_id=repo.id)")
results, err = db.DefaultContext().Engine().Query("SELECT repo.id FROM `repository` repo WHERE repo.num_forks!=(SELECT COUNT(*) FROM `repository` WHERE fork_id=repo.id)")
if err != nil {
log.Error("Select repository count 'num_forks': %v", err)
} else {
@@ -1996,7 +2001,7 @@ func CheckRepoStats(ctx context.Context) error {
continue
}
rawResult, err := x.Query("SELECT COUNT(*) FROM `repository` WHERE fork_id=?", repo.ID)
rawResult, err := db.DefaultContext().Engine().Query("SELECT COUNT(*) FROM `repository` WHERE fork_id=?", repo.ID)
if err != nil {
log.Error("Select count of forks[%d]: %v", repo.ID, err)
continue
@@ -2016,7 +2021,7 @@ func CheckRepoStats(ctx context.Context) error {
// SetArchiveRepoState sets if a repo is archived
func (repo *Repository) SetArchiveRepoState(isArchived bool) (err error) {
repo.IsArchived = isArchived
_, err = x.Where("id = ?", repo.ID).Cols("is_archived").NoAutoTime().Update(repo)
_, err = db.DefaultContext().Engine().Where("id = ?", repo.ID).Cols("is_archived").NoAutoTime().Update(repo)
return
}
@@ -2030,23 +2035,23 @@ func (repo *Repository) SetArchiveRepoState(isArchived bool) (err error) {
// HasForkedRepo checks if given user has already forked a repository with given ID.
func HasForkedRepo(ownerID, repoID int64) (*Repository, bool) {
repo := new(Repository)
has, _ := x.
has, _ := db.DefaultContext().Engine().
Where("owner_id=? AND fork_id=?", ownerID, repoID).
Get(repo)
return repo, has
}
// CopyLFS copies LFS data from one repo to another
func CopyLFS(ctx DBContext, newRepo, oldRepo *Repository) error {
func CopyLFS(ctx *db.Context, newRepo, oldRepo *Repository) error {
var lfsObjects []*LFSMetaObject
if err := ctx.e.Where("repository_id=?", oldRepo.ID).Find(&lfsObjects); err != nil {
if err := ctx.Engine().Where("repository_id=?", oldRepo.ID).Find(&lfsObjects); err != nil {
return err
}
for _, v := range lfsObjects {
v.ID = 0
v.RepositoryID = newRepo.ID
if _, err := ctx.e.Insert(v); err != nil {
if _, err := ctx.Engine().Insert(v); err != nil {
return err
}
}
@@ -2058,7 +2063,7 @@ func CopyLFS(ctx DBContext, newRepo, oldRepo *Repository) error {
func (repo *Repository) GetForks(listOptions ListOptions) ([]*Repository, error) {
if listOptions.Page == 0 {
forks := make([]*Repository, 0, repo.NumForks)
return forks, x.Find(&forks, &Repository{ForkID: repo.ID})
return forks, db.DefaultContext().Engine().Find(&forks, &Repository{ForkID: repo.ID})
}
sess := getPaginatedSession(&listOptions)
@@ -2069,7 +2074,7 @@ func (repo *Repository) GetForks(listOptions ListOptions) ([]*Repository, error)
// GetUserFork return user forked repository from this repository, if not forked return nil
func (repo *Repository) GetUserFork(userID int64) (*Repository, error) {
var forkedRepo Repository
has, err := x.Where("fork_id = ?", repo.ID).And("owner_id = ?", userID).Get(&forkedRepo)
has, err := db.DefaultContext().Engine().Where("fork_id = ?", repo.ID).And("owner_id = ?", userID).Get(&forkedRepo)
if err != nil {
return nil, err
}
@@ -2105,14 +2110,14 @@ func (repo *Repository) GetTreePathLock(treePath string) (*LFSLock, error) {
return nil, nil
}
func updateRepositoryCols(e Engine, repo *Repository, cols ...string) error {
func updateRepositoryCols(e db.Engine, repo *Repository, cols ...string) error {
_, err := e.ID(repo.ID).Cols(cols...).Update(repo)
return err
}
// UpdateRepositoryCols updates repository's columns
func UpdateRepositoryCols(repo *Repository, cols ...string) error {
return updateRepositoryCols(x, repo, cols...)
return updateRepositoryCols(db.DefaultContext().Engine(), repo, cols...)
}
// GetTrustModel will get the TrustModel for the repo or the default trust model
@@ -2130,7 +2135,7 @@ func (repo *Repository) GetTrustModel() TrustModelType {
// DoctorUserStarNum recalculate Stars number for all user
func DoctorUserStarNum() (err error) {
const batchSize = 100
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
for start := 0; ; start += batchSize {
@@ -2168,7 +2173,7 @@ func IterateRepository(f func(repo *Repository) error) error {
batchSize := setting.Database.IterateBufferSize
for {
repos := make([]*Repository, 0, batchSize)
if err := x.Limit(batchSize, start).Find(&repos); err != nil {
if err := db.DefaultContext().Engine().Limit(batchSize, start).Find(&repos); err != nil {
return err
}
if len(repos) == 0 {
+4 -3
View File
@@ -9,6 +9,7 @@ import (
"sort"
"time"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/git"
"xorm.io/xorm"
@@ -245,7 +246,7 @@ func (stats *ActivityStats) FillPullRequests(repoID int64, fromTime time.Time) e
}
func pullRequestsForActivityStatement(repoID int64, fromTime time.Time, merged bool) *xorm.Session {
sess := x.Where("pull_request.base_repo_id=?", repoID).
sess := db.DefaultContext().Engine().Where("pull_request.base_repo_id=?", repoID).
Join("INNER", "issue", "pull_request.issue_id = issue.id")
if merged {
@@ -313,7 +314,7 @@ func (stats *ActivityStats) FillUnresolvedIssues(repoID int64, fromTime time.Tim
}
func issuesForActivityStatement(repoID int64, fromTime time.Time, closed, unresolved bool) *xorm.Session {
sess := x.Where("issue.repo_id = ?", repoID).
sess := db.DefaultContext().Engine().Where("issue.repo_id = ?", repoID).
And("issue.is_closed = ?", closed)
if !unresolved {
@@ -355,7 +356,7 @@ func (stats *ActivityStats) FillReleases(repoID int64, fromTime time.Time) error
}
func releasesForActivityStatement(repoID int64, fromTime time.Time) *xorm.Session {
return x.Where("release.repo_id = ?", repoID).
return db.DefaultContext().Engine().Where("release.repo_id = ?", repoID).
And("release.is_draft = ?", false).
And("release.created_unix >= ?", fromTime.Unix())
}
+13 -8
View File
@@ -7,6 +7,7 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/timeutil"
)
@@ -31,6 +32,10 @@ type RepoArchiver struct {
CreatedUnix timeutil.TimeStamp `xorm:"INDEX NOT NULL created"`
}
func init() {
db.RegisterModel(new(RepoArchiver))
}
// LoadRepo loads repository
func (archiver *RepoArchiver) LoadRepo() (*Repository, error) {
if archiver.Repo != nil {
@@ -38,7 +43,7 @@ func (archiver *RepoArchiver) LoadRepo() (*Repository, error) {
}
var repo Repository
has, err := x.ID(archiver.RepoID).Get(&repo)
has, err := db.DefaultContext().Engine().ID(archiver.RepoID).Get(&repo)
if err != nil {
return nil, err
}
@@ -56,9 +61,9 @@ func (archiver *RepoArchiver) RelativePath() (string, error) {
}
// GetRepoArchiver get an archiver
func GetRepoArchiver(ctx DBContext, repoID int64, tp git.ArchiveType, commitID string) (*RepoArchiver, error) {
func GetRepoArchiver(ctx *db.Context, repoID int64, tp git.ArchiveType, commitID string) (*RepoArchiver, error) {
var archiver RepoArchiver
has, err := ctx.e.Where("repo_id=?", repoID).And("`type`=?", tp).And("commit_id=?", commitID).Get(&archiver)
has, err := ctx.Engine().Where("repo_id=?", repoID).And("`type`=?", tp).And("commit_id=?", commitID).Get(&archiver)
if err != nil {
return nil, err
}
@@ -69,19 +74,19 @@ func GetRepoArchiver(ctx DBContext, repoID int64, tp git.ArchiveType, commitID s
}
// AddRepoArchiver adds an archiver
func AddRepoArchiver(ctx DBContext, archiver *RepoArchiver) error {
_, err := ctx.e.Insert(archiver)
func AddRepoArchiver(ctx *db.Context, archiver *RepoArchiver) error {
_, err := ctx.Engine().Insert(archiver)
return err
}
// UpdateRepoArchiverStatus updates archiver's status
func UpdateRepoArchiverStatus(ctx DBContext, archiver *RepoArchiver) error {
_, err := ctx.e.ID(archiver.ID).Cols("status").Update(archiver)
func UpdateRepoArchiverStatus(ctx *db.Context, archiver *RepoArchiver) error {
_, err := ctx.Engine().ID(archiver.ID).Cols("status").Update(archiver)
return err
}
// DeleteAllRepoArchives deletes all repo archives records
func DeleteAllRepoArchives() error {
_, err := x.Where("1=1").Delete(new(RepoArchiver))
_, err := db.DefaultContext().Engine().Where("1=1").Delete(new(RepoArchiver))
return err
}
+9 -8
View File
@@ -13,6 +13,7 @@ import (
"strconv"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/avatar"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
@@ -25,7 +26,7 @@ func (repo *Repository) CustomAvatarRelativePath() string {
}
// generateRandomAvatar generates a random avatar for repository.
func (repo *Repository) generateRandomAvatar(e Engine) error {
func (repo *Repository) generateRandomAvatar(e db.Engine) error {
idToString := fmt.Sprintf("%d", repo.ID)
seed := idToString
@@ -56,7 +57,7 @@ func (repo *Repository) generateRandomAvatar(e Engine) error {
// RemoveRandomAvatars removes the randomly generated avatars that were created for repositories
func RemoveRandomAvatars(ctx context.Context) error {
return x.
return db.DefaultContext().Engine().
Where("id > 0").BufferSize(setting.Database.IterateBufferSize).
Iterate(new(Repository),
func(idx int, bean interface{}) error {
@@ -76,10 +77,10 @@ func RemoveRandomAvatars(ctx context.Context) error {
// RelAvatarLink returns a relative link to the repository's avatar.
func (repo *Repository) RelAvatarLink() string {
return repo.relAvatarLink(x)
return repo.relAvatarLink(db.DefaultContext().Engine())
}
func (repo *Repository) relAvatarLink(e Engine) string {
func (repo *Repository) relAvatarLink(e db.Engine) string {
// If no avatar - path is empty
avatarPath := repo.CustomAvatarRelativePath()
if len(avatarPath) == 0 {
@@ -100,11 +101,11 @@ func (repo *Repository) relAvatarLink(e Engine) string {
// AvatarLink returns a link to the repository's avatar.
func (repo *Repository) AvatarLink() string {
return repo.avatarLink(x)
return repo.avatarLink(db.DefaultContext().Engine())
}
// avatarLink returns user avatar absolute link.
func (repo *Repository) avatarLink(e Engine) string {
func (repo *Repository) avatarLink(e db.Engine) string {
link := repo.relAvatarLink(e)
// link may be empty!
if len(link) > 0 {
@@ -128,7 +129,7 @@ func (repo *Repository) UploadAvatar(data []byte) error {
return nil
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -171,7 +172,7 @@ func (repo *Repository) DeleteAvatar() error {
avatarPath := repo.CustomAvatarRelativePath()
log.Trace("DeleteAvatar[%d]: %s", repo.ID, avatarPath)
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
+23 -18
View File
@@ -8,6 +8,7 @@ package models
import (
"fmt"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/timeutil"
"xorm.io/builder"
@@ -23,7 +24,11 @@ type Collaboration struct {
UpdatedUnix timeutil.TimeStamp `xorm:"INDEX updated"`
}
func (repo *Repository) addCollaborator(e Engine, u *User) error {
func init() {
db.RegisterModel(new(Collaboration))
}
func (repo *Repository) addCollaborator(e db.Engine, u *User) error {
collaboration := &Collaboration{
RepoID: repo.ID,
UserID: u.ID,
@@ -46,7 +51,7 @@ func (repo *Repository) addCollaborator(e Engine, u *User) error {
// AddCollaborator adds new collaboration to a repository with default access mode.
func (repo *Repository) AddCollaborator(u *User) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -59,7 +64,7 @@ func (repo *Repository) AddCollaborator(u *User) error {
return sess.Commit()
}
func (repo *Repository) getCollaborations(e Engine, listOptions ListOptions) ([]*Collaboration, error) {
func (repo *Repository) getCollaborations(e db.Engine, listOptions ListOptions) ([]*Collaboration, error) {
if listOptions.Page == 0 {
collaborations := make([]*Collaboration, 0, 8)
return collaborations, e.Find(&collaborations, &Collaboration{RepoID: repo.ID})
@@ -77,7 +82,7 @@ type Collaborator struct {
Collaboration *Collaboration
}
func (repo *Repository) getCollaborators(e Engine, listOptions ListOptions) ([]*Collaborator, error) {
func (repo *Repository) getCollaborators(e db.Engine, listOptions ListOptions) ([]*Collaborator, error) {
collaborations, err := repo.getCollaborations(e, listOptions)
if err != nil {
return nil, fmt.Errorf("getCollaborations: %v", err)
@@ -99,15 +104,15 @@ func (repo *Repository) getCollaborators(e Engine, listOptions ListOptions) ([]*
// GetCollaborators returns the collaborators for a repository
func (repo *Repository) GetCollaborators(listOptions ListOptions) ([]*Collaborator, error) {
return repo.getCollaborators(x, listOptions)
return repo.getCollaborators(db.DefaultContext().Engine(), listOptions)
}
// CountCollaborators returns total number of collaborators for a repository
func (repo *Repository) CountCollaborators() (int64, error) {
return x.Where("repo_id = ? ", repo.ID).Count(&Collaboration{})
return db.DefaultContext().Engine().Where("repo_id = ? ", repo.ID).Count(&Collaboration{})
}
func (repo *Repository) getCollaboration(e Engine, uid int64) (*Collaboration, error) {
func (repo *Repository) getCollaboration(e db.Engine, uid int64) (*Collaboration, error) {
collaboration := &Collaboration{
RepoID: repo.ID,
UserID: uid,
@@ -119,16 +124,16 @@ func (repo *Repository) getCollaboration(e Engine, uid int64) (*Collaboration, e
return collaboration, err
}
func (repo *Repository) isCollaborator(e Engine, userID int64) (bool, error) {
func (repo *Repository) isCollaborator(e db.Engine, userID int64) (bool, error) {
return e.Get(&Collaboration{RepoID: repo.ID, UserID: userID})
}
// IsCollaborator check if a user is a collaborator of a repository
func (repo *Repository) IsCollaborator(userID int64) (bool, error) {
return repo.isCollaborator(x, userID)
return repo.isCollaborator(db.DefaultContext().Engine(), userID)
}
func (repo *Repository) changeCollaborationAccessMode(e Engine, uid int64, mode AccessMode) error {
func (repo *Repository) changeCollaborationAccessMode(e db.Engine, uid int64, mode AccessMode) error {
// Discard invalid input
if mode <= AccessModeNone || mode > AccessModeOwner {
return nil
@@ -164,7 +169,7 @@ func (repo *Repository) changeCollaborationAccessMode(e Engine, uid int64, mode
// ChangeCollaborationAccessMode sets new access mode for the collaboration.
func (repo *Repository) ChangeCollaborationAccessMode(uid int64, mode AccessMode) error {
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err := sess.Begin(); err != nil {
return err
@@ -184,7 +189,7 @@ func (repo *Repository) DeleteCollaboration(uid int64) (err error) {
UserID: uid,
}
sess := x.NewSession()
sess := db.DefaultContext().NewSession()
defer sess.Close()
if err = sess.Begin(); err != nil {
return err
@@ -212,7 +217,7 @@ func (repo *Repository) DeleteCollaboration(uid int64) (err error) {
return sess.Commit()
}
func (repo *Repository) reconsiderIssueAssignees(e Engine, uid int64) error {
func (repo *Repository) reconsiderIssueAssignees(e db.Engine, uid int64) error {
user, err := getUserByID(e, uid)
if err != nil {
return err
@@ -230,7 +235,7 @@ func (repo *Repository) reconsiderIssueAssignees(e Engine, uid int64) error {
return nil
}
func (repo *Repository) reconsiderWatches(e Engine, uid int64) error {
func (repo *Repository) reconsiderWatches(e db.Engine, uid int64) error {
if has, err := hasAccess(e, uid, repo); err != nil || has {
return err
}
@@ -243,7 +248,7 @@ func (repo *Repository) reconsiderWatches(e Engine, uid int64) error {
return removeIssueWatchersByRepoID(e, uid, repo.ID)
}
func (repo *Repository) getRepoTeams(e Engine) (teams []*Team, err error) {
func (repo *Repository) getRepoTeams(e db.Engine) (teams []*Team, err error) {
return teams, e.
Join("INNER", "team_repo", "team_repo.team_id = team.id").
Where("team.org_id = ?", repo.OwnerID).
@@ -254,7 +259,7 @@ func (repo *Repository) getRepoTeams(e Engine) (teams []*Team, err error) {
// GetRepoTeams gets the list of teams that has access to the repository
func (repo *Repository) GetRepoTeams() ([]*Team, error) {
return repo.getRepoTeams(x)
return repo.getRepoTeams(db.DefaultContext().Engine())
}
// IsOwnerMemberCollaborator checks if a provided user is the owner, a collaborator or a member of a team in a repository
@@ -262,7 +267,7 @@ func (repo *Repository) IsOwnerMemberCollaborator(userID int64) (bool, error) {
if repo.OwnerID == userID {
return true, nil
}
teamMember, err := x.Join("INNER", "team_repo", "team_repo.team_id = team_user.team_id").
teamMember, err := db.DefaultContext().Engine().Join("INNER", "team_repo", "team_repo.team_id = team_user.team_id").
Join("INNER", "team_unit", "team_unit.team_id = team_user.team_id").
Where("team_repo.repo_id = ?", repo.ID).
And("team_unit.`type` = ?", UnitTypeCode).
@@ -274,5 +279,5 @@ func (repo *Repository) IsOwnerMemberCollaborator(userID int64) (bool, error) {
return true, nil
}
return x.Get(&Collaboration{RepoID: repo.ID, UserID: userID})
return db.DefaultContext().Engine().Get(&Collaboration{RepoID: repo.ID, UserID: userID})
}
+19 -18
View File
@@ -7,16 +7,17 @@ package models
import (
"testing"
"code.gitea.io/gitea/models/db"
"github.com/stretchr/testify/assert"
)
func TestRepository_AddCollaborator(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
testSuccess := func(repoID, userID int64) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
assert.NoError(t, repo.GetOwner())
user := AssertExistsAndLoadBean(t, &User{ID: userID}).(*User)
user := db.AssertExistsAndLoadBean(t, &User{ID: userID}).(*User)
assert.NoError(t, repo.AddCollaborator(user))
CheckConsistencyFor(t, &Repository{ID: repoID}, &User{ID: userID})
}
@@ -26,12 +27,12 @@ func TestRepository_AddCollaborator(t *testing.T) {
}
func TestRepository_GetCollaborators(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(repoID int64) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
collaborators, err := repo.GetCollaborators(ListOptions{})
assert.NoError(t, err)
expectedLen, err := x.Count(&Collaboration{RepoID: repoID})
expectedLen, err := db.DefaultContext().Engine().Count(&Collaboration{RepoID: repoID})
assert.NoError(t, err)
assert.Len(t, collaborators, int(expectedLen))
for _, collaborator := range collaborators {
@@ -46,49 +47,49 @@ func TestRepository_GetCollaborators(t *testing.T) {
}
func TestRepository_IsCollaborator(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
test := func(repoID, userID int64, expected bool) {
repo := AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: repoID}).(*Repository)
actual, err := repo.IsCollaborator(userID)
assert.NoError(t, err)
assert.Equal(t, expected, actual)
}
test(3, 2, true)
test(3, NonexistentID, false)
test(3, db.NonexistentID, false)
test(4, 2, false)
test(4, 4, true)
}
func TestRepository_ChangeCollaborationAccessMode(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository)
assert.NoError(t, repo.ChangeCollaborationAccessMode(4, AccessModeAdmin))
collaboration := AssertExistsAndLoadBean(t, &Collaboration{RepoID: repo.ID, UserID: 4}).(*Collaboration)
collaboration := db.AssertExistsAndLoadBean(t, &Collaboration{RepoID: repo.ID, UserID: 4}).(*Collaboration)
assert.EqualValues(t, AccessModeAdmin, collaboration.Mode)
access := AssertExistsAndLoadBean(t, &Access{UserID: 4, RepoID: repo.ID}).(*Access)
access := db.AssertExistsAndLoadBean(t, &Access{UserID: 4, RepoID: repo.ID}).(*Access)
assert.EqualValues(t, AccessModeAdmin, access.Mode)
assert.NoError(t, repo.ChangeCollaborationAccessMode(4, AccessModeAdmin))
assert.NoError(t, repo.ChangeCollaborationAccessMode(NonexistentID, AccessModeAdmin))
assert.NoError(t, repo.ChangeCollaborationAccessMode(db.NonexistentID, AccessModeAdmin))
CheckConsistencyFor(t, &Repository{ID: repo.ID})
}
func TestRepository_DeleteCollaboration(t *testing.T) {
assert.NoError(t, PrepareTestDatabase())
assert.NoError(t, db.PrepareTestDatabase())
repo := AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository)
repo := db.AssertExistsAndLoadBean(t, &Repository{ID: 4}).(*Repository)
assert.NoError(t, repo.GetOwner())
assert.NoError(t, repo.DeleteCollaboration(4))
AssertNotExistsBean(t, &Collaboration{RepoID: repo.ID, UserID: 4})
db.AssertNotExistsBean(t, &Collaboration{RepoID: repo.ID, UserID: 4})
assert.NoError(t, repo.DeleteCollaboration(4))
AssertNotExistsBean(t, &Collaboration{RepoID: repo.ID, UserID: 4})
db.AssertNotExistsBean(t, &Collaboration{RepoID: repo.ID, UserID: 4})
CheckConsistencyFor(t, &Repository{ID: repo.ID})
}
+11 -10
View File
@@ -10,6 +10,7 @@ import (
"strconv"
"strings"
"code.gitea.io/gitea/models/db"
"code.gitea.io/gitea/modules/git"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/storage"
@@ -67,9 +68,9 @@ func (gt GiteaTemplate) Globs() []glob.Glob {
}
// GenerateTopics generates topics from a template repository
func GenerateTopics(ctx DBContext, templateRepo, generateRepo *Repository) error {
func GenerateTopics(ctx *db.Context, templateRepo, generateRepo *Repository) error {
for _, topic := range templateRepo.Topics {
if _, err := addTopicByNameToRepo(ctx.e, generateRepo.ID, topic); err != nil {
if _, err := addTopicByNameToRepo(ctx.Engine(), generateRepo.ID, topic); err != nil {
return err
}
}
@@ -77,7 +78,7 @@ func GenerateTopics(ctx DBContext, templateRepo, generateRepo *Repository) error
}
// GenerateGitHooks generates git hooks from a template repository
func GenerateGitHooks(ctx DBContext, templateRepo, generateRepo *Repository) error {
func GenerateGitHooks(ctx *db.Context, templateRepo, generateRepo *Repository) error {
generateGitRepo, err := git.OpenRepository(generateRepo.RepoPath())
if err != nil {
return err
@@ -110,7 +111,7 @@ func GenerateGitHooks(ctx DBContext, templateRepo, generateRepo *Repository) err
}
// GenerateWebhooks generates webhooks from a template repository
func GenerateWebhooks(ctx DBContext, templateRepo, generateRepo *Repository) error {
func GenerateWebhooks(ctx *db.Context, templateRepo, generateRepo *Repository) error {
templateWebhooks, err := ListWebhooksByOpts(&ListWebhookOptions{RepoID: templateRepo.ID})
if err != nil {
return err
@@ -130,7 +131,7 @@ func GenerateWebhooks(ctx DBContext, templateRepo, generateRepo *Repository) err
Events: templateWebhook.Events,
Meta: templateWebhook.Meta,
}
if err := createWebhook(ctx.e, generateWebhook); err != nil {
if err := createWebhook(ctx.Engine(), generateWebhook); err != nil {
return err
}
}
@@ -138,18 +139,18 @@ func GenerateWebhooks(ctx DBContext, templateRepo, generateRepo *Repository) err
}
// GenerateAvatar generates the avatar from a template repository
func GenerateAvatar(ctx DBContext, templateRepo, generateRepo *Repository) error {
func GenerateAvatar(ctx *db.Context, templateRepo, generateRepo *Repository) error {
generateRepo.Avatar = strings.Replace(templateRepo.Avatar, strconv.FormatInt(templateRepo.ID, 10), strconv.FormatInt(generateRepo.ID, 10), 1)
if _, err := storage.Copy(storage.RepoAvatars, generateRepo.CustomAvatarRelativePath(), storage.RepoAvatars, templateRepo.CustomAvatarRelativePath()); err != nil {
return err
}
return updateRepositoryCols(ctx.e, generateRepo, "avatar")
return updateRepositoryCols(ctx.Engine(), generateRepo, "avatar")
}
// GenerateIssueLabels generates issue labels from a template repository
func GenerateIssueLabels(ctx DBContext, templateRepo, generateRepo *Repository) error {
templateLabels, err := getLabelsByRepoID(ctx.e, templateRepo.ID, "", ListOptions{})
func GenerateIssueLabels(ctx *db.Context, templateRepo, generateRepo *Repository) error {
templateLabels, err := getLabelsByRepoID(ctx.Engine(), templateRepo.ID, "", ListOptions{})
if err != nil {
return err
}
@@ -161,7 +162,7 @@ func GenerateIssueLabels(ctx DBContext, templateRepo, generateRepo *Repository)
Description: templateLabel.Description,
Color: templateLabel.Color,
}
if err := newLabel(ctx.e, generateLabel); err != nil {
if err := newLabel(ctx.Engine(), generateLabel); err != nil {
return err
}
}

Some files were not shown because too many files have changed in this diff Show More