From 2da5a67e82c2c58f15287657cbc4315fd3fe422b Mon Sep 17 00:00:00 2001 From: MickaelK Date: Fri, 29 Aug 2025 16:44:26 +1000 Subject: [PATCH] feature (plg_backend_psql): psql as a storage --- server/plugin/index.go | 1 + server/plugin/plg_backend_psql/index.go | 52 +++++--- server/plugin/plg_backend_psql/index_cat.go | 86 ++++++++++--- server/plugin/plg_backend_psql/index_ls.go | 13 +- server/plugin/plg_backend_psql/index_save.go | 111 +++++++++++++++++ server/plugin/plg_backend_psql/types.go | 13 +- server/plugin/plg_backend_psql/utils.go | 124 +++++++++++-------- server/plugin/plg_metadata_sqlite/state.go | 1 - 8 files changed, 308 insertions(+), 93 deletions(-) create mode 100644 server/plugin/plg_backend_psql/index_save.go delete mode 100644 server/plugin/plg_metadata_sqlite/state.go diff --git a/server/plugin/index.go b/server/plugin/index.go index e963ad43..e60f7d99 100644 --- a/server/plugin/index.go +++ b/server/plugin/index.go @@ -20,6 +20,7 @@ import ( _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_mysql" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_nfs" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_nop" + _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_psql" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_s3" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_samba" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_sftp" diff --git a/server/plugin/plg_backend_psql/index.go b/server/plugin/plg_backend_psql/index.go index b1a60740..100d9609 100644 --- a/server/plugin/plg_backend_psql/index.go +++ b/server/plugin/plg_backend_psql/index.go @@ -4,13 +4,15 @@ import ( "context" "database/sql" "fmt" - "io" + "strings" . "github.com/mickael-kerjean/filestash/server/common" _ "github.com/lib/pq" ) +var PGCache AppCache + type PSQL struct { db *sql.DB ctx context.Context @@ -18,9 +20,20 @@ type PSQL struct { func init() { Backend.Register("psql", PSQL{}) + + PGCache = NewAppCache(2, 1) + PGCache.OnEvict(func(key string, value interface{}) { + c := value.(*PSQL) + c.Close() + }) } func (this PSQL) Init(params map[string]string, app *App) (IBackend, error) { + if d := PGCache.Get(params); d != nil { + backend := d.(*PSQL) + backend.ctx = app.Context + return backend, nil + } host := params["host"] port := withDefault(params["port"], "5432") user := params["user"] @@ -44,10 +57,12 @@ func (this PSQL) Init(params map[string]string, app *App) (IBackend, error) { Log.Debug("plg_backend_psql::init err=%s", err.Error()) return nil, ErrNotValid } - return PSQL{ + backend := &PSQL{ db: db, ctx: app.Context, - }, nil + } + PGCache.Set(params, backend) + return backend, nil } func withDefault(val string, def string) string { @@ -99,28 +114,22 @@ func (this PSQL) LoginForm() Form { } } -func (this PSQL) Touch(path string) error { // TODO - this.db.Close() - return ErrNotImplemented +func (this PSQL) Touch(path string) error { + if !strings.HasSuffix(path, ".form") { + return ErrNotValid + } + return nil } -func (this PSQL) Save(path string, file io.Reader) error { // TODO - this.db.Close() - return ErrNotImplemented -} - -func (this PSQL) Rm(path string) error { // TODO - this.db.Close() - return ErrNotImplemented +func (this PSQL) Rm(path string) error { + return ErrNotAuthorized } func (this PSQL) Mkdir(path string) error { - this.db.Close() return ErrNotValid } func (this PSQL) Mv(from string, to string) error { - this.db.Close() return ErrNotValid } @@ -128,21 +137,21 @@ func (this PSQL) Meta(path string) Metadata { location, _ := getPath(path) return Metadata{ CanCreateDirectory: NewBool(false), - CanCreateFile: func(l Location) *bool { + CanCreateFile: func(l LocationRow) *bool { if l.table == "" { return NewBool(false) } return NewBool(true) }(location), CanRename: NewBool(false), - CanDelete: func(l Location) *bool { + CanDelete: func(l LocationRow) *bool { if l.table == "" { return NewBool(false) } return NewBool(true) }(location), CanMove: NewBool(false), - CanUpload: func(l Location) *bool { + CanUpload: func(l LocationRow) *bool { if l.row == "" { return NewBool(false) } @@ -152,3 +161,8 @@ func (this PSQL) Meta(path string) Metadata { HideExtension: NewBool(true), } } + +func (this PSQL) Close() error { + this.db.Close() + return nil +} diff --git a/server/plugin/plg_backend_psql/index_cat.go b/server/plugin/plg_backend_psql/index_cat.go index 8bfbbc53..7a287bc7 100644 --- a/server/plugin/plg_backend_psql/index_cat.go +++ b/server/plugin/plg_backend_psql/index_cat.go @@ -1,26 +1,29 @@ package plg_backend_psql import ( + "context" + "database/sql" + "fmt" "io" + "slices" . "github.com/mickael-kerjean/filestash/server/common" ) func (this PSQL) Cat(path string) (io.ReadCloser, error) { - defer this.db.Close() l, err := getPath(path) if err != nil { return nil, err } - columnName, err := getKey(this.ctx, this.db, l.table) + columns, columnName, err := processTable(this.ctx, this.db, l.table) if err != nil { return nil, err } rows, err := this.db.QueryContext(this.ctx, ` SELECT * - FROM `+l.table+` - WHERE `+columnName+`='`+l.row+`' - `) + FROM "`+l.table+`" + WHERE "`+columnName+`"=$1 + `, l.row) if err != nil { return nil, err } @@ -28,13 +31,12 @@ func (this PSQL) Cat(path string) (io.ReadCloser, error) { c, err := rows.Columns() if err != nil { return nil, err - } - t, err := rows.ColumnTypes() - if err != nil { - return nil, err + } else if len(columns) != len(c) { + Log.Error("plg_backend_psql::index_cat columns is not of the expected size columns[%d]=%v c[%d]=%v", len(columns), columns, len(c), c) + return nil, ErrNotValid } i := 0 - col := make([]any, len(c)) + col := make([]interface{}, len(c)) for rows.Next() { if i != 0 { return nil, ErrNotValid @@ -48,11 +50,19 @@ func (this PSQL) Cat(path string) (io.ReadCloser, error) { } } forms := make([]FormElement, len(c)) - for i, _ := range c { - f := formType(t[i].ScanType(), c[i]) - f.Name = c[i] - f.Value = col[i] - forms[i] = f + for i, _ := range columns { + forms[i] = createFormElement(col[i], columns[i]) + if slices.Contains(columns[i].Constraint, "PRIMARY KEY") { + forms[i].ReadOnly = true + } else if slices.Contains(columns[i].Constraint, "FOREIGN KEY") { + if link, err := _findRelation(this.ctx, this.db, columns[i]); err == nil { + forms[i].Description = _createDescription(columns[i], link) + if len(link.values) > 0 { + forms[i].Type = "select" + forms[i].Opts = link.values + } + } + } } b, err := Form{Elmnts: forms}.MarshalJSON() if err != nil { @@ -60,3 +70,49 @@ func (this PSQL) Cat(path string) (io.ReadCloser, error) { } return NewReadCloserFromBytes(b), nil } + +func _createDescription(el Column, link LocationColumn) string { + if slices.Contains(el.Constraint, "FOREIGN KEY") { + return fmt.Sprintf("points to <%s> → <%s>", link.table, link.column) + } + return "" +} + +func _findRelation(ctx context.Context, db *sql.DB, el Column) (LocationColumn, error) { + l := LocationColumn{} + rows, err := db.QueryContext(ctx, ` + SELECT ccu.table_name, ccu.column_name + FROM information_schema.table_constraints AS tc + JOIN information_schema.key_column_usage AS kcu USING (constraint_name) + JOIN information_schema.constraint_column_usage AS ccu USING (constraint_name) + WHERE tc.constraint_type = 'FOREIGN KEY' + AND tc.table_name = $1 + AND kcu.column_name = $2 + `, el.Table, el.Name) + if err != nil { + return l, err + } + defer rows.Close() + for rows.Next() { + if err := rows.Scan(&l.table, &l.column); err != nil { + return l, err + } + } + valueRows, err := db.QueryContext(ctx, fmt.Sprintf( + `SELECT DISTINCT "%s" FROM "%s" ORDER BY "%s" LIMIT 5000`, + l.column, l.table, l.column, + )) + if err != nil { + return l, err + } + defer valueRows.Close() + l.values = []string{} + for valueRows.Next() { + var value string + if err := valueRows.Scan(&value); err != nil { + return l, err + } + l.values = append(l.values, value) + } + return l, nil +} diff --git a/server/plugin/plg_backend_psql/index_ls.go b/server/plugin/plg_backend_psql/index_ls.go index a8d587fa..17d63c55 100644 --- a/server/plugin/plg_backend_psql/index_ls.go +++ b/server/plugin/plg_backend_psql/index_ls.go @@ -7,14 +7,17 @@ import ( ) func (this PSQL) Ls(path string) ([]os.FileInfo, error) { - defer this.db.Close() l, err := getPath(path) if err != nil { Log.Debug("pl_backend_psql::ls method=getPath err=%s", err.Error()) return nil, err } if l.table == "" { - rows, err := this.db.QueryContext(this.ctx, "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") + rows, err := this.db.QueryContext(this.ctx, ` + SELECT table_name FROM information_schema.tables + WHERE table_schema = 'public' + AND table_type = 'BASE TABLE' + `) if err != nil { Log.Debug("plg_backend_psql::ls method=query err=%s", err.Error()) return nil, err @@ -34,12 +37,12 @@ func (this PSQL) Ls(path string) ([]os.FileInfo, error) { } return out, nil } else if l.row == "" { - key, err := getKey(this.ctx, this.db, l.table) + _, key, err := processTable(this.ctx, this.db, l.table) if err != nil { - Log.Debug("plg_backend_psql::ls method=getKey err=%s", err.Error()) + Log.Debug("plg_backend_psql::ls method=processTable err=%s", err.Error()) return nil, err } - rows, err := this.db.QueryContext(this.ctx, "SELECT "+key+" FROM "+l.table+" LIMIT 500000") + rows, err := this.db.QueryContext(this.ctx, `SELECT "`+key+`" FROM "`+l.table+`" LIMIT 500000`) if err != nil { Log.Debug("plg_backend_psql::ls method=query err=%s", err.Error()) return nil, err diff --git a/server/plugin/plg_backend_psql/index_save.go b/server/plugin/plg_backend_psql/index_save.go new file mode 100644 index 00000000..c5f5f525 --- /dev/null +++ b/server/plugin/plg_backend_psql/index_save.go @@ -0,0 +1,111 @@ +package plg_backend_psql + +import ( + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "slices" + "strings" + + . "github.com/mickael-kerjean/filestash/server/common" +) + +func (this PSQL) Save(path string, file io.Reader) error { + l, err := getPath(path) + if err != nil { + return err + } + columns, key, err := processTable(this.ctx, this.db, l.table) + if err != nil { + return err + } + f := map[string]FormElement{} + if err := json.NewDecoder(file).Decode(&f); err != nil { + return err + } + tx, err := this.db.BeginTx(this.ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + rows, err := tx.QueryContext(this.ctx, `SELECT * FROM "`+l.table+`" WHERE "`+key+`" = $1`, l.row) + if err != nil { + return err + } + i := 0 + dbvals := make([]any, len(columns)) + for rows.Next() { + currentPtrs := make([]any, len(columns)) + for i := range dbvals { + currentPtrs[i] = &dbvals[i] + } + if serr := rows.Scan(currentPtrs...); serr != nil { + rows.Close() + err = serr + break + } else if i >= 1 { + err = ErrNotValid + break + } + i += 1 + } + rows.Close() + if i == 0 { + err = _createRow(tx, this.ctx, l.table, columns, f) + } + if err == nil && i == 1 { + err = _updateRow(tx, this.ctx, l.table, columns, f, key, l.row, dbvals) + } + if err != nil { + return err + } + return tx.Commit() +} + +func _createRow(tx *sql.Tx, ctx context.Context, table string, columns []Column, f map[string]FormElement) error { + colNames := []string{} + placeholders := []string{} + values := []interface{}{} + paramIndex := 1 + for _, col := range columns { + if formEl, exists := f[col.Name]; exists { + if slices.Contains(col.Constraint, "PRIMARY KEY") && col.Default { + continue + } + colNames = append(colNames, `"`+col.Name+`"`) + placeholders = append(placeholders, fmt.Sprintf("$%d", paramIndex)) + values = append(values, formEl.Value) + paramIndex++ + } + } + if len(colNames) == 0 { + return ErrNotValid + } + _, err := tx.ExecContext( + ctx, + `INSERT INTO "`+table+`" (`+strings.Join(colNames, ", ")+`) VALUES (`+strings.Join(placeholders, ", ")+`)`, + values..., + ) + return err +} + +func _updateRow(tx *sql.Tx, ctx context.Context, table string, columns []Column, f map[string]FormElement, keyName string, keyValue any, dbvals []any) error { + for i, col := range columns { + dbval := convertFromDB(dbvals[i]) + formval, ok := f[col.Name] + if !ok || formval.Value == dbval { + continue + } + if _, err := tx.ExecContext( + ctx, + `UPDATE "`+table+`" SET "`+col.Name+`" = $1 WHERE "`+keyName+`" = $2`, + formval.Value, keyValue, + ); err != nil { + return err + } + } + return nil +} diff --git a/server/plugin/plg_backend_psql/types.go b/server/plugin/plg_backend_psql/types.go index 4b40137a..f22fd1bc 100644 --- a/server/plugin/plg_backend_psql/types.go +++ b/server/plugin/plg_backend_psql/types.go @@ -1,12 +1,21 @@ package plg_backend_psql type Column struct { + Table string Name string Type string - Constraint string + Nullable bool + Default bool + Constraint []string } -type Location struct { +type LocationRow struct { table string row string } + +type LocationColumn struct { + table string + column string + values []string +} diff --git a/server/plugin/plg_backend_psql/utils.go b/server/plugin/plg_backend_psql/utils.go index e5353613..f542932b 100644 --- a/server/plugin/plg_backend_psql/utils.go +++ b/server/plugin/plg_backend_psql/utils.go @@ -3,20 +3,24 @@ package plg_backend_psql import ( "context" "database/sql" - "reflect" + "slices" "strings" + "time" . "github.com/mickael-kerjean/filestash/server/common" ) -func getPath(path string) (Location, error) { - l := Location{} +func getPath(path string) (LocationRow, error) { + l := LocationRow{} for i, chunk := range strings.Split(path, "/") { if i == 0 { if chunk != "" { return l, ErrNotValid } } else if i == 1 { + if strings.Contains(chunk, `"`) { + return l, ErrNotValid + } l.table = chunk } else if i == 2 { l.row = strings.TrimSuffix(chunk, ".form") @@ -27,14 +31,39 @@ func getPath(path string) (Location, error) { return l, nil } -func getColumns(ctx context.Context, db *sql.DB, table string) ([]Column, error) { +func processTable(ctx context.Context, db *sql.DB, table string) ([]Column, string, error) { + columns, err := _getColumns(ctx, db, table) + if err != nil { + return nil, "", err + } + key := "" + score := 0 + for _, column := range columns { + if c := _calculateScore(column); c > score { + key = column.Name + score = c + } + } + if key == "" { + return columns, "", ErrNotValid + } + return columns, key, nil +} + +func _getColumns(ctx context.Context, db *sql.DB, table string) ([]Column, error) { rows, err := db.QueryContext(ctx, ` - SELECT c.column_name, c.data_type, tc.constraint_type - FROM information_schema.table_constraints tc - JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) - JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema - AND tc.table_name = c.table_name AND ccu.column_name = c.column_name - WHERE tc.table_name = $1 + SELECT + c.column_name, + c.udt_name as type, + (c.is_nullable = 'YES') AS nullable, + (c.column_default IS NOT NULL) AS has_default, + coalesce(string_agg(tc.constraint_type, ', '), '') as constraint + FROM information_schema.columns AS c + LEFT JOIN information_schema.key_column_usage kcu USING (table_name, column_name) + LEFT JOIN information_schema.table_constraints tc USING (table_name, constraint_name) + WHERE c.table_name = $1 + GROUP BY c.column_name, c.is_nullable, c.udt_name, c.column_default + ORDER BY MIN(c.ordinal_position) `, table) if err != nil { return nil, err @@ -42,40 +71,23 @@ func getColumns(ctx context.Context, db *sql.DB, table string) ([]Column, error) columns := []Column{} for rows.Next() { var c Column - if err := rows.Scan(&c.Name, &c.Type, &c.Type); err != nil { + var constraints string + if err := rows.Scan(&c.Name, &c.Type, &c.Nullable, &c.Default, &constraints); err != nil { return nil, err } + c.Constraint = strings.Split(constraints, ", ") + c.Table = table columns = append(columns, c) } - return columns, nil + return columns, rows.Close() } -func getKey(ctx context.Context, db *sql.DB, table string) (string, error) { - columns, err := getColumns(ctx, db, table) - if err != nil { - return "", err - } - key := "" - score := 0 - for _, column := range columns { - if c := calculateScore(column); c > score { - key = column.Name - score = c - } - } - if key == "" { - return "", ErrNotValid - } - return key, nil -} - -func calculateScore(column Column) int { +func _calculateScore(column Column) int { scoreType := 0 scoreName := 1 - switch column.Type { - case "PRIMARY KEY": + if slices.Contains(column.Constraint, "PRIMARY KEY") { scoreType = 3 - case "UNIQUE": + } else if slices.Contains(column.Constraint, "UNIQUE") { scoreType = 2 } switch strings.ToLower(column.Name) { @@ -89,23 +101,33 @@ func calculateScore(column Column) int { return scoreType * scoreName } -func formType(rt reflect.Type, label string) FormElement { - switch rt.String() { - case "bool": - return FormElement{ - Type: "boolean", - } - case "time.Time": - return FormElement{ - Type: "datetime", - } +func convertFromDB(val any) any { + switch tmp := val.(type) { + case []byte: + return string(tmp) + case time.Time: + return tmp.UTC().Format("2006-01-02T15:04") } - if strings.Contains(strings.ToLower(label), "password") { - return FormElement{ - Type: "password", - } - } - return FormElement{ + return val +} + +func createFormElement(val any, column Column) FormElement { + f := FormElement{ Type: "text", } + switch val.(type) { + case bool: + f.Type = "boolean" + case time.Time: + f.Type = "datetime" + } + f.Value = convertFromDB(val) + + f.Name = column.Name + f.Required = !column.Nullable && !column.Default + + if strings.Contains(strings.ToLower(column.Name), "password") { + f.Type = "password" + } + return f } diff --git a/server/plugin/plg_metadata_sqlite/state.go b/server/plugin/plg_metadata_sqlite/state.go deleted file mode 100644 index be08eaa2..00000000 --- a/server/plugin/plg_metadata_sqlite/state.go +++ /dev/null @@ -1 +0,0 @@ -package plg_metadata_sqlite