mirror of
https://github.com/stashapp/stash.git
synced 2026-04-14 11:04:07 +02:00
Tag custom fields support for backend (#6546)
* Fix custom field import/export for studio * Update studio unit tests * Add tag create and update unit tests * Add custom fields to tag filter graphql * Add unit tests for tag filtering * Add filter unit tests for studio
This commit is contained in:
parent
f629191b28
commit
b278525647
42 changed files with 1356 additions and 135 deletions
|
|
@ -650,6 +650,8 @@ input TagFilterType {
|
|||
|
||||
"Filter by last update time"
|
||||
updated_at: TimestampCriterionInput
|
||||
|
||||
custom_fields: [CustomFieldCriterionInput!]
|
||||
}
|
||||
|
||||
input ImageFilterType {
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ type Tag {
|
|||
|
||||
parent_count: Int! # Resolver
|
||||
child_count: Int! # Resolver
|
||||
custom_fields: Map!
|
||||
}
|
||||
|
||||
input TagCreateInput {
|
||||
|
|
@ -41,6 +42,8 @@ input TagCreateInput {
|
|||
|
||||
parent_ids: [ID!]
|
||||
child_ids: [ID!]
|
||||
|
||||
custom_fields: Map
|
||||
}
|
||||
|
||||
input TagUpdateInput {
|
||||
|
|
@ -59,6 +62,8 @@ input TagUpdateInput {
|
|||
|
||||
parent_ids: [ID!]
|
||||
child_ids: [ID!]
|
||||
|
||||
custom_fields: CustomFieldsInput
|
||||
}
|
||||
|
||||
input TagDestroyInput {
|
||||
|
|
|
|||
|
|
@ -62,10 +62,11 @@ type Loaders struct {
|
|||
StudioByID *StudioLoader
|
||||
StudioCustomFields *CustomFieldsLoader
|
||||
|
||||
TagByID *TagLoader
|
||||
GroupByID *GroupLoader
|
||||
FileByID *FileLoader
|
||||
FolderByID *FolderLoader
|
||||
TagByID *TagLoader
|
||||
TagCustomFields *CustomFieldsLoader
|
||||
GroupByID *GroupLoader
|
||||
FileByID *FileLoader
|
||||
FolderByID *FolderLoader
|
||||
}
|
||||
|
||||
type Middleware struct {
|
||||
|
|
@ -116,6 +117,11 @@ func (m Middleware) Middleware(next http.Handler) http.Handler {
|
|||
maxBatch: maxBatch,
|
||||
fetch: m.fetchTags(ctx),
|
||||
},
|
||||
TagCustomFields: &CustomFieldsLoader{
|
||||
wait: wait,
|
||||
maxBatch: maxBatch,
|
||||
fetch: m.fetchTagCustomFields(ctx),
|
||||
},
|
||||
GroupByID: &GroupLoader{
|
||||
wait: wait,
|
||||
maxBatch: maxBatch,
|
||||
|
|
@ -283,6 +289,18 @@ func (m Middleware) fetchTags(ctx context.Context) func(keys []int) ([]*models.T
|
|||
}
|
||||
}
|
||||
|
||||
func (m Middleware) fetchTagCustomFields(ctx context.Context) func(keys []int) ([]models.CustomFieldMap, []error) {
|
||||
return func(keys []int) (ret []models.CustomFieldMap, errs []error) {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
var err error
|
||||
ret, err = m.Repository.Tag.GetCustomFieldsBulk(ctx, keys)
|
||||
return err
|
||||
})
|
||||
|
||||
return ret, toErrorSlice(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (m Middleware) fetchGroups(ctx context.Context) func(keys []int) ([]*models.Group, []error) {
|
||||
return func(keys []int) (ret []*models.Group, errs []error) {
|
||||
err := m.Repository.WithDB(ctx, func(ctx context.Context) error {
|
||||
|
|
|
|||
|
|
@ -181,3 +181,16 @@ func (r *tagResolver) ChildCount(ctx context.Context, obj *models.Tag) (ret int,
|
|||
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func (r *tagResolver) CustomFields(ctx context.Context, obj *models.Tag) (map[string]interface{}, error) {
|
||||
m, err := loaders.From(ctx).TagCustomFields.Load(obj.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if m == nil {
|
||||
return make(map[string]interface{}), nil
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -31,7 +31,10 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
|
|||
}
|
||||
|
||||
// Populate a new tag from the input
|
||||
newTag := models.NewTag()
|
||||
newTag := models.CreateTagInput{
|
||||
Tag: &models.Tag{},
|
||||
}
|
||||
*newTag.Tag = models.NewTag()
|
||||
|
||||
newTag.Name = strings.TrimSpace(input.Name)
|
||||
newTag.SortName = translator.string(input.SortName)
|
||||
|
|
@ -60,6 +63,8 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
|
|||
return nil, fmt.Errorf("converting child tag ids: %w", err)
|
||||
}
|
||||
|
||||
newTag.CustomFields = convertMapJSONNumbers(input.CustomFields)
|
||||
|
||||
// Process the base 64 encoded image string
|
||||
var imageData []byte
|
||||
if input.Image != nil {
|
||||
|
|
@ -73,7 +78,7 @@ func (r *mutationResolver) TagCreate(ctx context.Context, input TagCreateInput)
|
|||
if err := r.withTxn(ctx, func(ctx context.Context) error {
|
||||
qb := r.repository.Tag
|
||||
|
||||
if err := tag.ValidateCreate(ctx, newTag, qb); err != nil {
|
||||
if err := tag.ValidateCreate(ctx, *newTag.Tag, qb); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
@ -137,6 +142,13 @@ func (r *mutationResolver) TagUpdate(ctx context.Context, input TagUpdateInput)
|
|||
return nil, fmt.Errorf("converting child tag ids: %w", err)
|
||||
}
|
||||
|
||||
if input.CustomFields != nil {
|
||||
updatedTag.CustomFields = *input.CustomFields
|
||||
// convert json.Numbers to int/float
|
||||
updatedTag.CustomFields.Full = convertMapJSONNumbers(updatedTag.CustomFields.Full)
|
||||
updatedTag.CustomFields.Partial = convertMapJSONNumbers(updatedTag.CustomFields.Partial)
|
||||
}
|
||||
|
||||
var imageData []byte
|
||||
imageIncluded := translator.hasField("image")
|
||||
if input.Image != nil {
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ func createTag(ctx context.Context, qb models.TagWriter) error {
|
|||
Name: testName,
|
||||
}
|
||||
|
||||
err := qb.Create(ctx, &tag)
|
||||
err := qb.Create(ctx, &models.CreateTagInput{Tag: &tag})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -167,7 +167,9 @@ func (g sceneRelationships) tags(ctx context.Context) ([]int, error) {
|
|||
} else if createMissing {
|
||||
newTag := t.ToTag(endpoint, nil)
|
||||
|
||||
err := g.tagCreator.Create(ctx, newTag)
|
||||
err := g.tagCreator.Create(ctx, &models.CreateTagInput{
|
||||
Tag: newTag,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tag: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -368,14 +368,14 @@ func Test_sceneRelationships_tags(t *testing.T) {
|
|||
|
||||
db := mocks.NewDatabase()
|
||||
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
|
||||
return p.Name == validName
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateTagInput) bool {
|
||||
return p.Tag.Name == validName
|
||||
})).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = validStoredIDInt
|
||||
t := args.Get(1).(*models.CreateTagInput)
|
||||
t.Tag.ID = validStoredIDInt
|
||||
}).Return(nil)
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.Tag) bool {
|
||||
return p.Name == invalidName
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(p *models.CreateTagInput) bool {
|
||||
return p.Tag.Name == invalidName
|
||||
})).Return(errors.New("error creating tag"))
|
||||
|
||||
tr := sceneRelationships{
|
||||
|
|
|
|||
|
|
@ -249,7 +249,9 @@ func (i *Importer) createTags(ctx context.Context, names []string) ([]*models.Ta
|
|||
newTag := models.NewTag()
|
||||
newTag.Name = name
|
||||
|
||||
err := i.TagWriter.Create(ctx, &newTag)
|
||||
err := i.TagWriter.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &newTag,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -289,9 +289,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.CreateTagInput)
|
||||
t.Tag.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -323,7 +323,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -126,7 +126,9 @@ func createTags(ctx context.Context, tagWriter models.TagFinderCreator, names []
|
|||
newTag := models.NewTag()
|
||||
newTag.Name = name
|
||||
|
||||
err := tagWriter.Create(ctx, &newTag)
|
||||
err := tagWriter.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &newTag,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -212,9 +212,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.CreateTagInput)
|
||||
t.Tag.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -247,7 +247,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -407,7 +407,9 @@ func createTags(ctx context.Context, tagWriter models.TagCreator, names []string
|
|||
newTag := models.NewTag()
|
||||
newTag.Name = name
|
||||
|
||||
err := tagWriter.Create(ctx, &newTag)
|
||||
err := tagWriter.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &newTag,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -251,9 +251,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.CreateTagInput)
|
||||
t.Tag.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -285,7 +285,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -17,3 +17,7 @@ type CustomFieldsReader interface {
|
|||
GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error)
|
||||
GetCustomFieldsBulk(ctx context.Context, ids []int) ([]CustomFieldMap, error)
|
||||
}
|
||||
|
||||
type CustomFieldsWriter interface {
|
||||
SetCustomFields(ctx context.Context, id int, fields CustomFieldsInput) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ type Studio struct {
|
|||
Tags []string `json:"tags,omitempty"`
|
||||
IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"`
|
||||
|
||||
CustomFields map[string]interface{} `json:"custom_fields,omitempty"`
|
||||
|
||||
// deprecated - for import only
|
||||
URL string `json:"url,omitempty"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,17 +11,18 @@ import (
|
|||
)
|
||||
|
||||
type Tag struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
SortName string `json:"sort_name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Favorite bool `json:"favorite,omitempty"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Parents []string `json:"parents,omitempty"`
|
||||
IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"`
|
||||
StashIDs []models.StashID `json:"stash_ids,omitempty"`
|
||||
CreatedAt json.JSONTime `json:"created_at,omitempty"`
|
||||
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`
|
||||
Name string `json:"name,omitempty"`
|
||||
SortName string `json:"sort_name,omitempty"`
|
||||
Description string `json:"description,omitempty"`
|
||||
Favorite bool `json:"favorite,omitempty"`
|
||||
Aliases []string `json:"aliases,omitempty"`
|
||||
Image string `json:"image,omitempty"`
|
||||
Parents []string `json:"parents,omitempty"`
|
||||
IgnoreAutoTag bool `json:"ignore_auto_tag,omitempty"`
|
||||
StashIDs []models.StashID `json:"stash_ids,omitempty"`
|
||||
CreatedAt json.JSONTime `json:"created_at,omitempty"`
|
||||
UpdatedAt json.JSONTime `json:"updated_at,omitempty"`
|
||||
CustomFields map[string]interface{} `json:"custom_fields,omitempty"`
|
||||
}
|
||||
|
||||
func (s Tag) Filename() string {
|
||||
|
|
|
|||
|
|
@ -101,11 +101,11 @@ func (_m *TagReaderWriter) CountByParentTagID(ctx context.Context, parentID int)
|
|||
}
|
||||
|
||||
// Create provides a mock function with given fields: ctx, newTag
|
||||
func (_m *TagReaderWriter) Create(ctx context.Context, newTag *models.Tag) error {
|
||||
func (_m *TagReaderWriter) Create(ctx context.Context, newTag *models.CreateTagInput) error {
|
||||
ret := _m.Called(ctx, newTag)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *models.Tag) error); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *models.CreateTagInput) error); ok {
|
||||
r0 = rf(ctx, newTag)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
|
|
@ -542,6 +542,52 @@ func (_m *TagReaderWriter) GetChildIDs(ctx context.Context, relatedID int) ([]in
|
|||
return r0, r1
|
||||
}
|
||||
|
||||
// GetCustomFields provides a mock function with given fields: ctx, id
|
||||
func (_m *TagReaderWriter) GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error) {
|
||||
ret := _m.Called(ctx, id)
|
||||
|
||||
var r0 map[string]interface{}
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int) map[string]interface{}); ok {
|
||||
r0 = rf(ctx, id)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).(map[string]interface{})
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, int) error); ok {
|
||||
r1 = rf(ctx, id)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetCustomFieldsBulk provides a mock function with given fields: ctx, ids
|
||||
func (_m *TagReaderWriter) GetCustomFieldsBulk(ctx context.Context, ids []int) ([]models.CustomFieldMap, error) {
|
||||
ret := _m.Called(ctx, ids)
|
||||
|
||||
var r0 []models.CustomFieldMap
|
||||
if rf, ok := ret.Get(0).(func(context.Context, []int) []models.CustomFieldMap); ok {
|
||||
r0 = rf(ctx, ids)
|
||||
} else {
|
||||
if ret.Get(0) != nil {
|
||||
r0 = ret.Get(0).([]models.CustomFieldMap)
|
||||
}
|
||||
}
|
||||
|
||||
var r1 error
|
||||
if rf, ok := ret.Get(1).(func(context.Context, []int) error); ok {
|
||||
r1 = rf(ctx, ids)
|
||||
} else {
|
||||
r1 = ret.Error(1)
|
||||
}
|
||||
|
||||
return r0, r1
|
||||
}
|
||||
|
||||
// GetImage provides a mock function with given fields: ctx, tagID
|
||||
func (_m *TagReaderWriter) GetImage(ctx context.Context, tagID int) ([]byte, error) {
|
||||
ret := _m.Called(ctx, tagID)
|
||||
|
|
@ -699,12 +745,26 @@ func (_m *TagReaderWriter) QueryForAutoTag(ctx context.Context, words []string)
|
|||
return r0, r1
|
||||
}
|
||||
|
||||
// SetCustomFields provides a mock function with given fields: ctx, id, fields
|
||||
func (_m *TagReaderWriter) SetCustomFields(ctx context.Context, id int, fields models.CustomFieldsInput) error {
|
||||
ret := _m.Called(ctx, id, fields)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, int, models.CustomFieldsInput) error); ok {
|
||||
r0 = rf(ctx, id, fields)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
}
|
||||
|
||||
return r0
|
||||
}
|
||||
|
||||
// Update provides a mock function with given fields: ctx, updatedTag
|
||||
func (_m *TagReaderWriter) Update(ctx context.Context, updatedTag *models.Tag) error {
|
||||
func (_m *TagReaderWriter) Update(ctx context.Context, updatedTag *models.UpdateTagInput) error {
|
||||
ret := _m.Called(ctx, updatedTag)
|
||||
|
||||
var r0 error
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *models.Tag) error); ok {
|
||||
if rf, ok := ret.Get(0).(func(context.Context, *models.UpdateTagInput) error); ok {
|
||||
r0 = rf(ctx, updatedTag)
|
||||
} else {
|
||||
r0 = ret.Error(0)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,18 @@ func NewTag() Tag {
|
|||
}
|
||||
}
|
||||
|
||||
type CreateTagInput struct {
|
||||
*Tag
|
||||
|
||||
CustomFields map[string]interface{} `json:"custom_fields"`
|
||||
}
|
||||
|
||||
type UpdateTagInput struct {
|
||||
*Tag
|
||||
|
||||
CustomFields CustomFieldsInput `json:"custom_fields"`
|
||||
}
|
||||
|
||||
func (s *Tag) LoadAliases(ctx context.Context, l AliasLoader) error {
|
||||
return s.Aliases.load(func() ([]string, error) {
|
||||
return l.GetAliases(ctx, s.ID)
|
||||
|
|
@ -66,6 +78,8 @@ type TagPartial struct {
|
|||
ParentIDs *UpdateIDs
|
||||
ChildIDs *UpdateIDs
|
||||
StashIDs *UpdateStashIDs
|
||||
|
||||
CustomFields CustomFieldsInput
|
||||
}
|
||||
|
||||
func NewTagPartial() TagPartial {
|
||||
|
|
|
|||
|
|
@ -51,12 +51,12 @@ type TagCounter interface {
|
|||
|
||||
// TagCreator provides methods to create tags.
|
||||
type TagCreator interface {
|
||||
Create(ctx context.Context, newTag *Tag) error
|
||||
Create(ctx context.Context, newTag *CreateTagInput) error
|
||||
}
|
||||
|
||||
// TagUpdater provides methods to update tags.
|
||||
type TagUpdater interface {
|
||||
Update(ctx context.Context, updatedTag *Tag) error
|
||||
Update(ctx context.Context, updatedTag *UpdateTagInput) error
|
||||
UpdatePartial(ctx context.Context, id int, updateTag TagPartial) (*Tag, error)
|
||||
UpdateAliases(ctx context.Context, tagID int, aliases []string) error
|
||||
UpdateImage(ctx context.Context, tagID int, image []byte) error
|
||||
|
|
@ -77,6 +77,7 @@ type TagFinderCreator interface {
|
|||
type TagCreatorUpdater interface {
|
||||
TagCreator
|
||||
TagUpdater
|
||||
CustomFieldsWriter
|
||||
}
|
||||
|
||||
// TagReader provides all methods to read tags.
|
||||
|
|
@ -89,6 +90,7 @@ type TagReader interface {
|
|||
AliasLoader
|
||||
TagRelationLoader
|
||||
StashIDLoader
|
||||
CustomFieldsReader
|
||||
|
||||
All(ctx context.Context) ([]*Tag, error)
|
||||
GetImage(ctx context.Context, tagID int) ([]byte, error)
|
||||
|
|
@ -100,6 +102,7 @@ type TagWriter interface {
|
|||
TagCreator
|
||||
TagUpdater
|
||||
TagDestroyer
|
||||
CustomFieldsWriter
|
||||
|
||||
Merge(ctx context.Context, source []int, destination int) error
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,4 +56,7 @@ type TagFilterType struct {
|
|||
CreatedAt *TimestampCriterionInput `json:"created_at"`
|
||||
// Filter by updated at
|
||||
UpdatedAt *TimestampCriterionInput `json:"updated_at"`
|
||||
|
||||
// Filter by custom fields
|
||||
CustomFields []CustomFieldCriterionInput `json:"custom_fields"`
|
||||
}
|
||||
|
|
|
|||
|
|
@ -107,7 +107,9 @@ func createTags(ctx context.Context, tagWriter models.TagFinderCreator, names []
|
|||
newTag := models.NewTag()
|
||||
newTag.Name = name
|
||||
|
||||
err := tagWriter.Create(ctx, &newTag)
|
||||
err := tagWriter.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &newTag,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -111,9 +111,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.CreateTagInput)
|
||||
t.Tag.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -146,7 +146,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -549,7 +549,9 @@ func createTags(ctx context.Context, tagWriter models.TagCreator, names []string
|
|||
newTag := models.NewTag()
|
||||
newTag.Name = name
|
||||
|
||||
err := tagWriter.Create(ctx, &newTag)
|
||||
err := tagWriter.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &newTag,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -508,9 +508,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.CreateTagInput)
|
||||
t.Tag.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -542,7 +542,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -678,6 +678,10 @@ func (db *Anonymiser) anonymiseStudios(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err := db.anonymiseCustomFields(ctx, goqu.T(studiosCustomFieldsTable.GetTable()), "studio_id"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -873,6 +877,10 @@ func (db *Anonymiser) anonymiseTags(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
if err := db.anonymiseCustomFields(ctx, goqu.T(tagsCustomFieldsTable.GetTable()), "tag_id"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ const (
|
|||
cacheSizeEnv = "STASH_SQLITE_CACHE_SIZE"
|
||||
)
|
||||
|
||||
var appSchemaVersion uint = 76
|
||||
var appSchemaVersion uint = 77
|
||||
|
||||
//go:embed migrations/*.sql
|
||||
var migrationsBox embed.FS
|
||||
|
|
|
|||
9
pkg/sqlite/migrations/77_tag_custom_fields.up.sql
Normal file
9
pkg/sqlite/migrations/77_tag_custom_fields.up.sql
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
CREATE TABLE `tag_custom_fields` (
|
||||
`tag_id` integer NOT NULL,
|
||||
`field` varchar(64) NOT NULL,
|
||||
`value` BLOB NOT NULL,
|
||||
PRIMARY KEY (`tag_id`, `field`),
|
||||
foreign key(`tag_id`) references `tags`(`id`) on delete CASCADE
|
||||
);
|
||||
|
||||
CREATE INDEX `index_tag_custom_fields_field_value` ON `tag_custom_fields` (`field`, `value`);
|
||||
|
|
@ -1709,6 +1709,18 @@ func tagStashID(i int) models.StashID {
|
|||
}
|
||||
}
|
||||
|
||||
func getTagCustomFields(index int) map[string]interface{} {
|
||||
if index%5 == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"string": getTagStringValue(index, "custom"),
|
||||
"int": int64(index % 5),
|
||||
"real": float64(index) / 10,
|
||||
}
|
||||
}
|
||||
|
||||
// createTags creates n tags with plain Name and o tags with camel cased NaMe included
|
||||
func createTags(ctx context.Context, tqb models.TagReaderWriter, n int, o int) error {
|
||||
const namePlain = "Name"
|
||||
|
|
@ -1736,7 +1748,10 @@ func createTags(ctx context.Context, tqb models.TagReaderWriter, n int, o int) e
|
|||
})
|
||||
}
|
||||
|
||||
err := tqb.Create(ctx, &tag)
|
||||
err := tqb.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &tag,
|
||||
CustomFields: getTagCustomFields(i),
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating tag %v+: %s", tag, err.Error())
|
||||
|
|
|
|||
|
|
@ -1694,6 +1694,251 @@ func TestStudioQueryFast(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func studiesToIDs(i []*models.Studio) []int {
|
||||
ret := make([]int, len(i))
|
||||
for i, v := range i {
|
||||
ret[i] = v.ID
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
func TestStudioQueryCustomFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filter *models.StudioFilterType
|
||||
includeIdxs []int
|
||||
excludeIdxs []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"equals",
|
||||
&models.StudioFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
Value: []any{getStudioStringValue(studioIdxWithTwoScenes, "custom")},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{studioIdxWithTwoScenes},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not equals",
|
||||
&models.StudioFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotEquals,
|
||||
Value: []any{getStudioStringValue(studioIdxWithTwoScenes, "custom")},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{studioIdxWithTwoScenes},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"includes",
|
||||
&models.StudioFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierIncludes,
|
||||
Value: []any{getStudioStringValue(studioIdxWithTwoScenes, "custom")[9:]},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{studioIdxWithTwoScenes},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"excludes",
|
||||
&models.StudioFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierExcludes,
|
||||
Value: []any{getStudioStringValue(studioIdxWithTwoScenes, "custom")[9:]},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{studioIdxWithTwoScenes},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"regex",
|
||||
&models.StudioFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
Value: []any{".*1_custom"},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{studioIdxWithTwoScenes},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid regex",
|
||||
&models.StudioFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
Value: []any{"["},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"not matches regex",
|
||||
&models.StudioFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotMatchesRegex,
|
||||
Value: []any{".*1_custom"},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{studioIdxWithTwoScenes},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid not matches regex",
|
||||
&models.StudioFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotMatchesRegex,
|
||||
Value: []any{"["},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"null",
|
||||
&models.StudioFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "not existing",
|
||||
Modifier: models.CriterionModifierIsNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{studioIdxWithTwoScenes},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not null",
|
||||
&models.StudioFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getStudioStringValue(studioIdxWithTwoScenes, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{studioIdxWithTwoScenes},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"between",
|
||||
&models.StudioFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "real",
|
||||
Modifier: models.CriterionModifierBetween,
|
||||
Value: []any{0.15, 0.25},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{studioIdxWithGroup},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not between",
|
||||
&models.StudioFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getStudioStringValue(studioIdxWithGroup, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "real",
|
||||
Modifier: models.CriterionModifierNotBetween,
|
||||
Value: []any{0.15, 0.25},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{studioIdxWithGroup},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
studios, _, err := db.Studio.Query(ctx, tt.filter, nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("StudioStore.Query() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
ids := studiesToIDs(studios)
|
||||
include := indexesToIDs(studioIDs, tt.includeIdxs)
|
||||
exclude := indexesToIDs(studioIDs, tt.excludeIdxs)
|
||||
|
||||
for _, i := range include {
|
||||
assert.Contains(ids, i)
|
||||
}
|
||||
for _, e := range exclude {
|
||||
assert.NotContains(ids, e)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO Create
|
||||
// TODO Update
|
||||
// TODO Destroy
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ var (
|
|||
tagsAliasesJoinTable = goqu.T(tagAliasesTable)
|
||||
tagRelationsJoinTable = goqu.T(tagRelationsTable)
|
||||
tagsStashIDsJoinTable = goqu.T("tag_stash_ids")
|
||||
tagsCustomFieldsTable = goqu.T("tag_custom_fields")
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
|
|||
|
|
@ -166,6 +166,7 @@ var (
|
|||
|
||||
type TagStore struct {
|
||||
blobJoinQueryBuilder
|
||||
customFieldsStore
|
||||
|
||||
tableMgr *table
|
||||
}
|
||||
|
|
@ -176,6 +177,10 @@ func NewTagStore(blobStore *BlobStore) *TagStore {
|
|||
blobStore: blobStore,
|
||||
joinTable: tagTable,
|
||||
},
|
||||
customFieldsStore: customFieldsStore{
|
||||
table: tagsCustomFieldsTable,
|
||||
fk: tagsCustomFieldsTable.Col(tagIDColumn),
|
||||
},
|
||||
tableMgr: tagTableMgr,
|
||||
}
|
||||
}
|
||||
|
|
@ -188,9 +193,9 @@ func (qb *TagStore) selectDataset() *goqu.SelectDataset {
|
|||
return dialect.From(qb.table()).Select(qb.table().All())
|
||||
}
|
||||
|
||||
func (qb *TagStore) Create(ctx context.Context, newObject *models.Tag) error {
|
||||
func (qb *TagStore) Create(ctx context.Context, newObject *models.CreateTagInput) error {
|
||||
var r tagRow
|
||||
r.fromTag(*newObject)
|
||||
r.fromTag(*newObject.Tag)
|
||||
|
||||
id, err := qb.tableMgr.insertID(ctx, r)
|
||||
if err != nil {
|
||||
|
|
@ -221,12 +226,17 @@ func (qb *TagStore) Create(ctx context.Context, newObject *models.Tag) error {
|
|||
}
|
||||
}
|
||||
|
||||
const partial = false
|
||||
if err := qb.setCustomFields(ctx, id, newObject.CustomFields, partial); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updated, err := qb.find(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("finding after create: %w", err)
|
||||
}
|
||||
|
||||
*newObject = *updated
|
||||
*newObject.Tag = *updated
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
@ -270,12 +280,16 @@ func (qb *TagStore) UpdatePartial(ctx context.Context, id int, partial models.Ta
|
|||
}
|
||||
}
|
||||
|
||||
if err := qb.SetCustomFields(ctx, id, partial.CustomFields); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return qb.find(ctx, id)
|
||||
}
|
||||
|
||||
func (qb *TagStore) Update(ctx context.Context, updatedObject *models.Tag) error {
|
||||
func (qb *TagStore) Update(ctx context.Context, updatedObject *models.UpdateTagInput) error {
|
||||
var r tagRow
|
||||
r.fromTag(*updatedObject)
|
||||
r.fromTag(*updatedObject.Tag)
|
||||
|
||||
if err := qb.tableMgr.updateByID(ctx, updatedObject.ID, r); err != nil {
|
||||
return err
|
||||
|
|
@ -305,6 +319,10 @@ func (qb *TagStore) Update(ctx context.Context, updatedObject *models.Tag) error
|
|||
}
|
||||
}
|
||||
|
||||
if err := qb.SetCustomFields(ctx, updatedObject.ID, updatedObject.CustomFields); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -101,6 +101,13 @@ func (qb *tagFilterHandler) criterionHandler() criterionHandler {
|
|||
×tampCriterionHandler{tagFilter.CreatedAt, "tags.created_at", nil},
|
||||
×tampCriterionHandler{tagFilter.UpdatedAt, "tags.updated_at", nil},
|
||||
|
||||
&customFieldsFilterHandler{
|
||||
table: tagsCustomFieldsTable.GetTable(),
|
||||
fkCol: tagIDColumn,
|
||||
c: tagFilter.CustomFields,
|
||||
idCol: "tags.id",
|
||||
},
|
||||
|
||||
&relatedFilterHandler{
|
||||
relatedIDCol: "scenes_tags.scene_id",
|
||||
relatedRepo: sceneRepository.repository,
|
||||
|
|
|
|||
|
|
@ -1012,8 +1012,10 @@ func TestTagUpdateTagImage(t *testing.T) {
|
|||
|
||||
// create tag to test against
|
||||
const name = "TestTagUpdateTagImage"
|
||||
tag := models.Tag{
|
||||
Name: name,
|
||||
tag := models.CreateTagInput{
|
||||
Tag: &models.Tag{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
err := qb.Create(ctx, &tag)
|
||||
if err != nil {
|
||||
|
|
@ -1032,15 +1034,17 @@ func TestTagUpdateAlias(t *testing.T) {
|
|||
|
||||
// create tag to test against
|
||||
const name = "TestTagUpdateAlias"
|
||||
tag := models.Tag{
|
||||
Name: name,
|
||||
tag := models.CreateTagInput{
|
||||
Tag: &models.Tag{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
err := qb.Create(ctx, &tag)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error creating tag: %s", err.Error())
|
||||
}
|
||||
|
||||
aliases := []string{"alias1", "alias2"}
|
||||
aliases := []string{"updatedAlias1", "updatedAlias2"}
|
||||
err = qb.UpdateAliases(ctx, tag.ID, aliases)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error updating tag aliases: %s", err.Error())
|
||||
|
|
@ -1065,8 +1069,10 @@ func TestTagStashIDs(t *testing.T) {
|
|||
|
||||
// create tag to test against
|
||||
const name = "TestTagStashIDs"
|
||||
tag := models.Tag{
|
||||
Name: name,
|
||||
tag := models.CreateTagInput{
|
||||
Tag: &models.Tag{
|
||||
Name: name,
|
||||
},
|
||||
}
|
||||
err := qb.Create(ctx, &tag)
|
||||
if err != nil {
|
||||
|
|
@ -1089,9 +1095,11 @@ func TestTagFindByStashID(t *testing.T) {
|
|||
const name = "TestTagFindByStashID"
|
||||
const stashID = "stashid"
|
||||
const endpoint = "endpoint"
|
||||
tag := models.Tag{
|
||||
Name: name,
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{{StashID: stashID, Endpoint: endpoint}}),
|
||||
tag := models.CreateTagInput{
|
||||
Tag: &models.Tag{
|
||||
Name: name,
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{{StashID: stashID, Endpoint: endpoint}}),
|
||||
},
|
||||
}
|
||||
err := qb.Create(ctx, &tag)
|
||||
if err != nil {
|
||||
|
|
@ -1263,8 +1271,626 @@ func TestTagMerge(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// TODO Create
|
||||
// TODO Update
|
||||
func loadTagRelationships(ctx context.Context, expected models.Tag, actual *models.Tag) error {
|
||||
if expected.Aliases.Loaded() {
|
||||
if err := actual.LoadAliases(ctx, db.Tag); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if expected.ParentIDs.Loaded() {
|
||||
if err := actual.LoadParentIDs(ctx, db.Tag); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if expected.ChildIDs.Loaded() {
|
||||
if err := actual.LoadChildIDs(ctx, db.Tag); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if expected.StashIDs.Loaded() {
|
||||
if err := actual.LoadStashIDs(ctx, db.Tag); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_TagStore_Create(t *testing.T) {
|
||||
var (
|
||||
name = "name"
|
||||
sortName = "sortName"
|
||||
description = "description"
|
||||
favorite = true
|
||||
ignoreAutoTag = true
|
||||
aliases = []string{"alias1", "alias2"}
|
||||
endpoint1 = "endpoint1"
|
||||
endpoint2 = "endpoint2"
|
||||
stashID1 = "stashid1"
|
||||
stashID2 = "stashid2"
|
||||
createdAt = epochTime
|
||||
updatedAt = epochTime
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
newObject models.CreateTagInput
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"full",
|
||||
models.CreateTagInput{
|
||||
Tag: &models.Tag{
|
||||
Name: name,
|
||||
SortName: sortName,
|
||||
Description: description,
|
||||
Favorite: favorite,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
ParentIDs: models.NewRelatedIDs([]int{tagIDs[tagIdxWithScene]}),
|
||||
ChildIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithScene]}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
},
|
||||
CustomFields: testCustomFields,
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid parent id",
|
||||
models.CreateTagInput{
|
||||
Tag: &models.Tag{
|
||||
Name: name,
|
||||
ParentIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid child id",
|
||||
models.CreateTagInput{
|
||||
Tag: &models.Tag{
|
||||
Name: name,
|
||||
ChildIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
qb := db.Tag
|
||||
|
||||
for _, tt := range tests {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
p := tt.newObject
|
||||
if err := qb.Create(ctx, &p); (err != nil) != tt.wantErr {
|
||||
t.Errorf("TagStore.Create() error = %v, wantErr = %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
assert.Zero(p.ID)
|
||||
return
|
||||
}
|
||||
|
||||
assert.NotZero(p.ID)
|
||||
|
||||
copy := *tt.newObject.Tag
|
||||
copy.ID = p.ID
|
||||
|
||||
// load relationships
|
||||
if err := loadTagRelationships(ctx, copy, p.Tag); err != nil {
|
||||
t.Errorf("loadTagRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(copy, *p.Tag)
|
||||
|
||||
// ensure can find the tag
|
||||
found, err := qb.Find(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Errorf("TagStore.Find() error = %v", err)
|
||||
}
|
||||
|
||||
if !assert.NotNil(found) {
|
||||
return
|
||||
}
|
||||
|
||||
// load relationships
|
||||
if err := loadTagRelationships(ctx, copy, found); err != nil {
|
||||
t.Errorf("loadTagRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
assert.Equal(copy, *found)
|
||||
|
||||
// ensure custom fields are set
|
||||
cf, err := qb.GetCustomFields(ctx, p.ID)
|
||||
if err != nil {
|
||||
t.Errorf("TagStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.newObject.CustomFields, cf)
|
||||
|
||||
return
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TagStore_Update(t *testing.T) {
|
||||
var (
|
||||
name = "name"
|
||||
sortName = "sortName"
|
||||
description = "description"
|
||||
favorite = true
|
||||
ignoreAutoTag = true
|
||||
aliases = []string{"alias1", "alias2"}
|
||||
endpoint1 = "endpoint1"
|
||||
endpoint2 = "endpoint2"
|
||||
stashID1 = "stashid1"
|
||||
stashID2 = "stashid2"
|
||||
createdAt = epochTime
|
||||
updatedAt = epochTime
|
||||
)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
updatedObject models.UpdateTagInput
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"full",
|
||||
models.UpdateTagInput{
|
||||
Tag: &models.Tag{
|
||||
ID: tagIDs[tagIdxWithGallery],
|
||||
Name: name,
|
||||
SortName: sortName,
|
||||
Description: description,
|
||||
Favorite: favorite,
|
||||
IgnoreAutoTag: ignoreAutoTag,
|
||||
Aliases: models.NewRelatedStrings(aliases),
|
||||
ParentIDs: models.NewRelatedIDs([]int{tagIDs[tagIdxWithScene]}),
|
||||
ChildIDs: models.NewRelatedIDs([]int{tagIDs[tagIdx1WithScene]}),
|
||||
StashIDs: models.NewRelatedStashIDs([]models.StashID{
|
||||
{
|
||||
StashID: stashID1,
|
||||
Endpoint: endpoint1,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
{
|
||||
StashID: stashID2,
|
||||
Endpoint: endpoint2,
|
||||
UpdatedAt: epochTime,
|
||||
},
|
||||
}),
|
||||
CreatedAt: createdAt,
|
||||
UpdatedAt: updatedAt,
|
||||
},
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{
|
||||
"string": "updated",
|
||||
"int": int64(999),
|
||||
"real": 9.99,
|
||||
},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"set custom fields",
|
||||
models.UpdateTagInput{
|
||||
Tag: &models.Tag{
|
||||
ID: tagIDs[tagIdxWithGallery],
|
||||
Name: tagNames[tagIdxWithGallery],
|
||||
},
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: testCustomFields,
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"clear custom fields",
|
||||
models.UpdateTagInput{
|
||||
Tag: &models.Tag{
|
||||
ID: tagIDs[tagIdxWithGallery],
|
||||
Name: tagNames[tagIdxWithGallery],
|
||||
},
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid parent id",
|
||||
models.UpdateTagInput{
|
||||
Tag: &models.Tag{
|
||||
ID: tagIDs[tagIdxWithGallery],
|
||||
Name: tagNames[tagIdxWithGallery],
|
||||
ParentIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
{
|
||||
"invalid child id",
|
||||
models.UpdateTagInput{
|
||||
Tag: &models.Tag{
|
||||
ID: tagIDs[tagIdxWithGallery],
|
||||
Name: tagNames[tagIdxWithGallery],
|
||||
ChildIDs: models.NewRelatedIDs([]int{invalidID}),
|
||||
},
|
||||
},
|
||||
true,
|
||||
},
|
||||
}
|
||||
|
||||
qb := db.Tag
|
||||
|
||||
for _, tt := range tests {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
p := tt.updatedObject
|
||||
if err := qb.Update(ctx, &p); (err != nil) != tt.wantErr {
|
||||
t.Errorf("TagStore.Update() error = %v, wantErr = %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
|
||||
s, err := qb.Find(ctx, tt.updatedObject.ID)
|
||||
if err != nil {
|
||||
t.Errorf("TagStore.Find() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// load relationships
|
||||
if err := loadTagRelationships(ctx, *tt.updatedObject.Tag, s); err != nil {
|
||||
t.Errorf("loadTagRelationships() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(*tt.updatedObject.Tag, *s)
|
||||
|
||||
// ensure custom fields are correct
|
||||
if tt.updatedObject.CustomFields.Full != nil {
|
||||
cf, err := qb.GetCustomFields(ctx, tt.updatedObject.ID)
|
||||
if err != nil {
|
||||
t.Errorf("TagStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(tt.updatedObject.CustomFields.Full, cf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TagStore_UpdatePartialCustomFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
id int
|
||||
partial models.TagPartial
|
||||
expected map[string]interface{} // nil to use the partial
|
||||
}{
|
||||
{
|
||||
"set custom fields",
|
||||
tagIDs[tagIdxWithGallery],
|
||||
models.TagPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: testCustomFields,
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"clear custom fields",
|
||||
tagIDs[tagIdxWithGallery],
|
||||
models.TagPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: map[string]interface{}{},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"partial custom fields",
|
||||
tagIDs[tagIdxWithGallery],
|
||||
models.TagPartial{
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Partial: map[string]interface{}{
|
||||
"string": "bbb",
|
||||
"new_field": "new",
|
||||
},
|
||||
},
|
||||
},
|
||||
map[string]interface{}{
|
||||
"int": int64(2),
|
||||
"real": float64(1.7),
|
||||
"string": "bbb",
|
||||
"new_field": "new",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
qb := db.Tag
|
||||
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
_, err := qb.UpdatePartial(ctx, tt.id, tt.partial)
|
||||
if err != nil {
|
||||
t.Errorf("TagStore.UpdatePartial() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// ensure custom fields are correct
|
||||
cf, err := qb.GetCustomFields(ctx, tt.id)
|
||||
if err != nil {
|
||||
t.Errorf("TagStore.GetCustomFields() error = %v", err)
|
||||
return
|
||||
}
|
||||
if tt.expected == nil {
|
||||
assert.Equal(tt.partial.CustomFields.Full, cf)
|
||||
} else {
|
||||
assert.Equal(tt.expected, cf)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagQueryCustomFields(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
filter *models.TagFilterType
|
||||
includeIdxs []int
|
||||
excludeIdxs []int
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"equals",
|
||||
&models.TagFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
Value: []any{getTagStringValue(tagIdxWithGallery, "custom")},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{tagIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not equals",
|
||||
&models.TagFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getTagStringValue(tagIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotEquals,
|
||||
Value: []any{getTagStringValue(tagIdxWithGallery, "custom")},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{tagIdxWithGallery},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"includes",
|
||||
&models.TagFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierIncludes,
|
||||
Value: []any{getTagStringValue(tagIdxWithGallery, "custom")[9:]},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{tagIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"excludes",
|
||||
&models.TagFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getTagStringValue(tagIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierExcludes,
|
||||
Value: []any{getTagStringValue(tagIdxWithGallery, "custom")[9:]},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{tagIdxWithGallery},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"regex",
|
||||
&models.TagFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
Value: []any{".*17_custom"},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{tagIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid regex",
|
||||
&models.TagFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierMatchesRegex,
|
||||
Value: []any{"["},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"not matches regex",
|
||||
&models.TagFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getTagStringValue(tagIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotMatchesRegex,
|
||||
Value: []any{".*17_custom"},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{tagIdxWithGallery},
|
||||
false,
|
||||
},
|
||||
{
|
||||
"invalid not matches regex",
|
||||
&models.TagFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotMatchesRegex,
|
||||
Value: []any{"["},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
"null",
|
||||
&models.TagFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getTagStringValue(tagIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "not existing",
|
||||
Modifier: models.CriterionModifierIsNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{tagIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not null",
|
||||
&models.TagFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getTagStringValue(tagIdxWithGallery, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "string",
|
||||
Modifier: models.CriterionModifierNotNull,
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{tagIdxWithGallery},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"between",
|
||||
&models.TagFilterType{
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "real",
|
||||
Modifier: models.CriterionModifierBetween,
|
||||
Value: []any{0.15, 0.25},
|
||||
},
|
||||
},
|
||||
},
|
||||
[]int{tagIdx2WithScene},
|
||||
nil,
|
||||
false,
|
||||
},
|
||||
{
|
||||
"not between",
|
||||
&models.TagFilterType{
|
||||
Name: &models.StringCriterionInput{
|
||||
Value: getTagStringValue(tagIdx2WithScene, "Name"),
|
||||
Modifier: models.CriterionModifierEquals,
|
||||
},
|
||||
CustomFields: []models.CustomFieldCriterionInput{
|
||||
{
|
||||
Field: "real",
|
||||
Modifier: models.CriterionModifierNotBetween,
|
||||
Value: []any{0.15, 0.25},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
[]int{tagIdx2WithScene},
|
||||
false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
runWithRollbackTxn(t, tt.name, func(t *testing.T, ctx context.Context) {
|
||||
assert := assert.New(t)
|
||||
|
||||
tags, _, err := db.Tag.Query(ctx, tt.filter, nil)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("TagStore.Query() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
ids := tagsToIDs(tags)
|
||||
include := indexesToIDs(tagIDs, tt.includeIdxs)
|
||||
exclude := indexesToIDs(tagIDs, tt.excludeIdxs)
|
||||
|
||||
for _, i := range include {
|
||||
assert.Contains(ids, i)
|
||||
}
|
||||
for _, e := range exclude {
|
||||
assert.NotContains(ids, e)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TODO Destroy
|
||||
// TODO Find
|
||||
// TODO FindBySceneID
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ type FinderImageStashIDGetter interface {
|
|||
models.URLLoader
|
||||
models.StashIDLoader
|
||||
GetImage(ctx context.Context, studioID int) ([]byte, error)
|
||||
models.CustomFieldsReader
|
||||
}
|
||||
|
||||
// ToJSON converts a Studio object into its JSON equivalent.
|
||||
|
|
@ -60,6 +61,12 @@ func ToJSON(ctx context.Context, reader FinderImageStashIDGetter, studio *models
|
|||
}
|
||||
newStudioJSON.StashIDs = studio.StashIDs.List()
|
||||
|
||||
var err error
|
||||
newStudioJSON.CustomFields, err = reader.GetCustomFields(ctx, studio.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting studio custom fields: %v", err)
|
||||
}
|
||||
|
||||
image, err := reader.GetImage(ctx, studio.ID)
|
||||
if err != nil {
|
||||
logger.Errorf("Error getting studio image: %v", err)
|
||||
|
|
|
|||
|
|
@ -18,18 +18,24 @@ const (
|
|||
errImageID = 3
|
||||
missingParentStudioID = 4
|
||||
errStudioID = 5
|
||||
customFieldsID = 6
|
||||
|
||||
parentStudioID = 10
|
||||
missingStudioID = 11
|
||||
errParentStudioID = 12
|
||||
errCustomFieldsID = 13
|
||||
)
|
||||
|
||||
var (
|
||||
studioName = "testStudio"
|
||||
url = "url"
|
||||
details = "details"
|
||||
parentStudioName = "parentStudio"
|
||||
autoTagIgnored = true
|
||||
studioName = "testStudio"
|
||||
url = "url"
|
||||
details = "details"
|
||||
parentStudioName = "parentStudio"
|
||||
autoTagIgnored = true
|
||||
emptyCustomFields = make(map[string]interface{})
|
||||
customFields = map[string]interface{}{
|
||||
"customField1": "customValue1",
|
||||
}
|
||||
)
|
||||
|
||||
var studioID = 1
|
||||
|
|
@ -91,7 +97,7 @@ func createEmptyStudio(id int) models.Studio {
|
|||
}
|
||||
}
|
||||
|
||||
func createFullJSONStudio(parentStudio, image string, aliases []string) *jsonschema.Studio {
|
||||
func createFullJSONStudio(parentStudio, image string, aliases []string, customFields map[string]interface{}) *jsonschema.Studio {
|
||||
return &jsonschema.Studio{
|
||||
Name: studioName,
|
||||
URLs: []string{url},
|
||||
|
|
@ -109,6 +115,7 @@ func createFullJSONStudio(parentStudio, image string, aliases []string) *jsonsch
|
|||
Aliases: aliases,
|
||||
StashIDs: stashIDs,
|
||||
IgnoreAutoTag: autoTagIgnored,
|
||||
CustomFields: customFields,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -120,16 +127,18 @@ func createEmptyJSONStudio() *jsonschema.Studio {
|
|||
UpdatedAt: json.JSONTime{
|
||||
Time: updateTime,
|
||||
},
|
||||
Aliases: []string{},
|
||||
URLs: []string{},
|
||||
StashIDs: []models.StashID{},
|
||||
Aliases: []string{},
|
||||
URLs: []string{},
|
||||
StashIDs: []models.StashID{},
|
||||
CustomFields: emptyCustomFields,
|
||||
}
|
||||
}
|
||||
|
||||
type testScenario struct {
|
||||
input models.Studio
|
||||
expected *jsonschema.Studio
|
||||
err bool
|
||||
input models.Studio
|
||||
customFields map[string]interface{}
|
||||
expected *jsonschema.Studio
|
||||
err bool
|
||||
}
|
||||
|
||||
var scenarios []testScenario
|
||||
|
|
@ -138,30 +147,48 @@ func initTestTable() {
|
|||
scenarios = []testScenario{
|
||||
{
|
||||
createFullStudio(studioID, parentStudioID),
|
||||
createFullJSONStudio(parentStudioName, image, []string{"alias"}),
|
||||
emptyCustomFields,
|
||||
createFullJSONStudio(parentStudioName, image, []string{"alias"}, emptyCustomFields),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createFullStudio(customFieldsID, parentStudioID),
|
||||
customFields,
|
||||
createFullJSONStudio(parentStudioName, image, []string{"alias"}, customFields),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createEmptyStudio(noImageID),
|
||||
emptyCustomFields,
|
||||
createEmptyJSONStudio(),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createFullStudio(errImageID, parentStudioID),
|
||||
createFullJSONStudio(parentStudioName, "", []string{"alias"}),
|
||||
emptyCustomFields,
|
||||
createFullJSONStudio(parentStudioName, "", []string{"alias"}, emptyCustomFields),
|
||||
// failure to get image is not an error
|
||||
false,
|
||||
},
|
||||
{
|
||||
createFullStudio(missingParentStudioID, missingStudioID),
|
||||
createFullJSONStudio("", image, []string{"alias"}),
|
||||
emptyCustomFields,
|
||||
createFullJSONStudio("", image, []string{"alias"}, emptyCustomFields),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createFullStudio(errStudioID, errParentStudioID),
|
||||
emptyCustomFields,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
createFullStudio(errCustomFieldsID, parentStudioID),
|
||||
customFields,
|
||||
nil,
|
||||
// failure to get custom fields should cause an error
|
||||
true,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -177,6 +204,7 @@ func TestToJSON(t *testing.T) {
|
|||
db.Studio.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
db.Studio.On("GetImage", testCtx, missingParentStudioID).Return(imageBytes, nil).Maybe()
|
||||
db.Studio.On("GetImage", testCtx, errStudioID).Return(imageBytes, nil).Maybe()
|
||||
db.Studio.On("GetImage", testCtx, customFieldsID).Return(imageBytes, nil).Once()
|
||||
|
||||
parentStudioErr := errors.New("error getting parent studio")
|
||||
|
||||
|
|
@ -184,6 +212,15 @@ func TestToJSON(t *testing.T) {
|
|||
db.Studio.On("Find", testCtx, missingStudioID).Return(nil, nil)
|
||||
db.Studio.On("Find", testCtx, errParentStudioID).Return(nil, parentStudioErr)
|
||||
|
||||
customFieldsErr := errors.New("error getting custom fields")
|
||||
|
||||
db.Studio.On("GetCustomFields", testCtx, studioID).Return(emptyCustomFields, nil).Once()
|
||||
db.Studio.On("GetCustomFields", testCtx, customFieldsID).Return(customFields, nil).Once()
|
||||
db.Studio.On("GetCustomFields", testCtx, missingParentStudioID).Return(emptyCustomFields, nil).Once()
|
||||
db.Studio.On("GetCustomFields", testCtx, noImageID).Return(emptyCustomFields, nil).Once()
|
||||
db.Studio.On("GetCustomFields", testCtx, errImageID).Return(emptyCustomFields, nil).Once()
|
||||
db.Studio.On("GetCustomFields", testCtx, errCustomFieldsID).Return(nil, customFieldsErr).Once()
|
||||
|
||||
for i, s := range scenarios {
|
||||
studio := s.input
|
||||
json, err := ToJSON(testCtx, db.Studio, &studio)
|
||||
|
|
|
|||
|
|
@ -26,13 +26,15 @@ type Importer struct {
|
|||
Input jsonschema.Studio
|
||||
MissingRefBehaviour models.ImportMissingRefEnum
|
||||
|
||||
ID int
|
||||
studio models.Studio
|
||||
imageData []byte
|
||||
ID int
|
||||
studio models.Studio
|
||||
customFields models.CustomFieldMap
|
||||
imageData []byte
|
||||
}
|
||||
|
||||
func (i *Importer) PreImport(ctx context.Context) error {
|
||||
i.studio = studioJSONtoStudio(i.Input)
|
||||
i.customFields = i.Input.CustomFields
|
||||
|
||||
if err := i.populateParentStudio(ctx); err != nil {
|
||||
return err
|
||||
|
|
@ -110,7 +112,9 @@ func createTags(ctx context.Context, tagWriter models.TagFinderCreator, names []
|
|||
newTag := models.NewTag()
|
||||
newTag.Name = name
|
||||
|
||||
err := tagWriter.Create(ctx, &newTag)
|
||||
err := tagWriter.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &newTag,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -194,7 +198,10 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
|
|||
}
|
||||
|
||||
func (i *Importer) Create(ctx context.Context) (*int, error) {
|
||||
err := i.ReaderWriter.Create(ctx, &models.CreateStudioInput{Studio: &i.studio})
|
||||
err := i.ReaderWriter.Create(ctx, &models.CreateStudioInput{
|
||||
Studio: &i.studio,
|
||||
CustomFields: i.customFields,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating studio: %v", err)
|
||||
}
|
||||
|
|
@ -206,7 +213,12 @@ func (i *Importer) Create(ctx context.Context) (*int, error) {
|
|||
func (i *Importer) Update(ctx context.Context, id int) error {
|
||||
studio := i.studio
|
||||
studio.ID = id
|
||||
err := i.ReaderWriter.Update(ctx, &models.UpdateStudioInput{Studio: &studio})
|
||||
err := i.ReaderWriter.Update(ctx, &models.UpdateStudioInput{
|
||||
Studio: &studio,
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: i.customFields,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating existing studio: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ func TestImporterPreImport(t *testing.T) {
|
|||
|
||||
assert.Nil(t, err)
|
||||
|
||||
i.Input = *createFullJSONStudio(studioName, image, []string{"alias"})
|
||||
i.Input = *createFullJSONStudio(studioName, image, []string{"alias"}, customFields)
|
||||
i.Input.ParentStudio = ""
|
||||
|
||||
err = i.PreImport(testCtx)
|
||||
|
|
@ -71,6 +71,7 @@ func TestImporterPreImport(t *testing.T) {
|
|||
expectedStudio := createFullStudio(0, 0)
|
||||
expectedStudio.ParentID = nil
|
||||
assert.Equal(t, expectedStudio, i.studio)
|
||||
assert.Equal(t, models.CustomFieldMap(customFields), i.customFields)
|
||||
}
|
||||
|
||||
func TestImporterPreImportWithTag(t *testing.T) {
|
||||
|
|
@ -121,9 +122,9 @@ func TestImporterPreImportWithMissingTag(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Times(3)
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = existingTagID
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.CreateTagInput)
|
||||
t.Tag.ID = existingTagID
|
||||
}).Return(nil)
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
|
|
@ -156,7 +157,7 @@ func TestImporterPreImportWithMissingTagCreateErr(t *testing.T) {
|
|||
}
|
||||
|
||||
db.Tag.On("FindByNames", testCtx, []string{missingTagName}, false).Return(nil, nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.Tag")).Return(errors.New("Create error"))
|
||||
db.Tag.On("Create", testCtx, mock.AnythingOfType("*models.CreateTagInput")).Return(errors.New("Create error"))
|
||||
|
||||
err := i.PreImport(testCtx)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ type FinderAliasImageGetter interface {
|
|||
GetAliases(ctx context.Context, studioID int) ([]string, error)
|
||||
GetImage(ctx context.Context, tagID int) ([]byte, error)
|
||||
FindByChildTagID(ctx context.Context, childID int) ([]*models.Tag, error)
|
||||
GetCustomFields(ctx context.Context, id int) (map[string]interface{}, error)
|
||||
models.StashIDLoader
|
||||
}
|
||||
|
||||
|
|
@ -63,6 +64,11 @@ func ToJSON(ctx context.Context, reader FinderAliasImageGetter, tag *models.Tag)
|
|||
|
||||
newTagJSON.Parents = GetNames(parents)
|
||||
|
||||
newTagJSON.CustomFields, err = reader.GetCustomFields(ctx, tag.ID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("getting tag custom fields: %v", err)
|
||||
}
|
||||
|
||||
return &newTagJSON, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,12 +14,14 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
tagID = 1
|
||||
noImageID = 2
|
||||
errImageID = 3
|
||||
errAliasID = 4
|
||||
withParentsID = 5
|
||||
errParentsID = 6
|
||||
tagID = iota + 1
|
||||
customFieldsID
|
||||
noImageID
|
||||
errImageID
|
||||
errAliasID
|
||||
withParentsID
|
||||
errParentsID
|
||||
errCustomFieldsID
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -32,6 +34,11 @@ var (
|
|||
autoTagIgnored = true
|
||||
createTime = time.Date(2001, 01, 01, 0, 0, 0, 0, time.UTC)
|
||||
updateTime = time.Date(2002, 01, 01, 0, 0, 0, 0, time.UTC)
|
||||
|
||||
emptyCustomFields = make(map[string]interface{})
|
||||
customFields = map[string]interface{}{
|
||||
"customField1": "customValue1",
|
||||
}
|
||||
)
|
||||
|
||||
func createTag(id int) models.Tag {
|
||||
|
|
@ -47,8 +54,8 @@ func createTag(id int) models.Tag {
|
|||
}
|
||||
}
|
||||
|
||||
func createJSONTag(aliases []string, image string, parents []string) *jsonschema.Tag {
|
||||
return &jsonschema.Tag{
|
||||
func createJSONTag(aliases []string, image string, parents []string, withCustomFields bool) *jsonschema.Tag {
|
||||
ret := &jsonschema.Tag{
|
||||
Name: tagName,
|
||||
SortName: sortName,
|
||||
Favorite: true,
|
||||
|
|
@ -61,15 +68,23 @@ func createJSONTag(aliases []string, image string, parents []string) *jsonschema
|
|||
UpdatedAt: json.JSONTime{
|
||||
Time: updateTime,
|
||||
},
|
||||
Image: image,
|
||||
Parents: parents,
|
||||
Image: image,
|
||||
Parents: parents,
|
||||
CustomFields: emptyCustomFields,
|
||||
}
|
||||
|
||||
if withCustomFields {
|
||||
ret.CustomFields = customFields
|
||||
}
|
||||
|
||||
return ret
|
||||
}
|
||||
|
||||
type testScenario struct {
|
||||
tag models.Tag
|
||||
expected *jsonschema.Tag
|
||||
err bool
|
||||
tag models.Tag
|
||||
customFields map[string]interface{}
|
||||
expected *jsonschema.Tag
|
||||
err bool
|
||||
}
|
||||
|
||||
var scenarios []testScenario
|
||||
|
|
@ -78,32 +93,50 @@ func initTestTable() {
|
|||
scenarios = []testScenario{
|
||||
{
|
||||
createTag(tagID),
|
||||
createJSONTag([]string{"alias"}, image, nil),
|
||||
emptyCustomFields,
|
||||
createJSONTag([]string{"alias"}, image, nil, false),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createTag(customFieldsID),
|
||||
customFields,
|
||||
createJSONTag([]string{"alias"}, image, nil, true),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createTag(noImageID),
|
||||
createJSONTag(nil, "", nil),
|
||||
emptyCustomFields,
|
||||
createJSONTag(nil, "", nil, false),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createTag(errImageID),
|
||||
createJSONTag(nil, "", nil),
|
||||
emptyCustomFields,
|
||||
createJSONTag(nil, "", nil, false),
|
||||
// getting the image should not cause an error
|
||||
false,
|
||||
},
|
||||
{
|
||||
createTag(errAliasID),
|
||||
emptyCustomFields,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
createTag(withParentsID),
|
||||
createJSONTag(nil, image, []string{"parent"}),
|
||||
emptyCustomFields,
|
||||
createJSONTag(nil, image, []string{"parent"}, false),
|
||||
false,
|
||||
},
|
||||
{
|
||||
createTag(errParentsID),
|
||||
emptyCustomFields,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
{
|
||||
createTag(errCustomFieldsID),
|
||||
customFields,
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
|
|
@ -118,32 +151,48 @@ func TestToJSON(t *testing.T) {
|
|||
imageErr := errors.New("error getting image")
|
||||
aliasErr := errors.New("error getting aliases")
|
||||
parentsErr := errors.New("error getting parents")
|
||||
customFieldsErr := errors.New("error getting custom fields")
|
||||
|
||||
db.Tag.On("GetAliases", testCtx, tagID).Return([]string{"alias"}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, customFieldsID).Return([]string{"alias"}, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, errImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, errAliasID).Return(nil, aliasErr).Once()
|
||||
db.Tag.On("GetAliases", testCtx, withParentsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, errParentsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetAliases", testCtx, errCustomFieldsID).Return(nil, nil).Once()
|
||||
|
||||
db.Tag.On("GetStashIDs", testCtx, tagID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetStashIDs", testCtx, customFieldsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetStashIDs", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetStashIDs", testCtx, errImageID).Return(nil, nil).Once()
|
||||
// errAliasID test fails before GetStashIDs is called, so no mock needed
|
||||
db.Tag.On("GetStashIDs", testCtx, withParentsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetStashIDs", testCtx, errParentsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetStashIDs", testCtx, errCustomFieldsID).Return(nil, nil).Once()
|
||||
|
||||
db.Tag.On("GetImage", testCtx, tagID).Return(imageBytes, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, customFieldsID).Return(imageBytes, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, errImageID).Return(nil, imageErr).Once()
|
||||
db.Tag.On("GetImage", testCtx, withParentsID).Return(imageBytes, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, errParentsID).Return(nil, nil).Once()
|
||||
db.Tag.On("GetImage", testCtx, errCustomFieldsID).Return(nil, nil).Once()
|
||||
|
||||
db.Tag.On("FindByChildTagID", testCtx, tagID).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, customFieldsID).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, noImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, withParentsID).Return([]*models.Tag{{Name: "parent"}}, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, errParentsID).Return(nil, parentsErr).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, errImageID).Return(nil, nil).Once()
|
||||
db.Tag.On("FindByChildTagID", testCtx, errCustomFieldsID).Return(nil, nil).Once()
|
||||
|
||||
db.Tag.On("GetCustomFields", testCtx, tagID).Return(emptyCustomFields, nil).Once()
|
||||
db.Tag.On("GetCustomFields", testCtx, customFieldsID).Return(customFields, nil).Once()
|
||||
db.Tag.On("GetCustomFields", testCtx, noImageID).Return(emptyCustomFields, nil).Once()
|
||||
db.Tag.On("GetCustomFields", testCtx, errImageID).Return(emptyCustomFields, nil).Once()
|
||||
db.Tag.On("GetCustomFields", testCtx, withParentsID).Return(emptyCustomFields, nil).Once()
|
||||
db.Tag.On("GetCustomFields", testCtx, errCustomFieldsID).Return(nil, customFieldsErr).Once()
|
||||
|
||||
for i, s := range scenarios {
|
||||
tag := s.tag
|
||||
|
|
|
|||
|
|
@ -31,8 +31,9 @@ type Importer struct {
|
|||
Input jsonschema.Tag
|
||||
MissingRefBehaviour models.ImportMissingRefEnum
|
||||
|
||||
tag models.Tag
|
||||
imageData []byte
|
||||
tag models.Tag
|
||||
imageData []byte
|
||||
customFields map[string]interface{}
|
||||
}
|
||||
|
||||
func (i *Importer) PreImport(ctx context.Context) error {
|
||||
|
|
@ -55,6 +56,8 @@ func (i *Importer) PreImport(ctx context.Context) error {
|
|||
}
|
||||
}
|
||||
|
||||
i.customFields = i.Input.CustomFields
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -78,6 +81,14 @@ func (i *Importer) PostImport(ctx context.Context, id int) error {
|
|||
return fmt.Errorf("error setting parents: %v", err)
|
||||
}
|
||||
|
||||
if len(i.customFields) > 0 {
|
||||
if err := i.ReaderWriter.SetCustomFields(ctx, id, models.CustomFieldsInput{
|
||||
Full: i.customFields,
|
||||
}); err != nil {
|
||||
return fmt.Errorf("error setting tag custom fields: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -101,7 +112,10 @@ func (i *Importer) FindExistingID(ctx context.Context) (*int, error) {
|
|||
}
|
||||
|
||||
func (i *Importer) Create(ctx context.Context) (*int, error) {
|
||||
err := i.ReaderWriter.Create(ctx, &i.tag)
|
||||
err := i.ReaderWriter.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &i.tag,
|
||||
CustomFields: i.customFields,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating tag: %v", err)
|
||||
}
|
||||
|
|
@ -113,7 +127,12 @@ func (i *Importer) Create(ctx context.Context) (*int, error) {
|
|||
func (i *Importer) Update(ctx context.Context, id int) error {
|
||||
tag := i.tag
|
||||
tag.ID = id
|
||||
err := i.ReaderWriter.Update(ctx, &tag)
|
||||
err := i.ReaderWriter.Update(ctx, &models.UpdateTagInput{
|
||||
Tag: &tag,
|
||||
CustomFields: models.CustomFieldsInput{
|
||||
Full: i.customFields,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating existing tag: %v", err)
|
||||
}
|
||||
|
|
@ -157,7 +176,9 @@ func (i *Importer) createParent(ctx context.Context, name string) (int, error) {
|
|||
newTag := models.NewTag()
|
||||
newTag.Name = name
|
||||
|
||||
err := i.ReaderWriter.Create(ctx, &newTag)
|
||||
err := i.ReaderWriter.Create(ctx, &models.CreateTagInput{
|
||||
Tag: &newTag,
|
||||
})
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -154,14 +154,14 @@ func TestImporterPostImportParentMissing(t *testing.T) {
|
|||
db.Tag.On("UpdateParentTags", testCtx, ignoreID, emptyParents).Return(nil).Once()
|
||||
db.Tag.On("UpdateParentTags", testCtx, ignoreFoundID, []int{103}).Return(nil).Once()
|
||||
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
|
||||
return t.Name == "Create"
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(input *models.CreateTagInput) bool {
|
||||
return input.Tag.Name == "Create"
|
||||
})).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = 100
|
||||
input := args.Get(1).(*models.CreateTagInput)
|
||||
input.Tag.ID = 100
|
||||
}).Return(nil).Once()
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(t *models.Tag) bool {
|
||||
return t.Name == "CreateError"
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(input *models.CreateTagInput) bool {
|
||||
return input.Tag.Name == "CreateError"
|
||||
})).Return(errors.New("failed creating parent")).Once()
|
||||
|
||||
i.MissingRefBehaviour = models.ImportMissingRefEnumCreate
|
||||
|
|
@ -261,11 +261,15 @@ func TestCreate(t *testing.T) {
|
|||
}
|
||||
|
||||
errCreate := errors.New("Create error")
|
||||
db.Tag.On("Create", testCtx, &tag).Run(func(args mock.Arguments) {
|
||||
t := args.Get(1).(*models.Tag)
|
||||
t.ID = tagID
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(input *models.CreateTagInput) bool {
|
||||
return input.Tag.Name == tag.Name
|
||||
})).Run(func(args mock.Arguments) {
|
||||
input := args.Get(1).(*models.CreateTagInput)
|
||||
input.Tag.ID = tagID
|
||||
}).Return(nil).Once()
|
||||
db.Tag.On("Create", testCtx, &tagErr).Return(errCreate).Once()
|
||||
db.Tag.On("Create", testCtx, mock.MatchedBy(func(input *models.CreateTagInput) bool {
|
||||
return input.Tag.Name == tagErr.Name
|
||||
})).Return(errCreate).Once()
|
||||
|
||||
id, err := i.Create(testCtx)
|
||||
assert.Equal(t, tagID, *id)
|
||||
|
|
@ -299,7 +303,10 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// id needs to be set for the mock input
|
||||
tag.ID = tagID
|
||||
db.Tag.On("Update", testCtx, &tag).Return(nil).Once()
|
||||
tagInput := models.UpdateTagInput{
|
||||
Tag: &tag,
|
||||
}
|
||||
db.Tag.On("Update", testCtx, &tagInput).Return(nil).Once()
|
||||
|
||||
err := i.Update(testCtx, tagID)
|
||||
assert.Nil(t, err)
|
||||
|
|
@ -308,7 +315,10 @@ func TestUpdate(t *testing.T) {
|
|||
|
||||
// need to set id separately
|
||||
tagErr.ID = errImageID
|
||||
db.Tag.On("Update", testCtx, &tagErr).Return(errUpdate).Once()
|
||||
errInput := models.UpdateTagInput{
|
||||
Tag: &tagErr,
|
||||
}
|
||||
db.Tag.On("Update", testCtx, &errInput).Return(errUpdate).Once()
|
||||
|
||||
err = i.Update(testCtx, errImageID)
|
||||
assert.NotNil(t, err)
|
||||
|
|
|
|||
Loading…
Reference in a new issue