feature (plg_backend_psql): psql as a storage

This commit is contained in:
MickaelK 2025-08-29 16:44:26 +10:00
parent ab92884439
commit 2da5a67e82
8 changed files with 308 additions and 93 deletions

View file

@ -20,6 +20,7 @@ import (
_ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_mysql" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_mysql"
_ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_nfs" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_nfs"
_ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_nop" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_nop"
_ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_psql"
_ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_s3" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_s3"
_ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_samba" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_samba"
_ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_sftp" _ "github.com/mickael-kerjean/filestash/server/plugin/plg_backend_sftp"

View file

@ -4,13 +4,15 @@ import (
"context" "context"
"database/sql" "database/sql"
"fmt" "fmt"
"io" "strings"
. "github.com/mickael-kerjean/filestash/server/common" . "github.com/mickael-kerjean/filestash/server/common"
_ "github.com/lib/pq" _ "github.com/lib/pq"
) )
var PGCache AppCache
type PSQL struct { type PSQL struct {
db *sql.DB db *sql.DB
ctx context.Context ctx context.Context
@ -18,9 +20,20 @@ type PSQL struct {
func init() { func init() {
Backend.Register("psql", PSQL{}) Backend.Register("psql", PSQL{})
PGCache = NewAppCache(2, 1)
PGCache.OnEvict(func(key string, value interface{}) {
c := value.(*PSQL)
c.Close()
})
} }
func (this PSQL) Init(params map[string]string, app *App) (IBackend, error) { func (this PSQL) Init(params map[string]string, app *App) (IBackend, error) {
if d := PGCache.Get(params); d != nil {
backend := d.(*PSQL)
backend.ctx = app.Context
return backend, nil
}
host := params["host"] host := params["host"]
port := withDefault(params["port"], "5432") port := withDefault(params["port"], "5432")
user := params["user"] user := params["user"]
@ -44,10 +57,12 @@ func (this PSQL) Init(params map[string]string, app *App) (IBackend, error) {
Log.Debug("plg_backend_psql::init err=%s", err.Error()) Log.Debug("plg_backend_psql::init err=%s", err.Error())
return nil, ErrNotValid return nil, ErrNotValid
} }
return PSQL{ backend := &PSQL{
db: db, db: db,
ctx: app.Context, ctx: app.Context,
}, nil }
PGCache.Set(params, backend)
return backend, nil
} }
func withDefault(val string, def string) string { func withDefault(val string, def string) string {
@ -99,28 +114,22 @@ func (this PSQL) LoginForm() Form {
} }
} }
func (this PSQL) Touch(path string) error { // TODO func (this PSQL) Touch(path string) error {
this.db.Close() if !strings.HasSuffix(path, ".form") {
return ErrNotImplemented return ErrNotValid
}
return nil
} }
func (this PSQL) Save(path string, file io.Reader) error { // TODO func (this PSQL) Rm(path string) error {
this.db.Close() return ErrNotAuthorized
return ErrNotImplemented
}
func (this PSQL) Rm(path string) error { // TODO
this.db.Close()
return ErrNotImplemented
} }
func (this PSQL) Mkdir(path string) error { func (this PSQL) Mkdir(path string) error {
this.db.Close()
return ErrNotValid return ErrNotValid
} }
func (this PSQL) Mv(from string, to string) error { func (this PSQL) Mv(from string, to string) error {
this.db.Close()
return ErrNotValid return ErrNotValid
} }
@ -128,21 +137,21 @@ func (this PSQL) Meta(path string) Metadata {
location, _ := getPath(path) location, _ := getPath(path)
return Metadata{ return Metadata{
CanCreateDirectory: NewBool(false), CanCreateDirectory: NewBool(false),
CanCreateFile: func(l Location) *bool { CanCreateFile: func(l LocationRow) *bool {
if l.table == "" { if l.table == "" {
return NewBool(false) return NewBool(false)
} }
return NewBool(true) return NewBool(true)
}(location), }(location),
CanRename: NewBool(false), CanRename: NewBool(false),
CanDelete: func(l Location) *bool { CanDelete: func(l LocationRow) *bool {
if l.table == "" { if l.table == "" {
return NewBool(false) return NewBool(false)
} }
return NewBool(true) return NewBool(true)
}(location), }(location),
CanMove: NewBool(false), CanMove: NewBool(false),
CanUpload: func(l Location) *bool { CanUpload: func(l LocationRow) *bool {
if l.row == "" { if l.row == "" {
return NewBool(false) return NewBool(false)
} }
@ -152,3 +161,8 @@ func (this PSQL) Meta(path string) Metadata {
HideExtension: NewBool(true), HideExtension: NewBool(true),
} }
} }
func (this PSQL) Close() error {
this.db.Close()
return nil
}

View file

@ -1,26 +1,29 @@
package plg_backend_psql package plg_backend_psql
import ( import (
"context"
"database/sql"
"fmt"
"io" "io"
"slices"
. "github.com/mickael-kerjean/filestash/server/common" . "github.com/mickael-kerjean/filestash/server/common"
) )
func (this PSQL) Cat(path string) (io.ReadCloser, error) { func (this PSQL) Cat(path string) (io.ReadCloser, error) {
defer this.db.Close()
l, err := getPath(path) l, err := getPath(path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
columnName, err := getKey(this.ctx, this.db, l.table) columns, columnName, err := processTable(this.ctx, this.db, l.table)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rows, err := this.db.QueryContext(this.ctx, ` rows, err := this.db.QueryContext(this.ctx, `
SELECT * SELECT *
FROM `+l.table+` FROM "`+l.table+`"
WHERE `+columnName+`='`+l.row+`' WHERE "`+columnName+`"=$1
`) `, l.row)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -28,13 +31,12 @@ func (this PSQL) Cat(path string) (io.ReadCloser, error) {
c, err := rows.Columns() c, err := rows.Columns()
if err != nil { if err != nil {
return nil, err return nil, err
} } else if len(columns) != len(c) {
t, err := rows.ColumnTypes() Log.Error("plg_backend_psql::index_cat columns is not of the expected size columns[%d]=%v c[%d]=%v", len(columns), columns, len(c), c)
if err != nil { return nil, ErrNotValid
return nil, err
} }
i := 0 i := 0
col := make([]any, len(c)) col := make([]interface{}, len(c))
for rows.Next() { for rows.Next() {
if i != 0 { if i != 0 {
return nil, ErrNotValid return nil, ErrNotValid
@ -48,11 +50,19 @@ func (this PSQL) Cat(path string) (io.ReadCloser, error) {
} }
} }
forms := make([]FormElement, len(c)) forms := make([]FormElement, len(c))
for i, _ := range c { for i, _ := range columns {
f := formType(t[i].ScanType(), c[i]) forms[i] = createFormElement(col[i], columns[i])
f.Name = c[i] if slices.Contains(columns[i].Constraint, "PRIMARY KEY") {
f.Value = col[i] forms[i].ReadOnly = true
forms[i] = f } else if slices.Contains(columns[i].Constraint, "FOREIGN KEY") {
if link, err := _findRelation(this.ctx, this.db, columns[i]); err == nil {
forms[i].Description = _createDescription(columns[i], link)
if len(link.values) > 0 {
forms[i].Type = "select"
forms[i].Opts = link.values
}
}
}
} }
b, err := Form{Elmnts: forms}.MarshalJSON() b, err := Form{Elmnts: forms}.MarshalJSON()
if err != nil { if err != nil {
@ -60,3 +70,49 @@ func (this PSQL) Cat(path string) (io.ReadCloser, error) {
} }
return NewReadCloserFromBytes(b), nil return NewReadCloserFromBytes(b), nil
} }
func _createDescription(el Column, link LocationColumn) string {
if slices.Contains(el.Constraint, "FOREIGN KEY") {
return fmt.Sprintf("points to <%s> → <%s>", link.table, link.column)
}
return ""
}
func _findRelation(ctx context.Context, db *sql.DB, el Column) (LocationColumn, error) {
l := LocationColumn{}
rows, err := db.QueryContext(ctx, `
SELECT ccu.table_name, ccu.column_name
FROM information_schema.table_constraints AS tc
JOIN information_schema.key_column_usage AS kcu USING (constraint_name)
JOIN information_schema.constraint_column_usage AS ccu USING (constraint_name)
WHERE tc.constraint_type = 'FOREIGN KEY'
AND tc.table_name = $1
AND kcu.column_name = $2
`, el.Table, el.Name)
if err != nil {
return l, err
}
defer rows.Close()
for rows.Next() {
if err := rows.Scan(&l.table, &l.column); err != nil {
return l, err
}
}
valueRows, err := db.QueryContext(ctx, fmt.Sprintf(
`SELECT DISTINCT "%s" FROM "%s" ORDER BY "%s" LIMIT 5000`,
l.column, l.table, l.column,
))
if err != nil {
return l, err
}
defer valueRows.Close()
l.values = []string{}
for valueRows.Next() {
var value string
if err := valueRows.Scan(&value); err != nil {
return l, err
}
l.values = append(l.values, value)
}
return l, nil
}

View file

@ -7,14 +7,17 @@ import (
) )
func (this PSQL) Ls(path string) ([]os.FileInfo, error) { func (this PSQL) Ls(path string) ([]os.FileInfo, error) {
defer this.db.Close()
l, err := getPath(path) l, err := getPath(path)
if err != nil { if err != nil {
Log.Debug("pl_backend_psql::ls method=getPath err=%s", err.Error()) Log.Debug("pl_backend_psql::ls method=getPath err=%s", err.Error())
return nil, err return nil, err
} }
if l.table == "" { if l.table == "" {
rows, err := this.db.QueryContext(this.ctx, "SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") rows, err := this.db.QueryContext(this.ctx, `
SELECT table_name FROM information_schema.tables
WHERE table_schema = 'public'
AND table_type = 'BASE TABLE'
`)
if err != nil { if err != nil {
Log.Debug("plg_backend_psql::ls method=query err=%s", err.Error()) Log.Debug("plg_backend_psql::ls method=query err=%s", err.Error())
return nil, err return nil, err
@ -34,12 +37,12 @@ func (this PSQL) Ls(path string) ([]os.FileInfo, error) {
} }
return out, nil return out, nil
} else if l.row == "" { } else if l.row == "" {
key, err := getKey(this.ctx, this.db, l.table) _, key, err := processTable(this.ctx, this.db, l.table)
if err != nil { if err != nil {
Log.Debug("plg_backend_psql::ls method=getKey err=%s", err.Error()) Log.Debug("plg_backend_psql::ls method=processTable err=%s", err.Error())
return nil, err return nil, err
} }
rows, err := this.db.QueryContext(this.ctx, "SELECT "+key+" FROM "+l.table+" LIMIT 500000") rows, err := this.db.QueryContext(this.ctx, `SELECT "`+key+`" FROM "`+l.table+`" LIMIT 500000`)
if err != nil { if err != nil {
Log.Debug("plg_backend_psql::ls method=query err=%s", err.Error()) Log.Debug("plg_backend_psql::ls method=query err=%s", err.Error())
return nil, err return nil, err

View file

@ -0,0 +1,111 @@
package plg_backend_psql
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"slices"
"strings"
. "github.com/mickael-kerjean/filestash/server/common"
)
func (this PSQL) Save(path string, file io.Reader) error {
l, err := getPath(path)
if err != nil {
return err
}
columns, key, err := processTable(this.ctx, this.db, l.table)
if err != nil {
return err
}
f := map[string]FormElement{}
if err := json.NewDecoder(file).Decode(&f); err != nil {
return err
}
tx, err := this.db.BeginTx(this.ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
rows, err := tx.QueryContext(this.ctx, `SELECT * FROM "`+l.table+`" WHERE "`+key+`" = $1`, l.row)
if err != nil {
return err
}
i := 0
dbvals := make([]any, len(columns))
for rows.Next() {
currentPtrs := make([]any, len(columns))
for i := range dbvals {
currentPtrs[i] = &dbvals[i]
}
if serr := rows.Scan(currentPtrs...); serr != nil {
rows.Close()
err = serr
break
} else if i >= 1 {
err = ErrNotValid
break
}
i += 1
}
rows.Close()
if i == 0 {
err = _createRow(tx, this.ctx, l.table, columns, f)
}
if err == nil && i == 1 {
err = _updateRow(tx, this.ctx, l.table, columns, f, key, l.row, dbvals)
}
if err != nil {
return err
}
return tx.Commit()
}
func _createRow(tx *sql.Tx, ctx context.Context, table string, columns []Column, f map[string]FormElement) error {
colNames := []string{}
placeholders := []string{}
values := []interface{}{}
paramIndex := 1
for _, col := range columns {
if formEl, exists := f[col.Name]; exists {
if slices.Contains(col.Constraint, "PRIMARY KEY") && col.Default {
continue
}
colNames = append(colNames, `"`+col.Name+`"`)
placeholders = append(placeholders, fmt.Sprintf("$%d", paramIndex))
values = append(values, formEl.Value)
paramIndex++
}
}
if len(colNames) == 0 {
return ErrNotValid
}
_, err := tx.ExecContext(
ctx,
`INSERT INTO "`+table+`" (`+strings.Join(colNames, ", ")+`) VALUES (`+strings.Join(placeholders, ", ")+`)`,
values...,
)
return err
}
func _updateRow(tx *sql.Tx, ctx context.Context, table string, columns []Column, f map[string]FormElement, keyName string, keyValue any, dbvals []any) error {
for i, col := range columns {
dbval := convertFromDB(dbvals[i])
formval, ok := f[col.Name]
if !ok || formval.Value == dbval {
continue
}
if _, err := tx.ExecContext(
ctx,
`UPDATE "`+table+`" SET "`+col.Name+`" = $1 WHERE "`+keyName+`" = $2`,
formval.Value, keyValue,
); err != nil {
return err
}
}
return nil
}

View file

@ -1,12 +1,21 @@
package plg_backend_psql package plg_backend_psql
type Column struct { type Column struct {
Table string
Name string Name string
Type string Type string
Constraint string Nullable bool
Default bool
Constraint []string
} }
type Location struct { type LocationRow struct {
table string table string
row string row string
} }
type LocationColumn struct {
table string
column string
values []string
}

View file

@ -3,20 +3,24 @@ package plg_backend_psql
import ( import (
"context" "context"
"database/sql" "database/sql"
"reflect" "slices"
"strings" "strings"
"time"
. "github.com/mickael-kerjean/filestash/server/common" . "github.com/mickael-kerjean/filestash/server/common"
) )
func getPath(path string) (Location, error) { func getPath(path string) (LocationRow, error) {
l := Location{} l := LocationRow{}
for i, chunk := range strings.Split(path, "/") { for i, chunk := range strings.Split(path, "/") {
if i == 0 { if i == 0 {
if chunk != "" { if chunk != "" {
return l, ErrNotValid return l, ErrNotValid
} }
} else if i == 1 { } else if i == 1 {
if strings.Contains(chunk, `"`) {
return l, ErrNotValid
}
l.table = chunk l.table = chunk
} else if i == 2 { } else if i == 2 {
l.row = strings.TrimSuffix(chunk, ".form") l.row = strings.TrimSuffix(chunk, ".form")
@ -27,14 +31,39 @@ func getPath(path string) (Location, error) {
return l, nil return l, nil
} }
func getColumns(ctx context.Context, db *sql.DB, table string) ([]Column, error) { func processTable(ctx context.Context, db *sql.DB, table string) ([]Column, string, error) {
columns, err := _getColumns(ctx, db, table)
if err != nil {
return nil, "", err
}
key := ""
score := 0
for _, column := range columns {
if c := _calculateScore(column); c > score {
key = column.Name
score = c
}
}
if key == "" {
return columns, "", ErrNotValid
}
return columns, key, nil
}
func _getColumns(ctx context.Context, db *sql.DB, table string) ([]Column, error) {
rows, err := db.QueryContext(ctx, ` rows, err := db.QueryContext(ctx, `
SELECT c.column_name, c.data_type, tc.constraint_type SELECT
FROM information_schema.table_constraints tc c.column_name,
JOIN information_schema.constraint_column_usage AS ccu USING (constraint_schema, constraint_name) c.udt_name as type,
JOIN information_schema.columns AS c ON c.table_schema = tc.constraint_schema (c.is_nullable = 'YES') AS nullable,
AND tc.table_name = c.table_name AND ccu.column_name = c.column_name (c.column_default IS NOT NULL) AS has_default,
WHERE tc.table_name = $1 coalesce(string_agg(tc.constraint_type, ', '), '') as constraint
FROM information_schema.columns AS c
LEFT JOIN information_schema.key_column_usage kcu USING (table_name, column_name)
LEFT JOIN information_schema.table_constraints tc USING (table_name, constraint_name)
WHERE c.table_name = $1
GROUP BY c.column_name, c.is_nullable, c.udt_name, c.column_default
ORDER BY MIN(c.ordinal_position)
`, table) `, table)
if err != nil { if err != nil {
return nil, err return nil, err
@ -42,40 +71,23 @@ func getColumns(ctx context.Context, db *sql.DB, table string) ([]Column, error)
columns := []Column{} columns := []Column{}
for rows.Next() { for rows.Next() {
var c Column var c Column
if err := rows.Scan(&c.Name, &c.Type, &c.Type); err != nil { var constraints string
if err := rows.Scan(&c.Name, &c.Type, &c.Nullable, &c.Default, &constraints); err != nil {
return nil, err return nil, err
} }
c.Constraint = strings.Split(constraints, ", ")
c.Table = table
columns = append(columns, c) columns = append(columns, c)
} }
return columns, nil return columns, rows.Close()
} }
func getKey(ctx context.Context, db *sql.DB, table string) (string, error) { func _calculateScore(column Column) int {
columns, err := getColumns(ctx, db, table)
if err != nil {
return "", err
}
key := ""
score := 0
for _, column := range columns {
if c := calculateScore(column); c > score {
key = column.Name
score = c
}
}
if key == "" {
return "", ErrNotValid
}
return key, nil
}
func calculateScore(column Column) int {
scoreType := 0 scoreType := 0
scoreName := 1 scoreName := 1
switch column.Type { if slices.Contains(column.Constraint, "PRIMARY KEY") {
case "PRIMARY KEY":
scoreType = 3 scoreType = 3
case "UNIQUE": } else if slices.Contains(column.Constraint, "UNIQUE") {
scoreType = 2 scoreType = 2
} }
switch strings.ToLower(column.Name) { switch strings.ToLower(column.Name) {
@ -89,23 +101,33 @@ func calculateScore(column Column) int {
return scoreType * scoreName return scoreType * scoreName
} }
func formType(rt reflect.Type, label string) FormElement { func convertFromDB(val any) any {
switch rt.String() { switch tmp := val.(type) {
case "bool": case []byte:
return FormElement{ return string(tmp)
Type: "boolean", case time.Time:
return tmp.UTC().Format("2006-01-02T15:04")
} }
case "time.Time": return val
return FormElement{
Type: "datetime",
} }
}
if strings.Contains(strings.ToLower(label), "password") { func createFormElement(val any, column Column) FormElement {
return FormElement{ f := FormElement{
Type: "password",
}
}
return FormElement{
Type: "text", Type: "text",
} }
switch val.(type) {
case bool:
f.Type = "boolean"
case time.Time:
f.Type = "datetime"
}
f.Value = convertFromDB(val)
f.Name = column.Name
f.Required = !column.Nullable && !column.Default
if strings.Contains(strings.ToLower(column.Name), "password") {
f.Type = "password"
}
return f
} }

View file

@ -1 +0,0 @@
package plg_metadata_sqlite