Update go dependencies (#3480)

* Bump golang.org/x/text from 0.3.7 to 0.3.8

Bumps [golang.org/x/text](https://github.com/golang/text) from 0.3.7 to 0.3.8.
- [Release notes](https://github.com/golang/text/releases)
- [Commits](https://github.com/golang/text/compare/v0.3.7...v0.3.8)

---
updated-dependencies:
- dependency-name: golang.org/x/text
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

* Update go dependencies

* Update x/net

---------

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
This commit is contained in:
WithoutPants 2023-02-28 08:26:14 +11:00 committed by GitHub
parent 445e0a7311
commit 30809e16fa
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
417 changed files with 32289 additions and 17730 deletions

24
go.mod
View file

@ -1,7 +1,7 @@
module github.com/stashapp/stash
require (
github.com/99designs/gqlgen v0.17.2
github.com/99designs/gqlgen v0.17.24
github.com/Yamashou/gqlgenc v0.0.6
github.com/anacrolix/dms v1.2.2
github.com/antchfx/htmlquery v1.2.5-0.20211125074323-810ee8082758
@ -15,7 +15,7 @@ require (
github.com/golang-migrate/migrate/v4 v4.15.0-beta.1
github.com/gorilla/securecookie v1.1.1
github.com/gorilla/sessions v1.2.0
github.com/gorilla/websocket v1.4.2
github.com/gorilla/websocket v1.5.0
github.com/jinzhu/copier v0.0.0-20190924061706-b57f9002281a
github.com/jmoiron/sqlx v1.3.1
github.com/json-iterator/go v1.1.12
@ -30,16 +30,16 @@ require (
github.com/spf13/afero v1.8.2 // indirect
github.com/spf13/pflag v1.0.5
github.com/spf13/viper v1.10.1
github.com/stretchr/testify v1.7.0
github.com/stretchr/testify v1.7.1
github.com/tidwall/gjson v1.9.3
github.com/tidwall/pretty v1.2.0 // indirect
github.com/vektra/mockery/v2 v2.10.0
golang.org/x/crypto v0.0.0-20220321153916-2c7772ba3064
golang.org/x/image v0.0.0-20210220032944-ac19c3e999fb
golang.org/x/net v0.0.0-20220722155237-a158d28d115b
golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211
golang.org/x/text v0.3.7
golang.org/x/net v0.7.0
golang.org/x/sys v0.5.0
golang.org/x/term v0.5.0
golang.org/x/text v0.7.0
golang.org/x/tools v0.1.12 // indirect
gopkg.in/sourcemap.v1 v1.0.5 // indirect
gopkg.in/yaml.v2 v2.4.0
@ -57,7 +57,7 @@ require (
github.com/spf13/cast v1.4.1
github.com/vearutop/statigz v1.1.6
github.com/vektah/dataloaden v0.3.0
github.com/vektah/gqlparser/v2 v2.4.1
github.com/vektah/gqlparser/v2 v2.5.1
github.com/xWTF/chardet v0.0.0-20230208095535-c780f2ac244e
gopkg.in/guregu/null.v4 v4.0.0
)
@ -83,9 +83,8 @@ require (
github.com/josharian/intern v1.0.0 // indirect
github.com/magiconair/properties v1.8.6 // indirect
github.com/mailru/easyjson v0.7.7 // indirect
github.com/matryer/moq v0.2.6 // indirect
github.com/mitchellh/go-homedir v1.1.0 // indirect
github.com/mitchellh/mapstructure v1.4.3 // indirect
github.com/mitchellh/mapstructure v1.5.0 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 // indirect
@ -100,11 +99,12 @@ require (
github.com/stretchr/objx v0.2.0 // indirect
github.com/subosito/gotenv v1.2.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/urfave/cli/v2 v2.4.0 // indirect
github.com/urfave/cli/v2 v2.8.1 // indirect
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect
go.uber.org/atomic v1.7.0 // indirect
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4 // indirect
gopkg.in/ini.v1 v1.66.4 // indirect
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)
replace git.apache.org/thrift.git => github.com/apache/thrift v0.0.0-20180902110319-2566ecd5d999

50
go.sum
View file

@ -50,8 +50,9 @@ cloud.google.com/go/storage v1.8.0/go.mod h1:Wv1Oy7z6Yz3DshWRJFhqM/UCfaWIRTdp0RX
cloud.google.com/go/storage v1.10.0/go.mod h1:FLPqc6j+Ki4BU591ie1oL6qBQGu2Bl/tZ9ullr3+Kg0=
cloud.google.com/go/storage v1.14.0/go.mod h1:GrKmX003DSIwi9o29oFT7YDnHYwZoctc3fOKtUw0Xmo=
dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7gZCb22OMCxBHrMx7a5I7Hp++hsVxbQ4BYO7hU=
github.com/99designs/gqlgen v0.17.2 h1:yczvlwMsfcVu/JtejqfrLwXuSP0yZFhmcss3caEvHw8=
github.com/99designs/gqlgen v0.17.2/go.mod h1:K5fzLKwtph+FFgh9j7nFbRUdBKvTcGnsta51fsMTn3o=
github.com/99designs/gqlgen v0.17.24 h1:pcd/HFIoSdRvyADYQG2dHvQN2KZqX/nXzlVm6TMMq7E=
github.com/99designs/gqlgen v0.17.24/go.mod h1:BMhYIhe4bp7OlCo5I2PnowSK/Wimpv/YlxfNkqZGwLo=
github.com/Azure/azure-pipeline-go v0.2.3/go.mod h1:x841ezTBIMG6O3lAcl8ATHnsOPVl2bqk7S3ta6S6u4k=
github.com/Azure/azure-storage-blob-go v0.13.0/go.mod h1:pA9kNqtjUeQF2zOSu4s//nUdBD+e64lEuc4sVnuOfNs=
github.com/Azure/go-ansiterm v0.0.0-20170929234023-d6e3b3328b78/go.mod h1:LmzpDX56iTiv29bbRTIsUNlaFfuhWRQBWjQdVyAevI8=
@ -63,6 +64,7 @@ github.com/Azure/go-autorest/autorest/mocks v0.4.1/go.mod h1:LTp+uSrOhSkaKrUy935
github.com/Azure/go-autorest/logger v0.2.1/go.mod h1:T9E3cAhj2VqvPOtCYAvby9aBXkZmbF5NWuPV8+WeEW8=
github.com/Azure/go-autorest/tracing v0.6.0/go.mod h1:+vhtPC754Xsa23ID7GlGsrdKBpUA79WCAKPPZVC2DeU=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/ClickHouse/clickhouse-go v1.4.3/go.mod h1:EaI/sW7Azgz9UATzd5ZdZHRUhHgv5+JMS9NSr2smCJI=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
@ -396,8 +398,9 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.0 h1:S7P+1Hm5V/AT9cjEcUD5uDaQSX0OE577aCXgoaKpYbQ=
github.com/gorilla/sessions v1.2.0/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
github.com/hashicorp/consul/api v1.11.0/go.mod h1:XjsvQN+RJGWI2TWy1/kqaE16HrR2J/FWgkYjdZQsX9M=
@ -556,8 +559,7 @@ github.com/markbates/oncer v0.0.0-20181203154359-bf2de49a0be2/go.mod h1:Ld9puTsI
github.com/markbates/pkger v0.15.1/go.mod h1:0JoVlrol20BSywW79rN3kdFFsE5xYM+rSCQDXbLhiuI=
github.com/markbates/safe v1.0.1/go.mod h1:nAqgmRi7cY2nqMc92/bSEeQA+R4OheNU2T1kNSCBdG0=
github.com/matryer/moq v0.2.3/go.mod h1:9RtPYjTnH1bSBIkpvtHkFN7nbWAnO7oRpdJkEIn6UtE=
github.com/matryer/moq v0.2.6 h1:X4+LF09udTsi2P+Z+1UhSb4p3K8IyiF7KSNFDR9M3M0=
github.com/matryer/moq v0.2.6/go.mod h1:kITsx543GOENm48TUAQyJ9+SAvFSr7iGQXPoth/VUBk=
github.com/matryer/moq v0.2.7/go.mod h1:kITsx543GOENm48TUAQyJ9+SAvFSr7iGQXPoth/VUBk=
github.com/mattn/go-colorable v0.0.9/go.mod h1:9vuHe8Xs5qXnSaW/c/ABM9alt+Vo+STaOChaDxuIBZU=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE=
@ -565,6 +567,7 @@ github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVc
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.9/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4=
github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg=
github.com/mattn/go-ieproxy v0.0.1/go.mod h1:pYabZ6IHcRpFh7vIaLfK7rdcWgFEb3SFJ6/gNWuh88E=
github.com/mattn/go-isatty v0.0.3/go.mod h1:M+lRXTBqGeGNdLjl/ufCoiOlB5xdOkqRJdNxMWT7Zi4=
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
@ -575,6 +578,7 @@ github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcME
github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94=
github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM=
github.com/mattn/go-sqlite3 v1.9.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
@ -591,8 +595,9 @@ github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:F
github.com/mitchellh/mapstructure v0.0.0-20180220230111-00c29f56e238/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
github.com/mitchellh/mapstructure v1.2.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/mapstructure v1.4.3 h1:OVowDSCllw/YjdLkam3/sm7wEtOy59d8ndGgCcyj8cs=
github.com/mitchellh/mapstructure v1.4.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -734,8 +739,9 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1 h1:5TQK59W5E3v0r2duFAb7P95B6hEeOyEnHRa8MjYSMTY=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s=
github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw=
github.com/tidwall/gjson v1.9.3 h1:hqzS9wAHMO+KVBBkLxYdkEeeFHuqr95GfClRLKlgK0E=
@ -748,15 +754,16 @@ github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhso
github.com/tinylib/msgp v1.0.2/go.mod h1:+d+yLhGm8mzTaHzB+wgMYrodPfmZrzkirds8fDWklFE=
github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM=
github.com/urfave/cli/v2 v2.3.0/go.mod h1:LJmUH05zAU44vOAcrfzZQKsZbVcdbOG8rtL3/XcUArI=
github.com/urfave/cli/v2 v2.4.0 h1:m2pxjjDFgDxSPtO8WSdbndj17Wu2y8vOT86wE/tjr+I=
github.com/urfave/cli/v2 v2.4.0/go.mod h1:NX9W0zmTvedE5oDoOMs2RTC8RvdK98NTYZE5LbaEYPg=
github.com/urfave/cli/v2 v2.8.1 h1:CGuYNZF9IKZY/rfBe3lJpccSoIY1ytfvmgQT90cNOl4=
github.com/urfave/cli/v2 v2.8.1/go.mod h1:Z41J9TPoffeoqP0Iza0YbAhGvymRdZAd2uPmZ5JxRdY=
github.com/vearutop/statigz v1.1.6 h1:si1zvulh/6P4S/SjFticuKQ8/EgQISglaRuycj8PWso=
github.com/vearutop/statigz v1.1.6/go.mod h1:czAv7iXgPv/s+xsgXpVEhhD0NSOQ4wZPgmM/n7LANDI=
github.com/vektah/dataloaden v0.3.0 h1:ZfVN2QD6swgvp+tDqdH/OIT/wu3Dhu0cus0k5gIZS84=
github.com/vektah/dataloaden v0.3.0/go.mod h1:/HUdMve7rvxZma+2ZELQeNh88+003LL7Pf/CZ089j8U=
github.com/vektah/gqlparser/v2 v2.4.0/go.mod h1:flJWIR04IMQPGz+BXLrORkrARBxv/rtyIAFvd/MceW0=
github.com/vektah/gqlparser/v2 v2.4.1 h1:QOyEn8DAPMUMARGMeshKDkDgNmVoEaEGiDB0uWxcSlQ=
github.com/vektah/gqlparser/v2 v2.4.1/go.mod h1:flJWIR04IMQPGz+BXLrORkrARBxv/rtyIAFvd/MceW0=
github.com/vektah/gqlparser/v2 v2.5.1 h1:ZGu+bquAY23jsxDRcYpWjttRZrUz07LbiY77gUOHcr4=
github.com/vektah/gqlparser/v2 v2.5.1/go.mod h1:mPgqFBu/woKTVYWyNk8cO3kh4S/f4aRFZrvOnp3hmCs=
github.com/vektra/mockery/v2 v2.10.0 h1:MiiQWxwdq7/ET6dCXLaJzSGEN17k758H7JHS9kOdiks=
github.com/vektra/mockery/v2 v2.10.0/go.mod h1:m/WO2UzWzqgVX3nvqpRQq70I4Z7jbSCRhdmkgtp+Ab4=
github.com/willf/bitset v1.1.9/go.mod h1:RjeCKbqT1RxIR/KWY6phxZiaY1IyutSBfGjNPySAYV4=
@ -766,6 +773,8 @@ github.com/xanzy/go-gitlab v0.15.0/go.mod h1:8zdQa/ri1dfn8eS3Ir1SyfvOKlw7WBJ8DVT
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
github.com/xdg-go/scram v1.0.2/go.mod h1:1WAq6h33pAW+iRreB34OORO2Nf7qel3VV3fjBj+hCSs=
github.com/xdg-go/stringprep v1.0.2/go.mod h1:8F9zXuvzgwmyT5DUm4GUfZGDdT3W+LCvS6+da4O5kxM=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 h1:bAn7/zixMGCfxrRTfdpNzjtPYqr8smhKouy9mxVdGPU=
github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8=
github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA=
github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74=
@ -774,6 +783,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.0/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.1/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
gitlab.com/nyarla/go-crypt v0.0.0-20160106005555-d9a5dc2b789b/go.mod h1:T3BPAOm2cqquPa0MKWeNkmOM5RQsRhkrwMWonFMN7fE=
go.etcd.io/etcd/api/v3 v3.5.1/go.mod h1:cbVKeC6lCfl7j/8jBhAK6aIYO9XOjdptoxU/nLQcPvs=
@ -924,8 +934,9 @@ golang.org/x/net v0.0.0-20210805182204-aaa1db679c0d/go.mod h1:9nx3DQGgdP8bBQD5qx
golang.org/x/net v0.0.0-20210813160813-60bc85c4be6d/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b h1:PxfKdU9lEEDYjdIzOtC4qFWgkU2rGHdKlKowJSMN9h0=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/oauth2 v0.0.0-20180227000427-d7d64896b5ff/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181106182150-f42d05182288/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
@ -958,6 +969,7 @@ golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180224232135-f6cff0780e54/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180823144017-11551d06cbcc/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
@ -1050,11 +1062,15 @@ golang.org/x/sys v0.0.0-20211205182925-97ca703d548d/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20211210111614-af8b64212486/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220114195835-da31bd327af9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664 h1:v1W7bwXHsnLLloWYTVEdvGvA7BHMeBYsPcF0GLDxIRs=
golang.org/x/sys v0.0.0-20220808155132-1c4a2a72c664/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211 h1:JGgROgKl9N8DuW20oFS5gxc+lE67/N3FcwmBPMe7ArY=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0 h1:n2a8QNdAb0sZNpU9R1ALUXBbY+w51fCQDN+7EdxNBsY=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@ -1063,8 +1079,10 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ=
golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@ -1303,6 +1321,7 @@ google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlba
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.28.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
@ -1332,8 +1351,9 @@ gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorm.io/driver/postgres v1.0.8/go.mod h1:4eOzrI1MUfm6ObJU/UcmbXyiHSs8jSwH95G5P5dxcAg=
gorm.io/gorm v1.20.12/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw=
gorm.io/gorm v1.21.4/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw=

View file

@ -5,6 +5,7 @@ package main
import (
_ "github.com/99designs/gqlgen"
_ "github.com/99designs/gqlgen/graphql/introspection"
_ "github.com/Yamashou/gqlgenc"
_ "github.com/vektah/dataloaden"
_ "github.com/vektra/mockery/v2"

View file

@ -14,3 +14,5 @@
.idea/
*.test
*.out
gqlgen
*.exe

File diff suppressed because it is too large Load diff

View file

@ -22,7 +22,8 @@ Still not convinced enough to use **gqlgen**? Compare **gqlgen** with other Go g
2. Add `github.com/99designs/gqlgen` to your [project's tools.go](https://github.com/golang/go/wiki/Modules#how-can-i-track-tool-dependencies-for-a-module)
printf '// +build tools\npackage tools\nimport _ "github.com/99designs/gqlgen"' | gofmt > tools.go
printf '// +build tools\npackage tools\nimport (_ "github.com/99designs/gqlgen"\n _ "github.com/99designs/gqlgen/graphql/introspection")' | gofmt > tools.go
go mod tidy
3. Initialise gqlgen config and generate models
@ -142,9 +143,20 @@ first model in this list is used as the default type and it will always be used
There isn't any way around this, gqlgen has no way to know what you want in a given context.
### Why do my interfaces have getters? Can I disable these?
These were added in v0.17.14 to allow accessing common interface fields without casting to a concrete type.
However, certain fields, like Relay-style Connections, cannot be implemented with simple getters.
If you'd prefer to not have getters generated in your interfaces, you can add the following in your `gqlgen.yml`:
```yaml
# gqlgen.yml
omit_getters: true
```
## Other Resources
- [Christopher Biscardi @ Gophercon UK 2018](https://youtu.be/FdURVezcdcw)
- [Introducing gqlgen: a GraphQL Server Generator for Go](https://99designs.com.au/blog/engineering/gqlgen-a-graphql-server-generator-for-go/)
- [Dive into GraphQL by Iván Corrales Solera](https://medium.com/@ivan.corrales.solera/dive-into-graphql-9bfedf22e1a)
- [Sample Project built on gqlgen with Postgres by Oleg Shalygin](https://github.com/oshalygin/gqlgen-pg-todo-example)
- [Hackernews GraphQL Server with gqlgen by Shayegan Hooshyari](https://www.howtographql.com/graphql-go/0-introduction/)

View file

@ -6,9 +6,10 @@ Assuming the next version is $NEW_VERSION=v0.16.0 or something like that.
./bin/release $NEW_VERSION
```
2. git-chglog -o CHANGELOG.md
3. git commit and push the CHANGELOG.md
4. Go to https://github.com/99designs/gqlgen/releases and draft new release, autogenerate the release notes, and Create a discussion for this release
5. Comment on the release discussion with any really important notes (breaking changes)
3. go generate ./...; cd _examples; go generate ./...; cd ..
4. git commit and push the CHANGELOG.md
5. Go to https://github.com/99designs/gqlgen/releases and draft new release, autogenerate the release notes, and Create a discussion for this release
6. Comment on the release discussion with any really important notes (breaking changes)
I used https://github.com/git-chglog/git-chglog to automate the changelog maintenance process for now. We could just as easily use go releaser to make the whole thing automated.

View file

@ -36,5 +36,4 @@ npm install
will write the schema to `integration/schema-fetched.graphql`, compare that with `schema-expected.graphql`
CI will run this and fail the build if the two files dont match.
CI will run this and fail the build if the two files don't match.

View file

@ -2,6 +2,7 @@ package api
import (
"fmt"
"regexp"
"syscall"
"github.com/99designs/gqlgen/codegen"
@ -24,7 +25,20 @@ func Generate(cfg *config.Config, option ...Option) error {
}
plugins = append(plugins, resolvergen.New())
if cfg.Federation.IsDefined() {
plugins = append([]plugin.Plugin{federation.New()}, plugins...)
if cfg.Federation.Version == 0 { // default to using the user's choice of version, but if unset, try to sort out which federation version to use
urlRegex := regexp.MustCompile(`(?s)@link.*\(.*url:.*?"(.*?)"[^)]+\)`) // regex to grab the url of a link directive, should it exist
// check the sources, and if one is marked as federation v2, we mark the entirety to be generated using that format
for _, v := range cfg.Sources {
cfg.Federation.Version = 1
urlString := urlRegex.FindStringSubmatch(v.Input)
if urlString != nil && urlString[1] == "https://specs.apollo.dev/federation/v2.0" {
cfg.Federation.Version = 2
break
}
}
}
plugins = append([]plugin.Plugin{federation.New(cfg.Federation.Version)}, plugins...)
}
for _, o := range option {

View file

@ -73,11 +73,15 @@ func (b *builder) buildArg(obj *Object, arg *ast.ArgumentDefinition) (*FieldArgu
return &newArg, nil
}
func (b *builder) bindArgs(field *Field, params *types.Tuple) ([]*FieldArgument, error) {
var newArgs []*FieldArgument
func (b *builder) bindArgs(field *Field, sig *types.Signature, params *types.Tuple) ([]*FieldArgument, error) {
n := params.Len()
newArgs := make([]*FieldArgument, 0, len(field.Args))
// Accept variadic methods (i.e. have optional parameters).
if params.Len() > len(field.Args) && sig.Variadic() {
n = len(field.Args)
}
nextArg:
for j := 0; j < params.Len(); j++ {
for j := 0; j < n; j++ {
param := params.At(j)
for _, oldArg := range field.Args {
if strings.EqualFold(oldArg.Name, param.Name()) {

View file

@ -192,6 +192,7 @@ type TypeReference struct {
Unmarshaler *types.Func // When using external marshalling functions this will point to the Unmarshal function
IsMarshaler bool // Does the type implement graphql.Marshaler and graphql.Unmarshaler
IsContext bool // Is the Marshaler/Unmarshaller the context version; applies to either the method or interface variety.
PointersInUmarshalInput bool // Inverse values and pointers in return.
}
func (ref *TypeReference) Elem() *TypeReference {
@ -216,7 +217,6 @@ func (t *TypeReference) IsPtr() bool {
}
// fix for https://github.com/golang/go/issues/31103 may make it possible to remove this (may still be useful)
//
func (t *TypeReference) IsPtrToPtr() bool {
if p, isPtr := t.GO.(*types.Pointer); isPtr {
_, isPtr := p.Elem().(*types.Pointer)
@ -252,6 +252,39 @@ func (t *TypeReference) IsStruct() bool {
return isStruct
}
func (t *TypeReference) IsUnderlyingBasic() bool {
_, isUnderlyingBasic := t.GO.Underlying().(*types.Basic)
return isUnderlyingBasic
}
func (t *TypeReference) IsUnusualBasic() bool {
if basic, isBasic := t.GO.(*types.Basic); isBasic {
switch basic.Kind() {
case types.Int8, types.Int16, types.Uint, types.Uint8, types.Uint16, types.Uint32:
return true
default:
return false
}
}
return false
}
func (t *TypeReference) IsUnderlyingUnusualBasic() bool {
if basic, isUnderlyingBasic := t.GO.Underlying().(*types.Basic); isUnderlyingBasic {
switch basic.Kind() {
case types.Int8, types.Int16, types.Uint, types.Uint8, types.Uint16, types.Uint32:
return true
default:
return false
}
}
return false
}
func (t *TypeReference) IsScalarID() bool {
return t.Definition.Kind == ast.Scalar && t.Marshaler.Name() == "MarshalID"
}
func (t *TypeReference) IsScalar() bool {
return t.Definition.Kind == ast.Scalar
}
@ -413,6 +446,8 @@ func (b *Binder) TypeReference(schemaType *ast.Type, bindTarget types.Type) (ret
ref.GO = bindTarget
}
ref.PointersInUmarshalInput = b.cfg.ReturnPointersInUmarshalInput
return ref, nil
}

View file

@ -1,8 +1,8 @@
package config
import (
"bytes"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"regexp"
@ -12,7 +12,7 @@ import (
"github.com/99designs/gqlgen/internal/code"
"github.com/vektah/gqlparser/v2"
"github.com/vektah/gqlparser/v2/ast"
"gopkg.in/yaml.v2"
"gopkg.in/yaml.v3"
)
type Config struct {
@ -26,6 +26,11 @@ type Config struct {
StructTag string `yaml:"struct_tag,omitempty"`
Directives map[string]DirectiveConfig `yaml:"directives,omitempty"`
OmitSliceElementPointers bool `yaml:"omit_slice_element_pointers,omitempty"`
OmitGetters bool `yaml:"omit_getters,omitempty"`
OmitComplexity bool `yaml:"omit_complexity,omitempty"`
StructFieldsAlwaysPointers bool `yaml:"struct_fields_always_pointers,omitempty"`
ReturnPointersInUmarshalInput bool `yaml:"return_pointers_in_unmarshalinput,omitempty"`
ResolversAlwaysReturnPointers bool `yaml:"resolvers_always_return_pointers,omitempty"`
SkipValidation bool `yaml:"skip_validation,omitempty"`
SkipModTidy bool `yaml:"skip_mod_tidy,omitempty"`
Sources []*ast.Source `yaml:"-"`
@ -46,6 +51,9 @@ func DefaultConfig() *Config {
Exec: ExecConfig{Filename: "generated.go"},
Directives: map[string]DirectiveConfig{},
Models: TypeMap{},
StructFieldsAlwaysPointers: true,
ReturnPointersInUmarshalInput: false,
ResolversAlwaysReturnPointers: true,
}
}
@ -57,7 +65,7 @@ func LoadDefaultConfig() (*Config, error) {
filename = filepath.ToSlash(filename)
var err error
var schemaRaw []byte
schemaRaw, err = ioutil.ReadFile(filename)
schemaRaw, err = os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("unable to open schema: %w", err)
}
@ -94,12 +102,15 @@ var path2regex = strings.NewReplacer(
func LoadConfig(filename string) (*Config, error) {
config := DefaultConfig()
b, err := ioutil.ReadFile(filename)
b, err := os.ReadFile(filename)
if err != nil {
return nil, fmt.Errorf("unable to read config: %w", err)
}
if err := yaml.UnmarshalStrict(b, config); err != nil {
dec := yaml.NewDecoder(bytes.NewReader(b))
dec.KnownFields(true)
if err := dec.Decode(config); err != nil {
return nil, fmt.Errorf("unable to parse config: %w", err)
}
@ -173,7 +184,7 @@ func CompleteConfig(config *Config) error {
filename = filepath.ToSlash(filename)
var err error
var schemaRaw []byte
schemaRaw, err = ioutil.ReadFile(filename)
schemaRaw, err = os.ReadFile(filename)
if err != nil {
return fmt.Errorf("unable to open schema: %w", err)
}

View file

@ -12,6 +12,7 @@ import (
type PackageConfig struct {
Filename string `yaml:"filename,omitempty"`
Package string `yaml:"package,omitempty"`
Version int `yaml:"version,omitempty"`
}
func (c *PackageConfig) ImportPath() string {

View file

@ -2,7 +2,10 @@ package codegen
import (
"fmt"
"os"
"path/filepath"
"sort"
"strings"
"github.com/vektah/gqlparser/v2/ast"
@ -30,6 +33,26 @@ type Data struct {
QueryRoot *Object
MutationRoot *Object
SubscriptionRoot *Object
AugmentedSources []AugmentedSource
}
func (d *Data) HasEmbeddableSources() bool {
hasEmbeddableSources := false
for _, s := range d.AugmentedSources {
if s.Embeddable {
hasEmbeddableSources = true
}
}
return hasEmbeddableSources
}
// AugmentedSource contains extra information about graphql schema files which is not known directly from the Config.Sources data
type AugmentedSource struct {
// path relative to Config.Exec.Filename
RelativePath string
Embeddable bool
BuiltIn bool
Source string
}
type builder struct {
@ -147,6 +170,31 @@ func BuildData(cfg *config.Config) (*Data, error) {
// otherwise show a generic error message
return nil, fmt.Errorf("invalid types were encountered while traversing the go source code, this probably means the invalid code generated isnt correct. add try adding -v to debug")
}
aSources := []AugmentedSource{}
for _, s := range cfg.Sources {
wd, err := os.Getwd()
if err != nil {
return nil, fmt.Errorf("failed to get working directory: %w", err)
}
outputDir := cfg.Exec.Dir()
sourcePath := filepath.Join(wd, s.Name)
relative, err := filepath.Rel(outputDir, sourcePath)
if err != nil {
return nil, fmt.Errorf("failed to compute path of %s relative to %s: %w", sourcePath, outputDir, err)
}
relative = filepath.ToSlash(relative)
embeddable := true
if strings.HasPrefix(relative, "..") || s.BuiltIn {
embeddable = false
}
aSources = append(aSources, AugmentedSource{
RelativePath: relative,
Embeddable: embeddable,
BuiltIn: s.BuiltIn,
Source: s.Input,
})
}
s.AugmentedSources = aSources
return &s, nil
}

View file

@ -70,7 +70,7 @@ func (ec *executionContext) _mutationMiddleware(ctx context.Context, obj *ast.Op
{{ end }}
{{ if .Directives.LocationDirectives "SUBSCRIPTION" }}
func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) func() graphql.Marshaler {
func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *ast.OperationDefinition, next func(ctx context.Context) (interface{}, error)) func(ctx context.Context) graphql.Marshaler {
for _, d := range obj.Directives {
switch d.Name {
{{- range $directive := .Directives.LocationDirectives "SUBSCRIPTION" }}
@ -80,7 +80,7 @@ func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *as
args, err := ec.{{ $directive.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return func() graphql.Marshaler {
return func(ctx context.Context) graphql.Marshaler {
return graphql.Null
}
}
@ -98,15 +98,15 @@ func (ec *executionContext) _subscriptionMiddleware(ctx context.Context, obj *as
tmp, err := next(ctx)
if err != nil {
ec.Error(ctx, err)
return func() graphql.Marshaler {
return func(ctx context.Context) graphql.Marshaler {
return graphql.Null
}
}
if data, ok := tmp.(func() graphql.Marshaler); ok {
if data, ok := tmp.(func(ctx context.Context) graphql.Marshaler); ok {
return data
}
ec.Errorf(ctx, `unexpected type %T from directive, should be graphql.Marshaler`, tmp)
return func() graphql.Marshaler {
return func(ctx context.Context) graphql.Marshaler {
return graphql.Null
}
}

View file

@ -3,6 +3,7 @@ package codegen
import (
"errors"
"fmt"
goast "go/ast"
"go/types"
"log"
"reflect"
@ -12,6 +13,8 @@ import (
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/vektah/gqlparser/v2/ast"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
type Field struct {
@ -71,7 +74,7 @@ func (b *builder) buildField(obj *Object, field *ast.FieldDefinition) (*Field, e
log.Println(err.Error())
}
if f.IsResolver && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
if f.IsResolver && b.Config.ResolversAlwaysReturnPointers && !f.TypeReference.IsPtr() && f.TypeReference.IsStruct() {
f.TypeReference = b.Binder.PointerTo(f.TypeReference)
}
@ -179,8 +182,8 @@ func (b *builder) bindField(obj *Object, f *Field) (errret error) {
params = types.NewTuple(vars...)
}
// Try to match target function's arguments with GraphQL field arguments
newArgs, err := b.bindArgs(f, params)
// Try to match target function's arguments with GraphQL field arguments.
newArgs, err := b.bindArgs(f, sig, params)
if err != nil {
return fmt.Errorf("%s:%d: %w", pos.Filename, pos.Line, err)
}
@ -469,10 +472,11 @@ func (f *Field) GoNameUnexported() string {
}
func (f *Field) ShortInvocation() string {
caser := cases.Title(language.English, cases.NoLower)
if f.Object.Kind == ast.InputObject {
return fmt.Sprintf("%s().%s(ctx, &it, data)", strings.Title(f.Object.Definition.Name), f.GoFieldName)
return fmt.Sprintf("%s().%s(ctx, &it, data)", caser.String(f.Object.Definition.Name), f.GoFieldName)
}
return fmt.Sprintf("%s().%s(%s)", strings.Title(f.Object.Definition.Name), f.GoFieldName, f.CallArgs())
return fmt.Sprintf("%s().%s(%s)", caser.String(f.Object.Definition.Name), f.GoFieldName, f.CallArgs())
}
func (f *Field) ArgsFunc() string {
@ -483,6 +487,14 @@ func (f *Field) ArgsFunc() string {
return "field_" + f.Object.Definition.Name + "_" + f.Name + "_args"
}
func (f *Field) FieldContextFunc() string {
return "fieldContext_" + f.Object.Definition.Name + "_" + f.Name
}
func (f *Field) ChildFieldContextFunc(name string) string {
return "fieldContext_" + f.TypeReference.Definition.Name + "_" + name
}
func (f *Field) ResolverType() string {
if !f.IsResolver {
return ""
@ -492,6 +504,12 @@ func (f *Field) ResolverType() string {
}
func (f *Field) ShortResolverDeclaration() string {
return f.ShortResolverSignature(nil)
}
// ShortResolverSignature is identical to ShortResolverDeclaration,
// but respects previous naming (return) conventions, if any.
func (f *Field) ShortResolverSignature(ft *goast.FuncType) string {
if f.Object.Kind == ast.InputObject {
return fmt.Sprintf("(ctx context.Context, obj %s, data %s) error",
templates.CurrentImports.LookupType(f.Object.Reference()),
@ -512,8 +530,17 @@ func (f *Field) ShortResolverDeclaration() string {
if f.Object.Stream {
result = "<-chan " + result
}
res += fmt.Sprintf(") (%s, error)", result)
// Named return.
var namedV, namedE string
if ft != nil {
if ft.Results != nil && len(ft.Results.List) > 0 && len(ft.Results.List[0].Names) > 0 {
namedV = ft.Results.List[0].Names[0].Name
}
if ft.Results != nil && len(ft.Results.List) > 1 && len(ft.Results.List[1].Names) > 0 {
namedE = ft.Results.List[1].Names[0].Name
}
}
res += fmt.Sprintf(") (%s %s, %s error)", namedV, result, namedE)
return res
}
@ -549,7 +576,20 @@ func (f *Field) CallArgs() string {
}
for _, arg := range f.Args {
args = append(args, "args["+strconv.Quote(arg.Name)+"].("+templates.CurrentImports.LookupType(arg.TypeReference.GO)+")")
tmp := "fc.Args[" + strconv.Quote(arg.Name) + "].(" + templates.CurrentImports.LookupType(arg.TypeReference.GO) + ")"
if iface, ok := arg.TypeReference.GO.(*types.Interface); ok && iface.Empty() {
tmp = fmt.Sprintf(`
func () interface{} {
if fc.Args["%s"] == nil {
return nil
}
return fc.Args["%s"].(interface{})
}()`, arg.Name, arg.Name,
)
}
args = append(args, tmp)
}
return strings.Join(args, ", ")

View file

@ -1,34 +1,21 @@
{{- range $object := .Objects }}{{- range $field := $object.Fields }}
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret {{ if $object.Stream }}func(){{ end }}graphql.Marshaler) {
func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Context, field graphql.CollectedField{{ if not $object.Root }}, obj {{$object.Reference | ref}}{{end}}) (ret {{ if $object.Stream }}func(ctx context.Context){{ end }}graphql.Marshaler) {
{{- $null := "graphql.Null" }}
{{- if $object.Stream }}
{{- $null = "nil" }}
{{- end }}
fc, err := ec.{{ $field.FieldContextFunc }}(ctx, field)
if err != nil {
return {{ $null }}
}
ctx = graphql.WithFieldContext(ctx, fc)
defer func () {
if r := recover(); r != nil {
ec.Error(ctx, ec.Recover(ctx, r))
ret = {{ $null }}
}
}()
fc := &graphql.FieldContext{
Object: {{$object.Name|quote}},
Field: field,
Args: nil,
IsMethod: {{or $field.IsMethod $field.IsResolver}},
IsResolver: {{ $field.IsResolver }},
}
ctx = graphql.WithFieldContext(ctx, fc)
{{- if $field.Args }}
rawArgs := field.ArgumentMap(ec.Variables)
args, err := ec.{{ $field.ArgsFunc }}(ctx,rawArgs)
if err != nil {
ec.Error(ctx, err)
return {{ $null }}
}
fc.Args = args
{{- end }}
{{- if $.AllDirectives.LocationDirectives "FIELD" }}
resTmp := ec._fieldMiddleware(ctx, {{if $object.Root}}nil{{else}}obj{{end}}, func(rctx context.Context) (interface{}, error) {
{{ template "field" $field }}
@ -39,7 +26,9 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex
})
if err != nil {
ec.Error(ctx, err)
{{- if not $object.Root }}
return {{ $null }}
{{- end }}
}
{{- end }}
if resTmp == nil {
@ -51,8 +40,9 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex
return {{ $null }}
}
{{- if $object.Stream }}
return func() graphql.Marshaler {
res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}})
return func(ctx context.Context) graphql.Marshaler {
select {
case res, ok := <-resTmp.(<-chan {{$field.TypeReference.GO | ref}}):
if !ok {
return nil
}
@ -63,6 +53,9 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex
ec.{{ $field.TypeReference.MarshalFunc }}(ctx, field.Selections, res).MarshalGQL(w)
w.Write([]byte{'}'})
})
case <-ctx.Done():
return nil
}
}
{{- else }}
res := resTmp.({{$field.TypeReference.GO | ref}})
@ -71,6 +64,44 @@ func (ec *executionContext) _{{$object.Name}}_{{$field.Name}}(ctx context.Contex
{{- end }}
}
func (ec *executionContext) {{ $field.FieldContextFunc }}(ctx context.Context, field graphql.CollectedField) (fc *graphql.FieldContext, err error) {
fc = &graphql.FieldContext{
Object: {{quote $field.Object.Name}},
Field: field,
IsMethod: {{or $field.IsMethod $field.IsResolver}},
IsResolver: {{ $field.IsResolver }},
Child: func (ctx context.Context, field graphql.CollectedField) (*graphql.FieldContext, error) {
{{- if not $field.TypeReference.Definition.Fields }}
return nil, errors.New("field of type {{ $field.TypeReference.Definition.Name }} does not have child fields")
{{- else if ne $field.TypeReference.Definition.Kind "OBJECT" }}
return nil, errors.New("FieldContext.Child cannot be called on type {{ $field.TypeReference.Definition.Kind }}")
{{- else }}
switch field.Name {
{{- range $f := $field.TypeReference.Definition.Fields }}
case "{{ $f.Name }}":
return ec.{{ $field.ChildFieldContextFunc $f.Name }}(ctx, field)
{{- end }}
}
return nil, fmt.Errorf("no field named %q was found under type {{ $field.TypeReference.Definition.Name }}", field.Name)
{{- end }}
},
}
{{- if $field.Args }}
defer func () {
if r := recover(); r != nil {
err = ec.Recover(ctx, r)
ec.Error(ctx, err)
}
}()
ctx = graphql.WithFieldContext(ctx, fc)
if fc.Args, err = ec.{{ $field.ArgsFunc }}(ctx, field.ArgumentMap(ec.Variables)); err != nil {
ec.Error(ctx, err)
return
}
{{- end }}
return fc, nil
}
{{- end }}{{- end}}
{{ define "field" }}

View file

@ -1,9 +1,10 @@
package codegen
import (
"embed"
"errors"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"runtime"
"strings"
@ -13,6 +14,9 @@ import (
"github.com/vektah/gqlparser/v2/ast"
)
//go:embed *.gotpl
var codegenTemplates embed.FS
func GenerateCode(data *Data) error {
if !data.Config.Exec.IsDefined() {
return fmt.Errorf("missing exec config")
@ -36,6 +40,7 @@ func generateSingleFile(data *Data) error {
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
}
@ -82,6 +87,7 @@ func generatePerSchema(data *Data) error {
RegionTags: true,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
if err != nil {
return err
@ -131,7 +137,7 @@ func generateRootFile(data *Data) error {
_, thisFile, _, _ := runtime.Caller(0)
rootDir := filepath.Dir(thisFile)
templatePath := filepath.Join(rootDir, "root_.gotpl")
templateBytes, err := ioutil.ReadFile(templatePath)
templateBytes, err := os.ReadFile(templatePath)
if err != nil {
return err
}
@ -145,6 +151,7 @@ func generateRootFile(data *Data) error {
RegionTags: false,
GeneratedHeader: true,
Packages: data.Config.Packages,
TemplateFS: codegenTemplates,
})
}

View file

@ -7,6 +7,7 @@
{{ reserveImport "sync/atomic" }}
{{ reserveImport "errors" }}
{{ reserveImport "bytes" }}
{{ reserveImport "embed" }}
{{ reserveImport "github.com/vektah/gqlparser/v2" "gqlparser" }}
{{ reserveImport "github.com/vektah/gqlparser/v2/ast" }}
@ -50,6 +51,7 @@
}
type ComplexityRoot struct {
{{- if not .Config.OmitComplexity }}
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ ucFirst $object.Name }} struct {
@ -62,6 +64,7 @@
}
{{- end }}
{{ end }}
{{- end }}
}
{{ end }}
@ -103,6 +106,7 @@
func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e}
_ = ec
{{ if not .Config.OmitComplexity -}}
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
@ -129,12 +133,20 @@
{{ end }}
{{ end }}
}
{{- end }}
return 0, false
}
func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
rc := graphql.GetOperationContext(ctx)
ec := executionContext{rc, e}
inputUnmarshalMap := graphql.BuildUnmarshalerMap(
{{- range $input := .Inputs -}}
{{ if not $input.HasUnmarshal }}
ec.unmarshalInput{{ $input.Name }},
{{- end }}
{{- end }}
)
first := true
switch rc.Operation.Operation {
@ -142,6 +154,7 @@
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
{{ if .Directives.LocationDirectives "QUERY" -}}
data := ec._queryMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
@ -162,6 +175,7 @@
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
{{ if .Directives.LocationDirectives "MUTATION" -}}
data := ec._mutationMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
@ -190,7 +204,7 @@
var buf bytes.Buffer
return func(ctx context.Context) *graphql.Response {
buf.Reset()
data := next()
data := next(ctx)
if data == nil {
return nil
@ -226,9 +240,22 @@
return introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name]), nil
}
{{if .HasEmbeddableSources }}
//go:embed{{- range $source := .AugmentedSources }}{{if $source.Embeddable}} {{$source.RelativePath|quote}}{{end}}{{- end }}
var sourcesFS embed.FS
func sourceData(filename string) string {
data, err := sourcesFS.ReadFile(filename)
if err != nil {
panic(fmt.Sprintf("codegen problem: %s not available", filename))
}
return string(data)
}
{{- end }}
var sources = []*ast.Source{
{{- range $source := .Config.Sources }}
{Name: {{$source.Name|quote}}, Input: {{$source.Input|rawQuote}}, BuiltIn: {{$source.BuiltIn}}},
{{- range $source := .AugmentedSources }}
{Name: {{$source.RelativePath|quote}}, Input: {{if (not $source.Embeddable)}}{{$source.Source|rawQuote}}{{else}}sourceData({{$source.RelativePath|quote}}){{end}}, BuiltIn: {{$source.BuiltIn}}},
{{- end }}
}
var parsedSchema = gqlparser.MustLoadSchema(sources...)

View file

@ -1,6 +1,10 @@
{{- range $input := .Inputs }}
{{- if not .HasUnmarshal }}
func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{.Type | ref}}, error) {
{{- $it := "it" }}
{{- if .PointersInUmarshalInput }}
{{- $it = "&it" }}
{{- end }}
func (ec *executionContext) unmarshalInput{{ .Name }}(ctx context.Context, obj interface{}) ({{ if .PointersInUmarshalInput }}*{{ end }}{{.Type | ref}}, error) {
var it {{.Type | ref}}
asMap := map[string]interface{}{}
for k, v := range obj.(map[string]interface{}) {
@ -14,7 +18,12 @@
{{- end}}
{{- end }}
for k, v := range asMap {
fieldsInOrder := [...]string{ {{ range .Fields }}{{ quote .Name }},{{ end }} }
for _, k := range fieldsInOrder {
v, ok := asMap[k]
if !ok {
continue
}
switch k {
{{- range $field := .Fields }}
case {{$field.Name|quote}}:
@ -26,12 +35,12 @@
{{ template "implDirectives" $field }}
tmp, err := directive{{$field.ImplDirectives|len}}(ctx)
if err != nil {
return it, graphql.ErrorOnPath(ctx, err)
return {{$it}}, graphql.ErrorOnPath(ctx, err)
}
if data, ok := tmp.({{ $field.TypeReference.GO | ref }}) ; ok {
{{- if $field.IsResolver }}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return it, err
return {{$it}}, err
}
{{- else }}
it.{{$field.GoFieldName}} = data
@ -44,21 +53,21 @@
{{- end }}
} else {
err := fmt.Errorf(`unexpected type %T from directive, should be {{ $field.TypeReference.GO }}`, tmp)
return it, graphql.ErrorOnPath(ctx, err)
return {{$it}}, graphql.ErrorOnPath(ctx, err)
}
{{- else }}
{{- if $field.IsResolver }}
data, err := ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return it, err
return {{$it}}, err
}
if err = ec.resolvers.{{ $field.ShortInvocation }}; err != nil {
return it, err
return {{$it}}, err
}
{{- else }}
it.{{$field.GoFieldName}}, err = ec.{{ $field.TypeReference.UnmarshalFunc }}(ctx, v)
if err != nil {
return it, err
return {{$it}}, err
}
{{- end }}
{{- end }}
@ -66,7 +75,7 @@
}
}
return it, nil
return {{$it}}, nil
}
{{- end }}
{{ end }}

View file

@ -9,6 +9,8 @@ import (
"github.com/99designs/gqlgen/codegen/config"
"github.com/vektah/gqlparser/v2/ast"
"golang.org/x/text/cases"
"golang.org/x/text/language"
)
type GoFieldType int
@ -31,6 +33,7 @@ type Object struct {
DisableConcurrency bool
Stream bool
Directives []*Directive
PointersInUmarshalInput bool
}
func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
@ -38,15 +41,16 @@ func (b *builder) buildObject(typ *ast.Definition) (*Object, error) {
if err != nil {
return nil, fmt.Errorf("%s: %w", typ.Name, err)
}
caser := cases.Title(language.English, cases.NoLower)
obj := &Object{
Definition: typ,
Root: b.Schema.Query == typ || b.Schema.Mutation == typ || b.Schema.Subscription == typ,
DisableConcurrency: typ == b.Schema.Mutation,
Stream: typ == b.Schema.Subscription,
Directives: dirs,
PointersInUmarshalInput: b.Config.ReturnPointersInUmarshalInput,
ResolverInterface: types.NewNamed(
types.NewTypeName(0, b.Config.Exec.Pkg(), strings.Title(typ.Name)+"Resolver", nil),
types.NewTypeName(0, b.Config.Exec.Pkg(), caser.String(typ.Name)+"Resolver", nil),
nil,
nil,
),

View file

@ -3,7 +3,7 @@
var {{ $object.Name|lcFirst}}Implementors = {{$object.Implementors}}
{{- if .Stream }}
func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.SelectionSet) func() graphql.Marshaler {
func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.SelectionSet) func(ctx context.Context) graphql.Marshaler {
fields := graphql.CollectFields(ec.OperationContext, sel, {{$object.Name|lcFirst}}Implementors)
ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{
Object: {{$object.Name|quote}},
@ -31,7 +31,9 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
})
{{end}}
out := graphql.NewFieldSet(fields)
{{- if not $object.Root }}
var invalids uint32
{{- end }}
for i, field := range fields {
{{- if $object.Root }}
innerCtx := graphql.WithRootFieldContext(ctx, &graphql.RootFieldContext{
@ -54,6 +56,7 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
}
}()
res = ec._{{$object.Name}}_{{$field.Name}}(ctx, field{{if not $object.Root}}, obj{{end}})
{{- if not $object.Root }}
{{- if $field.TypeReference.GQL.NonNull }}
if res == graphql.Null {
{{- if $object.IsConcurrent }}
@ -63,6 +66,7 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
{{- end }}
}
{{- end }}
{{- end }}
return res
}
@ -80,15 +84,15 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
{{end}}
})
{{- else }}
innerFunc := func(ctx context.Context) (res graphql.Marshaler) {
return ec._{{$object.Name}}_{{$field.Name}}(ctx, field{{if not $object.Root}}, obj{{end}})
}
{{if $object.Root}}
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, innerFunc)
out.Values[i] = ec.OperationContext.RootResolverMiddleware(innerCtx, func(ctx context.Context) (res graphql.Marshaler) {
return ec._{{$object.Name}}_{{$field.Name}}(ctx, field)
})
{{else}}
out.Values[i] = innerFunc(ctx)
out.Values[i] = ec._{{$object.Name}}_{{$field.Name}}(ctx, field, obj)
{{end}}
{{- if not $object.Root }}
{{- if $field.TypeReference.GQL.NonNull }}
if out.Values[i] == graphql.Null {
{{- if $object.IsConcurrent }}
@ -100,12 +104,15 @@ func (ec *executionContext) _{{$object.Name}}(ctx context.Context, sel ast.Selec
{{- end }}
{{- end }}
{{- end }}
{{- end }}
default:
panic("unknown field " + strconv.Quote(field.Name))
}
}
out.Dispatch()
{{- if not $object.Root }}
if invalids > 0 { return graphql.Null }
{{- end }}
return out
}
{{- end }}

View file

@ -7,6 +7,7 @@
{{ reserveImport "sync/atomic" }}
{{ reserveImport "errors" }}
{{ reserveImport "bytes" }}
{{ reserveImport "embed" }}
{{ reserveImport "github.com/vektah/gqlparser/v2" "gqlparser" }}
{{ reserveImport "github.com/vektah/gqlparser/v2/ast" }}
@ -34,6 +35,11 @@ type ResolverRoot interface {
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
{{- range $object := .Inputs -}}
{{ if $object.HasResolvers -}}
{{ucFirst $object.Name}}() {{ucFirst $object.Name}}Resolver
{{ end }}
{{- end }}
}
type DirectiveRoot struct {
@ -43,6 +49,7 @@ type DirectiveRoot struct {
}
type ComplexityRoot struct {
{{- if not .Config.OmitComplexity }}
{{ range $object := .Objects }}
{{ if not $object.IsReserved -}}
{{ ucFirst $object.Name }} struct {
@ -55,6 +62,7 @@ type ComplexityRoot struct {
}
{{- end }}
{{ end }}
{{- end }}
}
type executableSchema struct {
@ -70,6 +78,7 @@ func (e *executableSchema) Schema() *ast.Schema {
func (e *executableSchema) Complexity(typeName, field string, childComplexity int, rawArgs map[string]interface{}) (int, bool) {
ec := executionContext{nil, e}
_ = ec
{{- if not .Config.OmitComplexity }}
switch typeName + "." + field {
{{ range $object := .Objects }}
{{ if not $object.IsReserved }}
@ -96,12 +105,20 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in
{{ end }}
{{ end }}
}
{{- end }}
return 0, false
}
func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
rc := graphql.GetOperationContext(ctx)
ec := executionContext{rc, e}
inputUnmarshalMap := graphql.BuildUnmarshalerMap(
{{- range $input := .Inputs -}}
{{ if not $input.HasUnmarshal }}
ec.unmarshalInput{{ $input.Name }},
{{- end }}
{{- end }}
)
first := true
switch rc.Operation.Operation {
@ -109,6 +126,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
{{ if .Directives.LocationDirectives "QUERY" -}}
data := ec._queryMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.QueryRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
@ -129,6 +147,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
return func(ctx context.Context) *graphql.Response {
if !first { return nil }
first = false
ctx = graphql.WithUnmarshalerMap(ctx, inputUnmarshalMap)
{{ if .Directives.LocationDirectives "MUTATION" -}}
data := ec._mutationMiddleware(ctx, rc.Operation, func(ctx context.Context) (interface{}, error){
return ec._{{.MutationRoot.Name}}(ctx, rc.Operation.SelectionSet), nil
@ -157,7 +176,7 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler {
var buf bytes.Buffer
return func(ctx context.Context) *graphql.Response {
buf.Reset()
data := next()
data := next(ctx)
if data == nil {
return nil
@ -193,9 +212,23 @@ func (ec *executionContext) introspectType(name string) (*introspection.Type, er
return introspection.WrapTypeFromDef(parsedSchema, parsedSchema.Types[name]), nil
}
{{if .HasEmbeddableSources }}
//go:embed{{- range $source := .AugmentedSources }}{{if $source.Embeddable}} {{$source.RelativePath|quote}}{{end}}{{- end }}
var sourcesFS embed.FS
func sourceData(filename string) string {
data, err := sourcesFS.ReadFile(filename)
if err != nil {
panic(fmt.Sprintf("codegen problem: %s not available", filename))
}
return string(data)
}
{{- end}}
var sources = []*ast.Source{
{{- range $source := .Config.Sources }}
{Name: {{$source.Name|quote}}, Input: {{$source.Input|rawQuote}}, BuiltIn: {{$source.BuiltIn}}},
{{- range $source := .AugmentedSources }}
{Name: {{$source.RelativePath|quote}}, Input: {{if (not $source.Embeddable)}}{{$source.Source|rawQuote}}{{else}}sourceData({{$source.RelativePath|quote}}){{end}}, BuiltIn: {{$source.BuiltIn}}},
{{- end }}
}
var parsedSchema = gqlparser.MustLoadSchema(sources...)

View file

@ -45,7 +45,7 @@ func (s *Imports) Reserve(path string, aliases ...string) (string, error) {
panic("empty ambient import")
}
// if we are referencing our own package we dont need an import
// if we are referencing our own package we don't need an import
if code.ImportPathForDir(s.destDir) == path {
return "", nil
}
@ -85,7 +85,7 @@ func (s *Imports) Lookup(path string) string {
path = code.NormalizeVendor(path)
// if we are referencing our own package we dont need an import
// if we are referencing our own package we don't need an import
if code.ImportPathForDir(s.destDir) == path {
return ""
}

View file

@ -4,14 +4,16 @@ import (
"bytes"
"fmt"
"go/types"
"io/ioutil"
"io/fs"
"os"
"path/filepath"
"reflect"
"regexp"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"text/template"
"unicode"
@ -36,6 +38,11 @@ type Options struct {
// the plugin processor will look for .gotpl files
// in the same directory of where you wrote the plugin.
Template string
// Use the go:embed API to collect all the template files you want to pass into Render
// this is an alternative to passing the Template option
TemplateFS fs.FS
// Filename is the name of the file that will be
// written to the system disk once the template is rendered.
Filename string
@ -53,6 +60,12 @@ type Options struct {
Packages *code.Packages
}
var (
modelNamesMu sync.Mutex
modelNames = make(map[string]string, 0)
goNameRe = regexp.MustCompile("[^a-zA-Z0-9_]")
)
// Render renders a gql plugin template from the given Options. Render is an
// abstraction of the text/template package that makes it easier to write gqlgen
// plugins. If Options.Template is empty, the Render function will look for `.gotpl`
@ -63,55 +76,27 @@ func Render(cfg Options) error {
}
CurrentImports = &Imports{packages: cfg.Packages, destDir: filepath.Dir(cfg.Filename)}
// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
rootDir := filepath.Dir(callerFile)
funcs := Funcs()
for n, f := range cfg.Funcs {
funcs[n] = f
}
t := template.New("").Funcs(funcs)
var roots []string
if cfg.Template != "" {
var err error
t, err = t.New("template.gotpl").Parse(cfg.Template)
if err != nil {
return fmt.Errorf("error with provided template: %w", err)
}
roots = append(roots, "template.gotpl")
} else {
// load all the templates in the directory
err := filepath.Walk(rootDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
name := filepath.ToSlash(strings.TrimPrefix(path, rootDir+string(os.PathSeparator)))
if !strings.HasSuffix(info.Name(), ".gotpl") {
return nil
}
// omit any templates with "_" at the end of their name, which are meant for specific contexts only
if strings.HasSuffix(info.Name(), "_.gotpl") {
return nil
}
b, err := ioutil.ReadFile(path)
t, err := parseTemplates(cfg, t)
if err != nil {
return err
}
t, err = t.New(name).Parse(string(b))
if err != nil {
return fmt.Errorf("%s: %w", cfg.Filename, err)
roots := make([]string, 0, len(t.Templates()))
for _, template := range t.Templates() {
// templates that end with _.gotpl are special files we don't want to include
if strings.HasSuffix(template.Name(), "_.gotpl") ||
// filter out templates added with {{ template xxx }} syntax inside the template file
!strings.HasSuffix(template.Name(), ".gotpl") {
continue
}
roots = append(roots, name)
return nil
})
if err != nil {
return fmt.Errorf("locating templates: %w", err)
}
roots = append(roots, template.Name())
}
// then execute all the important looking ones in order, adding them to the same file
@ -125,6 +110,7 @@ func Render(cfg Options) error {
}
return roots[i] < roots[j]
})
var buf bytes.Buffer
for _, root := range roots {
if cfg.RegionTags {
@ -156,7 +142,7 @@ func Render(cfg Options) error {
result.WriteString("import (\n")
result.WriteString(CurrentImports.String())
result.WriteString(")\n")
_, err := buf.WriteTo(&result)
_, err = buf.WriteTo(&result)
if err != nil {
return err
}
@ -171,6 +157,34 @@ func Render(cfg Options) error {
return nil
}
func parseTemplates(cfg Options, t *template.Template) (*template.Template, error) {
if cfg.Template != "" {
var err error
t, err = t.New("template.gotpl").Parse(cfg.Template)
if err != nil {
return nil, fmt.Errorf("error with provided template: %w", err)
}
return t, nil
}
var fileSystem fs.FS
if cfg.TemplateFS != nil {
fileSystem = cfg.TemplateFS
} else {
// load path relative to calling source file
_, callerFile, _, _ := runtime.Caller(1)
rootDir := filepath.Dir(callerFile)
fileSystem = os.DirFS(rootDir)
}
t, err := t.ParseFS(fileSystem, "*.gotpl")
if err != nil {
return nil, fmt.Errorf("locating templates: %w", err)
}
return t, nil
}
func center(width int, pad string, s string) string {
if len(s)+2 > width {
return s
@ -196,6 +210,8 @@ func Funcs() template.FuncMap {
"lookupImport": CurrentImports.Lookup,
"go": ToGo,
"goPrivate": ToGoPrivate,
"goModelName": ToGoModelName,
"goPrivateModelName": ToGoPrivateModelName,
"add": func(a, b int) int {
return a + b
},
@ -285,25 +301,154 @@ func Call(p *types.Func) string {
return pkg + p.Name()
}
func resetModelNames() {
modelNamesMu.Lock()
defer modelNamesMu.Unlock()
modelNames = make(map[string]string, 0)
}
func buildGoModelNameKey(parts []string) string {
const sep = ":"
return strings.Join(parts, sep)
}
func goModelName(primaryToGoFunc func(string) string, parts []string) string {
modelNamesMu.Lock()
defer modelNamesMu.Unlock()
var (
goNameKey string
partLen int
nameExists = func(n string) bool {
for _, v := range modelNames {
if n == v {
return true
}
}
return false
}
applyToGoFunc = func(parts []string) string {
var out string
switch len(parts) {
case 0:
return ""
case 1:
return primaryToGoFunc(parts[0])
default:
out = primaryToGoFunc(parts[0])
}
for _, p := range parts[1:] {
out = fmt.Sprintf("%s%s", out, ToGo(p))
}
return out
}
applyValidGoName = func(parts []string) string {
var out string
for _, p := range parts {
out = fmt.Sprintf("%s%s", out, replaceInvalidCharacters(p))
}
return out
}
)
// build key for this entity
goNameKey = buildGoModelNameKey(parts)
// determine if we've seen this entity before, and reuse if so
if goName, ok := modelNames[goNameKey]; ok {
return goName
}
// attempt first pass
if goName := applyToGoFunc(parts); !nameExists(goName) {
modelNames[goNameKey] = goName
return goName
}
// determine number of parts
partLen = len(parts)
// if there is only 1 part, append incrementing number until no conflict
if partLen == 1 {
base := applyToGoFunc(parts)
for i := 0; ; i++ {
tmp := fmt.Sprintf("%s%d", base, i)
if !nameExists(tmp) {
modelNames[goNameKey] = tmp
return tmp
}
}
}
// best effort "pretty" name
for i := partLen - 1; i >= 1; i-- {
tmp := fmt.Sprintf("%s%s", applyToGoFunc(parts[0:i]), applyValidGoName(parts[i:]))
if !nameExists(tmp) {
modelNames[goNameKey] = tmp
return tmp
}
}
// finally, fallback to just adding an incrementing number
base := applyToGoFunc(parts)
for i := 0; ; i++ {
tmp := fmt.Sprintf("%s%d", base, i)
if !nameExists(tmp) {
modelNames[goNameKey] = tmp
return tmp
}
}
}
func ToGoModelName(parts ...string) string {
return goModelName(ToGo, parts)
}
func ToGoPrivateModelName(parts ...string) string {
return goModelName(ToGoPrivate, parts)
}
func replaceInvalidCharacters(in string) string {
return goNameRe.ReplaceAllLiteralString(in, "_")
}
func wordWalkerFunc(private bool, nameRunes *[]rune) func(*wordInfo) {
return func(info *wordInfo) {
word := info.Word
switch {
case private && info.WordOffset == 0:
if strings.ToUpper(word) == word || strings.ToLower(word) == word {
// ID → id, CAMEL → camel
word = strings.ToLower(info.Word)
} else {
// ITicket → iTicket
word = LcFirst(info.Word)
}
case info.MatchCommonInitial:
word = strings.ToUpper(word)
case !info.HasCommonInitial && (strings.ToUpper(word) == word || strings.ToLower(word) == word):
// FOO or foo → Foo
// FOo → FOo
word = UcFirst(strings.ToLower(word))
}
*nameRunes = append(*nameRunes, []rune(word)...)
}
}
func ToGo(name string) string {
if name == "_" {
return "_"
}
runes := make([]rune, 0, len(name))
wordWalker(name, func(info *wordInfo) {
word := info.Word
if info.MatchCommonInitial {
word = strings.ToUpper(word)
} else if !info.HasCommonInitial {
if strings.ToUpper(word) == word || strings.ToLower(word) == word {
// FOO or foo → Foo
// FOo → FOo
word = UcFirst(strings.ToLower(word))
}
}
runes = append(runes, []rune(word)...)
})
wordWalker(name, wordWalkerFunc(false, &runes))
return string(runes)
}
@ -314,31 +459,13 @@ func ToGoPrivate(name string) string {
}
runes := make([]rune, 0, len(name))
first := true
wordWalker(name, func(info *wordInfo) {
word := info.Word
switch {
case first:
if strings.ToUpper(word) == word || strings.ToLower(word) == word {
// ID → id, CAMEL → camel
word = strings.ToLower(info.Word)
} else {
// ITicket → iTicket
word = LcFirst(info.Word)
}
first = false
case info.MatchCommonInitial:
word = strings.ToUpper(word)
case !info.HasCommonInitial:
word = UcFirst(strings.ToLower(word))
}
runes = append(runes, []rune(word)...)
})
wordWalker(name, wordWalkerFunc(true, &runes))
return sanitizeKeywords(string(runes))
}
type wordInfo struct {
WordOffset int
Word string
MatchCommonInitial bool
HasCommonInitial bool
@ -348,7 +475,7 @@ type wordInfo struct {
// https://github.com/golang/lint/blob/06c8688daad7faa9da5a0c2f163a3d14aac986ca/lint.go#L679
func wordWalker(str string, f func(*wordInfo)) {
runes := []rune(strings.TrimFunc(str, isDelimiter))
w, i := 0, 0 // index of start of word, scan
w, i, wo := 0, 0, 0 // index of start of word, scan, word offset
hasCommonInitial := false
for i+1 <= len(runes) {
eow := false // whether we hit the end of a word
@ -396,12 +523,14 @@ func wordWalker(str string, f func(*wordInfo)) {
}
f(&wordInfo{
WordOffset: wo,
Word: word,
MatchCommonInitial: matchCommonInitial,
HasCommonInitial: hasCommonInitial,
})
hasCommonInitial = false
w = i
wo++
}
}
@ -576,7 +705,7 @@ func resolveName(name string, skip int) string {
func render(filename string, tpldata interface{}) (*bytes.Buffer, error) {
t := template.New("").Funcs(Funcs())
b, err := ioutil.ReadFile(filename)
b, err := os.ReadFile(filename)
if err != nil {
return nil, err
}
@ -602,7 +731,7 @@ func write(filename string, b []byte, packages *code.Packages) error {
formatted = b
}
err = ioutil.WriteFile(filename, formatted, 0o644)
err = os.WriteFile(filename, formatted, 0o644)
if err != nil {
return fmt.Errorf("failed to write %s: %w", filename, err)
}

View file

@ -0,0 +1 @@
this is my test package

View file

@ -0,0 +1 @@
this will not be included

View file

@ -56,6 +56,10 @@
return *res, graphql.ErrorOnPath(ctx, err)
{{- else if and (not $type.IsTargetNilable) $type.IsNilable }}
return &res, graphql.ErrorOnPath(ctx, err)
{{- else if or $type.IsUnusualBasic $type.IsUnderlyingUnusualBasic }}
return {{ $type.GO | ref }}(res), graphql.ErrorOnPath(ctx, err)
{{- else if and $type.IsNamed $type.Definition.BuiltIn (not $type.IsScalarID) }}
return {{ $type.GO | ref }}(res), graphql.ErrorOnPath(ctx, err)
{{- else}}
return res, graphql.ErrorOnPath(ctx, err)
{{- end }}
@ -75,9 +79,11 @@
return res, graphql.ErrorOnPath(ctx, err)
{{- else }}
res, err := ec.unmarshalInput{{ $type.GQL.Name }}(ctx, v)
{{- if $type.IsNilable }}
{{- if and $type.IsNilable (not $type.PointersInUmarshalInput) }}
return &res, graphql.ErrorOnPath(ctx, err)
{{- else}}
{{- else if and (not $type.IsNilable) $type.PointersInUmarshalInput }}
return *res, graphql.ErrorOnPath(ctx, err)
{{- else }}
return res, graphql.ErrorOnPath(ctx, err)
{{- end }}
{{- end }}
@ -151,7 +157,7 @@
if v == nil {
{{- if $type.GQL.NonNull }}
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "must not be null")
ec.Errorf(ctx, "the requested element is null which the schema does not allow")
}
{{- end }}
return graphql.Null
@ -170,11 +176,17 @@
{{- else if and (not $type.IsTargetNilable) $type.IsNilable }}
{{- $v = "*v" }}
{{- end }}
{{- if or $type.IsUnusualBasic $type.IsUnderlyingUnusualBasic }}
res := {{ $type.Marshaler | call }}({{ $type.Target | ref }}({{ $v }}))
{{- else if and $type.IsNamed $type.Definition.BuiltIn (not $type.IsScalarID) }}
res := {{ $type.Marshaler | call }}({{- if and $type.GO.Underlying $type.IsUnderlyingBasic }}{{ $type.GO.Underlying | ref }}({{ $v }}){{else}}{{ $v }}{{- end }})
{{- else }}
res := {{ $type.Marshaler | call }}({{- if $type.CastType }}{{ $type.CastType | ref }}({{ $v }}){{else}}{{ $v }}{{- end }})
{{- end }}
{{- if $type.GQL.NonNull }}
if res == graphql.Null {
if !graphql.HasFieldError(ctx, graphql.GetFieldContext(ctx)) {
ec.Errorf(ctx, "must not be null")
ec.Errorf(ctx, "the requested element is null which the schema does not allow")
}
}
{{- end }}

View file

@ -41,6 +41,7 @@ func findGoInterface(def types.Type) (*types.Interface, error) {
func equalFieldName(source, target string) bool {
source = strings.ReplaceAll(source, "_", "")
source = strings.ReplaceAll(source, ",omitempty", "")
target = strings.ReplaceAll(target, "_", "")
return strings.EqualFold(source, target)
}

View file

@ -30,6 +30,25 @@ type FieldContext struct {
IsMethod bool
// IsResolver indicates if the field has a user-specified resolver
IsResolver bool
// Child allows getting a child FieldContext by its field collection description.
// Note that, the returned child FieldContext represents the context as it was
// before the execution of the field resolver. For example:
//
// srv.AroundFields(func(ctx context.Context, next graphql.Resolver) (interface{}, error) {
// fc := graphql.GetFieldContext(ctx)
// op := graphql.GetOperationContext(ctx)
// collected := graphql.CollectFields(opCtx, fc.Field.Selections, []string{"User"})
//
// child, err := fc.Child(ctx, collected[0])
// if err != nil {
// return nil, err
// }
// fmt.Println("child context %q with args: %v", child.Field.Name, child.Args)
//
// return next(ctx)
// })
//
Child func(context.Context, CollectedField) (*FieldContext, error)
}
type FieldStats struct {

View file

@ -3,8 +3,10 @@ package graphql
import (
"context"
"errors"
"net/http"
"github.com/vektah/gqlparser/v2/ast"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// Deprecated: Please update all references to OperationContext instead
@ -15,6 +17,7 @@ type OperationContext struct {
Variables map[string]interface{}
OperationName string
Doc *ast.QueryDocument
Headers http.Header
Operation *ast.OperationDefinition
DisableIntrospection bool
@ -104,9 +107,16 @@ func (c *OperationContext) Errorf(ctx context.Context, format string, args ...in
AddErrorf(ctx, format, args...)
}
// Error sends an error to the client, passing it through the formatter.
// Deprecated: use graphql.AddError(ctx, err) instead
// Error add error or multiple errors (if underlaying type is gqlerror.List) into the stack.
// Then it will be sends to the client, passing it through the formatter.
func (c *OperationContext) Error(ctx context.Context, err error) {
if errList, ok := err.(gqlerror.List); ok {
for _, e := range errList {
AddError(ctx, e)
}
return
}
AddError(ctx, err)
}

View file

@ -23,18 +23,27 @@ var codeType = map[string]ErrorKind{
ParseFailed: KindProtocol,
}
// RegisterErrorType should be called by extensions that want to customize the http status codes for errors they return
// RegisterErrorType should be called by extensions that want to customize the http status codes for
// errors they return
func RegisterErrorType(code string, kind ErrorKind) {
codeType[code] = kind
}
// Set the error code on a given graphql error extension
func Set(err *gqlerror.Error, value string) {
if err.Extensions == nil {
err.Extensions = map[string]interface{}{}
func Set(err error, value string) {
if err == nil {
return
}
gqlErr, ok := err.(*gqlerror.Error)
if !ok {
return
}
err.Extensions["code"] = value
if gqlErr.Extensions == nil {
gqlErr.Extensions = map[string]interface{}{}
}
gqlErr.Extensions["code"] = value
}
// get the kind of the first non User error, defaults to User if no errors have a custom extension

View file

@ -26,7 +26,8 @@ func ErrorOnPath(ctx context.Context, err error) error {
if gqlErr.Path == nil {
gqlErr.Path = GetPath(ctx)
}
return gqlErr
// Return the original error to avoid losing any attached annotation
return err
}
return gqlerror.WrapPath(GetPath(ctx), err)
}

View file

@ -118,6 +118,11 @@ func getOrCreateAndAppendField(c *[]CollectedField, name string, alias string, o
return &(*c)[i]
}
}
for _, ifc := range cf.ObjectDefinition.Interfaces {
if ifc == objectDefinition.Name {
return &(*c)[i]
}
}
}
}

View file

@ -37,7 +37,10 @@ func New(es graphql.ExecutableSchema) *Executor {
return e
}
func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.RawParams) (*graphql.OperationContext, gqlerror.List) {
func (e *Executor) CreateOperationContext(
ctx context.Context,
params *graphql.RawParams,
) (*graphql.OperationContext, gqlerror.List) {
rc := &graphql.OperationContext{
DisableIntrospection: true,
RecoverFunc: e.recoverFunc,
@ -58,6 +61,7 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R
rc.RawQuery = params.Query
rc.OperationName = params.OperationName
rc.Headers = params.Headers
var listErr gqlerror.List
rc.Doc, listErr = e.parseQuery(ctx, &rc.Stats, params.Query)
@ -67,15 +71,21 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R
rc.Operation = rc.Doc.Operations.ForName(params.OperationName)
if rc.Operation == nil {
return rc, gqlerror.List{gqlerror.Errorf("operation %s not found", params.OperationName)}
}
var err *gqlerror.Error
rc.Variables, err = validator.VariableValues(e.es.Schema(), rc.Operation, params.Variables)
if err != nil {
err := gqlerror.Errorf("operation %s not found", params.OperationName)
errcode.Set(err, errcode.ValidationFailed)
return rc, gqlerror.List{err}
}
var err error
rc.Variables, err = validator.VariableValues(e.es.Schema(), rc.Operation, params.Variables)
if err != nil {
gqlErr, ok := err.(*gqlerror.Error)
if ok {
errcode.Set(gqlErr, errcode.ValidationFailed)
return rc, gqlerror.List{gqlErr}
}
}
rc.Stats.Validation.End = graphql.Now()
for _, p := range e.ext.operationContextMutators {
@ -87,7 +97,10 @@ func (e *Executor) CreateOperationContext(ctx context.Context, params *graphql.R
return rc, nil
}
func (e *Executor) DispatchOperation(ctx context.Context, rc *graphql.OperationContext) (graphql.ResponseHandler, context.Context) {
func (e *Executor) DispatchOperation(
ctx context.Context,
rc *graphql.OperationContext,
) (graphql.ResponseHandler, context.Context) {
ctx = graphql.WithOperationContext(ctx, rc)
var innerCtx context.Context
@ -130,7 +143,7 @@ func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graph
resp := e.ext.responseMiddleware(ctx, func(ctx context.Context) *graphql.Response {
resp := &graphql.Response{
Errors: list,
Errors: graphql.GetErrors(ctx),
}
resp.Extensions = graphql.GetExtensions(ctx)
return resp
@ -139,7 +152,7 @@ func (e *Executor) DispatchError(ctx context.Context, list gqlerror.List) *graph
return resp
}
func (e *Executor) PresentRecoveredError(ctx context.Context, err interface{}) *gqlerror.Error {
func (e *Executor) PresentRecoveredError(ctx context.Context, err interface{}) error {
return e.errorPresenter(ctx, e.recoverFunc(ctx, err))
}
@ -157,9 +170,14 @@ func (e *Executor) SetRecoverFunc(f graphql.RecoverFunc) {
// parseQuery decodes the incoming query and validates it, pulling from cache if present.
//
// NOTE: This should NOT look at variables, they will change per request. It should only parse and validate
// NOTE: This should NOT look at variables, they will change per request. It should only parse and
// validate
// the raw query string.
func (e *Executor) parseQuery(ctx context.Context, stats *graphql.Stats, query string) (*ast.QueryDocument, gqlerror.List) {
func (e *Executor) parseQuery(
ctx context.Context,
stats *graphql.Stats,
query string,
) (*ast.QueryDocument, gqlerror.List) {
stats.Parsing.Start = graphql.Now()
if doc, ok := e.queryCache.Get(ctx, query); ok {
@ -172,12 +190,23 @@ func (e *Executor) parseQuery(ctx context.Context, stats *graphql.Stats, query s
doc, err := parser.ParseQuery(&ast.Source{Input: query})
if err != nil {
errcode.Set(err, errcode.ParseFailed)
return nil, gqlerror.List{err}
gqlErr, ok := err.(*gqlerror.Error)
if ok {
errcode.Set(gqlErr, errcode.ParseFailed)
return nil, gqlerror.List{gqlErr}
}
}
stats.Parsing.End = graphql.Now()
stats.Validation.Start = graphql.Now()
if len(doc.Operations) == 0 {
err = gqlerror.Errorf("no operation provided")
gqlErr, _ := err.(*gqlerror.Error)
errcode.Set(err, errcode.ValidationFailed)
return nil, gqlerror.List{gqlErr}
}
listErr := validator.Validate(e.es.Schema(), doc)
if len(listErr) != 0 {
for _, e := range listErr {

View file

@ -27,6 +27,7 @@ type (
OperationName string `json:"operationName"`
Variables map[string]interface{} `json:"variables"`
Extensions map[string]interface{} `json:"extensions"`
Headers http.Header `json:"headers"`
ReadTime TraceTiming `json:"-"`
}

View file

@ -102,7 +102,8 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
err := s.exec.PresentRecoveredError(r.Context(), err)
resp := &graphql.Response{Errors: []*gqlerror.Error{err}}
gqlErr, _ := err.(*gqlerror.Error)
resp := &graphql.Response{Errors: []*gqlerror.Error{gqlErr}}
b, _ := json.Marshal(resp)
w.WriteHeader(http.StatusUnprocessableEntity)
w.Write(b)

View file

@ -3,11 +3,9 @@ package transport
import (
"encoding/json"
"io"
"io/ioutil"
"mime"
"net/http"
"os"
"strings"
"github.com/99designs/gqlgen/graphql"
)
@ -64,63 +62,68 @@ func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.G
return
}
r.Body = http.MaxBytesReader(w, r.Body, f.maxUploadSize())
if err = r.ParseMultipartForm(f.maxMemory()); err != nil {
defer r.Body.Close()
mr, err := r.MultipartReader()
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
if strings.Contains(err.Error(), "request body too large") {
writeJsonError(w, "failed to parse multipart form, request body too large")
return
}
writeJsonError(w, "failed to parse multipart form")
return
}
defer r.Body.Close()
part, err := mr.NextPart()
if err != nil || part.FormName() != "operations" {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "first part must be operations")
return
}
var params graphql.RawParams
if err = jsonDecode(strings.NewReader(r.Form.Get("operations")), &params); err != nil {
if err = jsonDecode(part, &params); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "operations form field could not be decoded")
return
}
part, err = mr.NextPart()
if err != nil || part.FormName() != "map" {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "second part must be map")
return
}
uploadsMap := map[string][]string{}
if err = json.Unmarshal([]byte(r.Form.Get("map")), &uploadsMap); err != nil {
if err = json.NewDecoder(part).Decode(&uploadsMap); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonError(w, "map form field could not be decoded")
return
}
var upload graphql.Upload
for key, paths := range uploadsMap {
for {
part, err = mr.NextPart()
if err == io.EOF {
break
} else if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to parse part")
return
}
key := part.FormName()
filename := part.FileName()
contentType := part.Header.Get("Content-Type")
paths := uploadsMap[key]
if len(paths) == 0 {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "invalid empty operations paths list for key %s", key)
return
}
file, header, err := r.FormFile(key)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to get key %s from form", key)
return
}
defer file.Close()
delete(uploadsMap, key)
if len(paths) == 1 {
upload = graphql.Upload{
File: file,
Size: header.Size,
Filename: header.Filename,
ContentType: header.Header.Get("Content-Type"),
}
if err := params.AddUpload(upload, key, paths[0]); err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonGraphqlError(w, err)
return
}
} else {
var upload graphql.Upload
if r.ContentLength < f.maxMemory() {
fileBytes, err := ioutil.ReadAll(file)
fileBytes, err := io.ReadAll(part)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to read file for key %s", key)
@ -128,10 +131,10 @@ func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.G
}
for _, path := range paths {
upload = graphql.Upload{
File: &bytesReader{s: &fileBytes, i: 0, prevRune: -1},
Size: header.Size,
Filename: header.Filename,
ContentType: header.Header.Get("Content-Type"),
File: &bytesReader{s: &fileBytes, i: 0},
Size: int64(len(fileBytes)),
Filename: filename,
ContentType: contentType,
}
if err := params.AddUpload(upload, key, path); err != nil {
@ -141,7 +144,7 @@ func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.G
}
}
} else {
tmpFile, err := ioutil.TempFile(os.TempDir(), "gqlgen-")
tmpFile, err := os.CreateTemp(os.TempDir(), "gqlgen-")
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to create temp file for key %s", key)
@ -151,7 +154,7 @@ func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.G
defer func() {
_ = os.Remove(tmpName)
}()
_, err = io.Copy(tmpFile, file)
fileSize, err := io.Copy(tmpFile, part)
if err != nil {
w.WriteHeader(http.StatusUnprocessableEntity)
if err := tmpFile.Close(); err != nil {
@ -176,9 +179,9 @@ func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.G
defer pathTmpFile.Close()
upload = graphql.Upload{
File: pathTmpFile,
Size: header.Size,
Filename: header.Filename,
ContentType: header.Header.Get("Content-Type"),
Size: fileSize,
Filename: filename,
ContentType: contentType,
}
if err := params.AddUpload(upload, key, path); err != nil {
@ -189,8 +192,15 @@ func (f MultipartForm) Do(w http.ResponseWriter, r *http.Request, exec graphql.G
}
}
}
for key := range uploadsMap {
w.WriteHeader(http.StatusUnprocessableEntity)
writeJsonErrorf(w, "failed to get key %s from form", key)
return
}
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),

View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"io"
"net/http"
"net/url"
"strings"
"github.com/99designs/gqlgen/graphql"
@ -27,15 +28,22 @@ func (h GET) Supports(r *http.Request) bool {
}
func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
query, err := url.ParseQuery(r.URL.RawQuery)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonError(w, err.Error())
return
}
w.Header().Set("Content-Type", "application/json")
raw := &graphql.RawParams{
Query: r.URL.Query().Get("query"),
OperationName: r.URL.Query().Get("operationName"),
Query: query.Get("query"),
OperationName: query.Get("operationName"),
Headers: r.Header,
}
raw.ReadTime.Start = graphql.Now()
if variables := r.URL.Query().Get("variables"); variables != "" {
if variables := query.Get("variables"); variables != "" {
if err := jsonDecode(strings.NewReader(variables), &raw.Variables); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonError(w, "variables could not be decoded")
@ -43,7 +51,7 @@ func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut
}
}
if extensions := r.URL.Query().Get("extensions"); extensions != "" {
if extensions := query.Get("extensions"); extensions != "" {
if err := jsonDecode(strings.NewReader(extensions), &raw.Extensions); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonError(w, "extensions could not be decoded")
@ -53,10 +61,10 @@ func (h GET) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecut
raw.ReadTime.End = graphql.Now()
rc, err := exec.CreateOperationContext(r.Context(), raw)
if err != nil {
w.WriteHeader(statusFor(err))
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), err)
rc, gqlError := exec.CreateOperationContext(r.Context(), raw)
if gqlError != nil {
w.WriteHeader(statusFor(gqlError))
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), gqlError)
writeJson(w, resp)
return
}

View file

@ -1,8 +1,14 @@
package transport
import (
"fmt"
"io"
"log"
"mime"
"net/http"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
@ -26,28 +32,59 @@ func (h POST) Supports(r *http.Request) bool {
return r.Method == "POST" && mediaType == "application/json"
}
func (h POST) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
w.Header().Set("Content-Type", "application/json")
var params *graphql.RawParams
start := graphql.Now()
if err := jsonDecode(r.Body, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
writeJsonErrorf(w, "json body could not be decoded: "+err.Error())
return
func getRequestBody(r *http.Request) (string, error) {
if r == nil || r.Body == nil {
return "", nil
}
body, err := io.ReadAll(r.Body)
if err != nil {
return "", fmt.Errorf("unable to get Request Body %w", err)
}
return string(body), nil
}
func (h POST) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
w.Header().Set("Content-Type", "application/json")
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
rc, err := exec.CreateOperationContext(r.Context(), params)
bodyString, err := getRequestBody(r)
if err != nil {
w.WriteHeader(statusFor(err))
resp := exec.DispatchError(graphql.WithOperationContext(r.Context(), rc), err)
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("could not get json request body: %+v", err.Error())
writeJson(w, resp)
}
bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("decoding error: %+v body:%s", err.Error(), bodyString)
writeJson(w, resp)
return
}
responses, ctx := exec.DispatchOperation(r.Context(), rc)
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
var responses graphql.ResponseHandler
responses, ctx = exec.DispatchOperation(ctx, rc)
writeJson(w, responses(ctx))
}

View file

@ -2,12 +2,16 @@ package transport
import (
"net/http"
"strings"
"github.com/99designs/gqlgen/graphql"
)
// Options responds to http OPTIONS and HEAD requests
type Options struct{}
type Options struct {
// AllowedMethods is a list of allowed HTTP methods.
AllowedMethods []string
}
var _ graphql.Transport = Options{}
@ -18,9 +22,16 @@ func (o Options) Supports(r *http.Request) bool {
func (o Options) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
switch r.Method {
case http.MethodOptions:
w.Header().Set("Allow", "OPTIONS, GET, POST")
w.Header().Set("Allow", o.allowedMethods())
w.WriteHeader(http.StatusOK)
case http.MethodHead:
w.WriteHeader(http.StatusMethodNotAllowed)
}
}
func (o Options) allowedMethods() string {
if len(o.AllowedMethods) == 0 {
return "OPTIONS, GET, POST"
}
return strings.Join(o.AllowedMethods, ", ")
}

View file

@ -8,7 +8,6 @@ import (
type bytesReader struct {
s *[]byte
i int64 // current reading index
prevRune int // index of previous rune; or < 0
}
func (r *bytesReader) Read(b []byte) (n int, err error) {
@ -18,8 +17,29 @@ func (r *bytesReader) Read(b []byte) (n int, err error) {
if r.i >= int64(len(*r.s)) {
return 0, io.EOF
}
r.prevRune = -1
n = copy(b, (*r.s)[r.i:])
r.i += int64(n)
return
}
func (r *bytesReader) Seek(offset int64, whence int) (int64, error) {
if r.s == nil {
return 0, errors.New("byte slice pointer is nil")
}
var abs int64
switch whence {
case io.SeekStart:
abs = offset
case io.SeekCurrent:
abs = r.i + offset
case io.SeekEnd:
abs = int64(len(*r.s)) + offset
default:
return 0, errors.New("invalid whence")
}
if abs < 0 {
return 0, errors.New("negative position")
}
r.i = abs
return abs, nil
}

View file

@ -0,0 +1,110 @@
package transport
import (
"encoding/json"
"fmt"
"io"
"log"
"mime"
"net/http"
"strings"
"github.com/vektah/gqlparser/v2/gqlerror"
"github.com/99designs/gqlgen/graphql"
)
type SSE struct{}
var _ graphql.Transport = SSE{}
func (t SSE) Supports(r *http.Request) bool {
if !strings.Contains(r.Header.Get("Accept"), "text/event-stream") {
return false
}
mediaType, _, err := mime.ParseMediaType(r.Header.Get("Content-Type"))
if err != nil {
return false
}
return r.Method == http.MethodPost && mediaType == "application/json"
}
func (t SSE) Do(w http.ResponseWriter, r *http.Request, exec graphql.GraphExecutor) {
ctx := r.Context()
flusher, ok := w.(http.Flusher)
if !ok {
SendErrorf(w, http.StatusInternalServerError, "streaming unsupported")
return
}
defer flusher.Flush()
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Connection", "keep-alive")
w.Header().Set("Content-Type", "application/json")
params := &graphql.RawParams{}
start := graphql.Now()
params.Headers = r.Header
params.ReadTime = graphql.TraceTiming{
Start: start,
End: graphql.Now(),
}
bodyString, err := getRequestBody(r)
if err != nil {
gqlErr := gqlerror.Errorf("could not get json request body: %+v", err)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("could not get json request body: %+v", err.Error())
writeJson(w, resp)
return
}
bodyReader := io.NopCloser(strings.NewReader(bodyString))
if err = jsonDecode(bodyReader, &params); err != nil {
w.WriteHeader(http.StatusBadRequest)
gqlErr := gqlerror.Errorf(
"json request body could not be decoded: %+v body:%s",
err,
bodyString,
)
resp := exec.DispatchError(ctx, gqlerror.List{gqlErr})
log.Printf("decoding error: %+v body:%s", err.Error(), bodyString)
writeJson(w, resp)
return
}
rc, OpErr := exec.CreateOperationContext(ctx, params)
if OpErr != nil {
w.WriteHeader(statusFor(OpErr))
resp := exec.DispatchError(graphql.WithOperationContext(ctx, rc), OpErr)
writeJson(w, resp)
return
}
ctx = graphql.WithOperationContext(ctx, rc)
w.Header().Set("Content-Type", "text/event-stream")
fmt.Fprint(w, ":\n\n")
flusher.Flush()
responses, ctx := exec.DispatchOperation(ctx, rc)
for {
response := responses(ctx)
if response == nil {
break
}
writeJsonWithSSE(w, response)
flusher.Flush()
}
fmt.Fprint(w, "event: complete\n\n")
}
func writeJsonWithSSE(w io.Writer, response *graphql.Response) {
b, err := json.Marshal(response)
if err != nil {
panic(err)
}
fmt.Fprintf(w, "event: next\ndata: %s\n\n", b)
}

View file

@ -49,7 +49,24 @@ type (
var errReadTimeout = errors.New("read timeout")
var _ graphql.Transport = Websocket{}
type WebsocketError struct {
Err error
// IsReadError flags whether the error occurred on read or write to the websocket
IsReadError bool
}
func (e WebsocketError) Error() string {
if e.IsReadError {
return fmt.Sprintf("websocket read: %v", e.Err)
}
return fmt.Sprintf("websocket write: %v", e.Err)
}
var (
_ graphql.Transport = Websocket{}
_ error = WebsocketError{}
)
func (t Websocket) Supports(r *http.Request) bool {
return r.Header.Get("Upgrade") != ""
@ -94,9 +111,12 @@ func (t Websocket) Do(w http.ResponseWriter, r *http.Request, exec graphql.Graph
conn.run()
}
func (c *wsConnection) handlePossibleError(err error) {
func (c *wsConnection) handlePossibleError(err error, isReadError bool) {
if c.ErrorFunc != nil && err != nil {
c.ErrorFunc(c.ctx, err)
c.ErrorFunc(c.ctx, WebsocketError{
Err: err,
IsReadError: isReadError,
})
}
}
@ -181,7 +201,7 @@ func (c *wsConnection) init() bool {
func (c *wsConnection) write(msg *message) {
c.mu.Lock()
c.handlePossibleError(c.me.Send(msg))
c.handlePossibleError(c.me.Send(msg), false)
c.mu.Unlock()
}
@ -227,7 +247,7 @@ func (c *wsConnection) run() {
if err != nil {
// If the connection got closed by us, don't report the error
if !errors.Is(err, net.ErrClosed) {
c.handlePossibleError(err)
c.handlePossibleError(err, true)
}
return
}
@ -330,6 +350,7 @@ func (c *wsConnection) subscribe(start time.Time, msg *message) {
c.mu.Unlock()
go func() {
ctx = withSubscriptionErrorContext(ctx)
defer func() {
if r := recover(); r != nil {
err := rc.Recover(ctx, r)
@ -342,7 +363,11 @@ func (c *wsConnection) subscribe(start time.Time, msg *message) {
}
c.sendError(msg.id, gqlerr)
}
if errs := getSubscriptionError(ctx); len(errs) != 0 {
c.sendError(msg.id, errs...)
} else {
c.complete(msg.id)
}
c.mu.Lock()
delete(c.active, msg.id)
c.mu.Unlock()
@ -358,12 +383,8 @@ func (c *wsConnection) subscribe(start time.Time, msg *message) {
c.sendResponse(msg.id, response)
}
c.complete(msg.id)
c.mu.Lock()
delete(c.active, msg.id)
c.mu.Unlock()
cancel()
// complete and context cancel comes from the defer
}()
}

View file

@ -7,7 +7,7 @@ import (
"github.com/gorilla/websocket"
)
// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md
// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
const (
graphqltransportwsSubprotocol = "graphql-transport-ws"

View file

@ -7,7 +7,7 @@ import (
"github.com/gorilla/websocket"
)
// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
// https://github.com/enisdenjo/graphql-ws/blob/master/PROTOCOL.md
const (
graphqlwsSubprotocol = "graphql-ws"

View file

@ -0,0 +1,69 @@
package transport
import (
"context"
"github.com/vektah/gqlparser/v2/gqlerror"
)
// A private key for context that only this package can access. This is important
// to prevent collisions between different context uses
var wsSubscriptionErrorCtxKey = &wsSubscriptionErrorContextKey{"subscription-error"}
type wsSubscriptionErrorContextKey struct {
name string
}
type subscriptionError struct {
errs []*gqlerror.Error
}
// AddSubscriptionError is used to let websocket return an error message after subscription resolver returns a channel.
// for example:
//
// func (r *subscriptionResolver) Method(ctx context.Context) (<-chan *model.Message, error) {
// ch := make(chan *model.Message)
// go func() {
// defer func() {
// close(ch)
// }
// // some kind of block processing (e.g.: gRPC client streaming)
// stream, err := gRPCClientStreamRequest(ctx)
// if err != nil {
// transport.AddSubscriptionError(ctx, err)
// return // must return and close channel so websocket can send error back
// }
// for {
// m, err := stream.Recv()
// if err == io.EOF {
// return
// }
// if err != nil {
// transport.AddSubscriptionError(ctx, err)
// return // must return and close channel so websocket can send error back
// }
// ch <- m
// }
// }()
//
// return ch, nil
// }
//
// see https://github.com/99designs/gqlgen/pull/2506 for more details
func AddSubscriptionError(ctx context.Context, err *gqlerror.Error) {
subscriptionErrStruct := getSubscriptionErrorStruct(ctx)
subscriptionErrStruct.errs = append(subscriptionErrStruct.errs, err)
}
func withSubscriptionErrorContext(ctx context.Context) context.Context {
return context.WithValue(ctx, wsSubscriptionErrorCtxKey, &subscriptionError{})
}
func getSubscriptionErrorStruct(ctx context.Context) *subscriptionError {
v, _ := ctx.Value(wsSubscriptionErrorCtxKey).(*subscriptionError)
return v
}
func getSubscriptionError(ctx context.Context) []*gqlerror.Error {
return getSubscriptionErrorStruct(ctx).errs
}

55
vendor/github.com/99designs/gqlgen/graphql/input.go generated vendored Normal file
View file

@ -0,0 +1,55 @@
package graphql
import (
"context"
"errors"
"reflect"
)
const unmarshalInputCtx key = "unmarshal_input_context"
// BuildUnmarshalerMap returns a map of unmarshal functions of the ExecutableContext
// to use with the WithUnmarshalerMap function.
func BuildUnmarshalerMap(unmarshaler ...interface{}) map[reflect.Type]reflect.Value {
maps := make(map[reflect.Type]reflect.Value)
for _, v := range unmarshaler {
ft := reflect.TypeOf(v)
if ft.Kind() == reflect.Func {
maps[ft.Out(0)] = reflect.ValueOf(v)
}
}
return maps
}
// WithUnmarshalerMap returns a new context with a map from input types to their unmarshaler functions.
func WithUnmarshalerMap(ctx context.Context, maps map[reflect.Type]reflect.Value) context.Context {
return context.WithValue(ctx, unmarshalInputCtx, maps)
}
// UnmarshalInputFromContext allows unmarshaling input object from a context.
func UnmarshalInputFromContext(ctx context.Context, raw, v interface{}) error {
m, ok := ctx.Value(unmarshalInputCtx).(map[reflect.Type]reflect.Value)
if m == nil || !ok {
return errors.New("graphql: the input context is empty")
}
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr || rv.IsNil() {
return errors.New("graphql: input must be a non-nil pointer")
}
if fn, ok := m[rv.Elem().Type()]; ok {
res := fn.Call([]reflect.Value{
reflect.ValueOf(ctx),
reflect.ValueOf(raw),
})
if err := res[1].Interface(); err != nil {
return err.(error)
}
rv.Elem().Set(res[0])
return nil
}
return errors.New("graphql: no unmarshal function found")
}

View file

@ -0,0 +1,84 @@
package playground
import (
"html/template"
"net/http"
)
var altairPage = template.Must(template.New("altair").Parse(`<!doctype html>
<html>
<head>
<meta charset="utf-8">
<title>{{.title}}</title>
<base href="https://cdn.jsdelivr.net/npm/altair-static@{{.version}}/build/dist/">
<meta name="viewport" content="width=device-width,initial-scale=1">
<link rel="icon" type="image/x-icon" href="favicon.ico">
<link href="styles.css" rel="stylesheet" crossorigin="anonymous" integrity="{{.cssSRI}}"/>
</head>
<body>
<app-root>
<style>
.loading-screen {
display: none;
}
</style>
<div class="loading-screen styled">
<div class="loading-screen-inner">
<div class="loading-screen-logo-container">
<img src="assets/img/logo_350.svg" alt="Altair">
</div>
<div class="loading-screen-loading-indicator">
<span class="loading-indicator-dot"></span>
<span class="loading-indicator-dot"></span>
<span class="loading-indicator-dot"></span>
</div>
</div>
</div>
</app-root>
<script rel="preload" as="script" type="text/javascript" crossorigin="anonymous" integrity="{{.mainSRI}}" src="main.js"></script>
<script rel="preload" as="script" type="text/javascript" crossorigin="anonymous" integrity="{{.polyfillsSRI}}" src="polyfills.js"></script>
<script rel="preload" as="script" type="text/javascript" crossorigin="anonymous" integrity="{{.runtimeSRI}}" src="runtime.js"></script>
<script>
{{- if .endpointIsAbsolute}}
const url = {{.endpoint}};
const subscriptionUrl = {{.subscriptionEndpoint}};
{{- else}}
const url = location.protocol + '//' + location.host + {{.endpoint}};
const wsProto = location.protocol == 'https:' ? 'wss:' : 'ws:';
const subscriptionUrl = wsProto + '//' + location.host + {{.endpoint}};
{{- end}}
var altairOptions = {
endpointURL: url,
subscriptionsEndpoint: subscriptionUrl,
};
window.addEventListener("load", function() {
AltairGraphQL.init(altairOptions);
});
</script>
</body>
</html>`))
// AltairHandler responsible for setting up the altair playground
func AltairHandler(title, endpoint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
err := altairPage.Execute(w, map[string]interface{}{
"title": title,
"endpoint": endpoint,
"endpointIsAbsolute": endpointHasScheme(endpoint),
"subscriptionEndpoint": getSubscriptionEndpoint(endpoint),
"version": "5.0.5",
"cssSRI": "sha256-kZ35e5mdMYN5ALEbnsrA2CLn85Oe4hBodfsih9BqNxs=",
"mainSRI": "sha256-nWdVTcGTlBDV1L04UQnqod+AJedzBCnKHv6Ct65liHE=",
"polyfillsSRI": "sha256-1aVEg2sROcCQ/RxU3AlcPaRZhZdIWA92q2M+mdd/R4c=",
"runtimeSRI": "sha256-cK2XhXqQr0WS1Z5eKNdac0rJxTD6miC3ubd+aEVMQDk=",
})
if err != nil {
panic(err)
}
}
}

View file

@ -3,22 +3,26 @@ package playground
import (
"html/template"
"net/http"
"net/url"
)
var page = template.Must(template.New("graphiql").Parse(`<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<title>{{.title}}</title>
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/graphiql@{{.version}}/graphiql.min.css"
integrity="{{.cssSRI}}"
crossorigin="anonymous"
/>
</head>
<body style="margin: 0;">
<div id="graphiql" style="height: 100vh;"></div>
<style>
body {
height: 100%;
margin: 0;
width: 100%;
overflow: hidden;
}
#graphiql {
height: 100vh;
}
</style>
<script
src="https://cdn.jsdelivr.net/npm/react@17.0.2/umd/react.production.min.js"
integrity="{{.reactSRI}}"
@ -29,6 +33,16 @@ var page = template.Must(template.New("graphiql").Parse(`<!DOCTYPE html>
integrity="{{.reactDOMSRI}}"
crossorigin="anonymous"
></script>
<link
rel="stylesheet"
href="https://cdn.jsdelivr.net/npm/graphiql@{{.version}}/graphiql.min.css"
integrity="{{.cssSRI}}"
crossorigin="anonymous"
/>
</head>
<body>
<div id="graphiql">Loading...</div>
<script
src="https://cdn.jsdelivr.net/npm/graphiql@{{.version}}/graphiql.min.js"
integrity="{{.jsSRI}}"
@ -36,15 +50,20 @@ var page = template.Must(template.New("graphiql").Parse(`<!DOCTYPE html>
></script>
<script>
const url = location.protocol + '//' + location.host + '{{.endpoint}}';
{{- if .endpointIsAbsolute}}
const url = {{.endpoint}};
const subscriptionUrl = {{.subscriptionEndpoint}};
{{- else}}
const url = location.protocol + '//' + location.host + {{.endpoint}};
const wsProto = location.protocol == 'https:' ? 'wss:' : 'ws:';
const subscriptionUrl = wsProto + '//' + location.host + '{{.endpoint}}';
const subscriptionUrl = wsProto + '//' + location.host + {{.endpoint}};
{{- end}}
const fetcher = GraphiQL.createFetcher({ url, subscriptionUrl });
ReactDOM.render(
React.createElement(GraphiQL, {
fetcher: fetcher,
headerEditorEnabled: true,
isHeadersEditorEnabled: true,
shouldPersistHeaders: true
}),
document.getElementById('graphiql'),
@ -54,15 +73,18 @@ var page = template.Must(template.New("graphiql").Parse(`<!DOCTYPE html>
</html>
`))
// Handler responsible for setting up the playground
func Handler(title string, endpoint string) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "text/html")
err := page.Execute(w, map[string]string{
w.Header().Add("Content-Type", "text/html; charset=UTF-8")
err := page.Execute(w, map[string]interface{}{
"title": title,
"endpoint": endpoint,
"version": "1.5.16",
"cssSRI": "sha256-HADQowUuFum02+Ckkv5Yu5ygRoLllHZqg0TFZXY7NHI=",
"jsSRI": "sha256-uHp12yvpXC4PC9+6JmITxKuLYwjlW9crq9ywPE5Rxco=",
"endpointIsAbsolute": endpointHasScheme(endpoint),
"subscriptionEndpoint": getSubscriptionEndpoint(endpoint),
"version": "2.0.7",
"cssSRI": "sha256-gQryfbGYeYFxnJYnfPStPYFt0+uv8RP8Dm++eh00G9c=",
"jsSRI": "sha256-qQ6pw7LwTLC+GfzN+cJsYXfVWRKH9O5o7+5H96gTJhQ=",
"reactSRI": "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8=",
"reactDOMSRI": "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0=",
})
@ -71,3 +93,27 @@ func Handler(title string, endpoint string) http.HandlerFunc {
}
}
}
// endpointHasScheme checks if the endpoint has a scheme.
func endpointHasScheme(endpoint string) bool {
u, err := url.Parse(endpoint)
return err == nil && u.Scheme != ""
}
// getSubscriptionEndpoint returns the subscription endpoint for the given
// endpoint if it is parsable as a URL, or an empty string.
func getSubscriptionEndpoint(endpoint string) string {
u, err := url.Parse(endpoint)
if err != nil {
return ""
}
switch u.Scheme {
case "https":
u.Scheme = "wss"
default:
u.Scheme = "ws"
}
return u.String()
}

View file

@ -12,7 +12,7 @@ type Stats struct {
Parsing TraceTiming
Validation TraceTiming
// Stats collected by handler extensions. Dont use directly, the extension should provide a type safe way to
// Stats collected by handler extensions. Don't use directly, the extension should provide a type safe way to
// access this.
extension map[string]interface{}
}
@ -26,7 +26,7 @@ var ctxTraceStart key = "trace_start"
// StartOperationTrace captures the current time and stores it in context. This will eventually be added to request
// context but we want to grab it as soon as possible. For transports that can only handle a single graphql query
// per http requests you dont need to call this at all, the server will do it for you. For transports that handle
// per http requests you don't need to call this at all, the server will do it for you. For transports that handle
// multiple (eg batching, subscriptions) this should be called before decoding each request.
func StartOperationTrace(ctx context.Context) context.Context {
return context.WithValue(ctx, ctxTraceStart, Now())

View file

@ -1,6 +1,7 @@
package graphql
import (
"encoding/json"
"fmt"
"io"
"strconv"
@ -55,7 +56,9 @@ func UnmarshalString(v interface{}) (string, error) {
case int64:
return strconv.FormatInt(v, 10), nil
case float64:
return fmt.Sprintf("%f", v), nil
return strconv.FormatFloat(v, 'f', -1, 64), nil
case json.Number:
return string(v), nil
case bool:
if v {
return "true", nil

View file

@ -6,7 +6,7 @@ import (
)
type Upload struct {
File io.Reader
File io.ReadSeeker
Filename string
Size int64
ContentType string

View file

@ -1,3 +1,3 @@
package graphql
const Version = "v0.17.2"
const Version = "v0.17.24"

View file

@ -4,13 +4,13 @@ schema:
# Where should the generated server code go?
exec:
filename: graph/generated/generated.go
package: generated
filename: graph/generated.go
package: graph
# Uncomment to enable federation
# federation:
# filename: graph/generated/federation.go
# package: generated
# filename: graph/federation.go
# package: graph
# Where should any generated models go?
model:
@ -29,6 +29,13 @@ resolver:
# Optional: turn on to use []Thing instead of []*Thing
# omit_slice_element_pointers: false
# Optional: turn off to make struct-type struct fields not use pointers
# e.g. type Thing struct { FieldA OtherThing } instead of { FieldA *OtherThing }
# struct_fields_always_pointers: true
# Optional: turn off to make resolvers return values instead of pointers for structs
# resolvers_always_return_pointers: true
# Optional: set to speed up generation time by not performing a final validation pass.
# skip_validation: true

View file

@ -39,12 +39,16 @@ func CompatibleTypes(expected types.Type, actual types.Type) error {
}
case *types.Basic:
if actual, ok := actual.(*types.Basic); ok {
if actual.Kind() != expected.Kind() {
return fmt.Errorf("basic kind differs, %s != %s", expected.Name(), actual.Name())
if actualBasic, ok := actual.(*types.Basic); ok {
if similarBasicKind(actualBasic.Kind()) != expected.Kind() {
return fmt.Errorf("basic kind differs, %s != %s", expected.Name(), actualBasic.Name())
}
return nil
} else if actual, ok := actual.(*types.Named); ok {
if underlyingBasic, ok := actual.Underlying().(*types.Basic); ok {
return CompatibleTypes(expected, underlyingBasic)
}
}
case *types.Struct:
@ -159,3 +163,14 @@ func CompatibleTypes(expected types.Type, actual types.Type) error {
return fmt.Errorf("type mismatch %T != %T", expected, actual)
}
func similarBasicKind(kind types.BasicKind) types.BasicKind {
switch kind {
case types.Int8, types.Int16:
return types.Int64
case types.Uint, types.Uint8, types.Uint16, types.Uint32: // exclude Uint64: it still needs scalar with custom marshalling/unmarshalling because it is bigger then int64
return types.Int64
default:
return kind
}
}

View file

@ -1,10 +1,12 @@
package code
import (
"bufio"
"fmt"
"go/build"
"go/parser"
"go/token"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strings"
@ -26,7 +28,7 @@ func NameForDir(dir string) string {
if err != nil {
return SanitizePackageName(filepath.Base(dir))
}
files, err := ioutil.ReadDir(dir)
files, err := os.ReadDir(dir)
if err != nil {
return SanitizePackageName(filepath.Base(dir))
}
@ -73,8 +75,8 @@ func goModuleRoot(dir string) (string, bool) {
break
}
if content, err := ioutil.ReadFile(filepath.Join(modDir, "go.mod")); err == nil {
moduleName := string(modregex.FindSubmatch(content)[1])
if content, err := os.ReadFile(filepath.Join(modDir, "go.mod")); err == nil {
moduleName := extractModuleName(content)
result = goModuleSearchResult{
path: moduleName,
goModPath: modDir,
@ -126,6 +128,27 @@ func goModuleRoot(dir string) (string, bool) {
return res.path, true
}
func extractModuleName(content []byte) string {
for {
advance, tkn, err := bufio.ScanLines(content, false)
if err != nil {
panic(fmt.Errorf("error parsing mod file: %w", err))
}
if advance == 0 {
break
}
s := strings.Trim(string(tkn), " \t")
if len(s) != 0 && !strings.HasPrefix(s, "//") {
break
}
if advance <= len(content) {
content = content[advance:]
}
}
moduleName := string(modregex.FindSubmatch(content)[1])
return moduleName
}
// ImportPathForDir takes a path and returns a golang import path for the package
func ImportPathForDir(dir string) (res string) {
dir, err := filepath.Abs(dir)

View file

@ -151,7 +151,7 @@ func (p *Packages) NameForPackage(importPath string) string {
pkg := p.packages[importPath]
if pkg == nil {
// otherwise do a name only lookup for it but dont put it in the package cache.
// otherwise do a name only lookup for it but don't put it in the package cache.
p.numNameCalls++
pkgs, err := packages.Load(&packages.Config{Mode: packages.NeedName}, importPath)
if err != nil {

View file

@ -5,7 +5,7 @@ import (
"fmt"
"go/ast"
"go/token"
"io/ioutil"
"os"
"path/filepath"
"strconv"
"strings"
@ -56,7 +56,7 @@ func (r *Rewriter) getSource(start, end token.Pos) string {
func (r *Rewriter) getFile(filename string) string {
if _, ok := r.files[filename]; !ok {
b, err := ioutil.ReadFile(filename)
b, err := os.ReadFile(filename)
if err != nil {
panic(fmt.Errorf("unable to load file, already exists: %w", err))
}
@ -68,7 +68,7 @@ func (r *Rewriter) getFile(filename string) string {
return r.files[filename]
}
func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
func (r *Rewriter) GetPrevDecl(structname string, methodname string) *ast.FuncDecl {
for _, f := range r.pkg.Syntax {
for _, d := range f.Decls {
d, isFunc := d.(*ast.FuncDecl)
@ -89,17 +89,29 @@ func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
if !ok {
continue
}
if ident.Name != structname {
continue
}
r.copied[d] = true
return d
}
}
return nil
}
func (r *Rewriter) GetMethodComment(structname string, methodname string) string {
d := r.GetPrevDecl(structname, methodname)
if d != nil {
return d.Doc.Text()
}
return ""
}
func (r *Rewriter) GetMethodBody(structname string, methodname string) string {
d := r.GetPrevDecl(structname, methodname)
if d != nil {
return r.getSource(d.Body.Pos()+1, d.Body.End()-1)
}
}
return ""
}

View file

@ -6,8 +6,8 @@ import (
"errors"
"fmt"
"html/template"
"io"
"io/fs"
"io/ioutil"
"log"
"os"
"path/filepath"
@ -39,11 +39,33 @@ func fileExists(filename string) bool {
return !errors.Is(err, fs.ErrNotExist)
}
// see Go source code:
// https://github.com/golang/go/blob/f57ebed35132d02e5cf016f324853217fb545e91/src/cmd/go/internal/modload/init.go#L1283
func findModuleRoot(dir string) (roots string) {
if dir == "" {
panic("dir not set")
}
dir = filepath.Clean(dir)
// Look for enclosing go.mod.
for {
if fi, err := os.Stat(filepath.Join(dir, "go.mod")); err == nil && !fi.IsDir() {
return dir
}
d := filepath.Dir(dir)
if d == dir { // the parent of the root is itself, so we can go no further
break
}
dir = d
}
return ""
}
func initFile(filename, contents string) error {
if err := os.MkdirAll(filepath.Dir(filename), 0o755); err != nil {
return fmt.Errorf("unable to create directory for file '%s': %w\n", filename, err)
}
if err := ioutil.WriteFile(filename, []byte(contents), 0o644); err != nil {
if err := os.WriteFile(filename, []byte(contents), 0o644); err != nil {
return fmt.Errorf("unable to write file '%s': %w\n", filename, err)
}
@ -56,17 +78,36 @@ var initCmd = &cli.Command{
Flags: []cli.Flag{
&cli.BoolFlag{Name: "verbose, v", Usage: "show logs"},
&cli.StringFlag{Name: "config, c", Usage: "the config filename", Value: "gqlgen.yml"},
&cli.StringFlag{Name: "server", Usage: "where to write the server stub to", Value: "server.go"},
&cli.StringFlag{Name: "schema", Usage: "where to write the schema stub to", Value: "graph/schema.graphqls"},
&cli.StringFlag{
Name: "server",
Usage: "where to write the server stub to",
Value: "server.go",
},
&cli.StringFlag{
Name: "schema",
Usage: "where to write the schema stub to",
Value: "graph/schema.graphqls",
},
},
Action: func(ctx *cli.Context) error {
configFilename := ctx.String("config")
serverFilename := ctx.String("server")
schemaFilename := ctx.String("schema")
pkgName := code.ImportPathForDir(".")
cwd, err := os.Getwd()
if err != nil {
log.Println(err)
return fmt.Errorf("unable to determine current directory:%w", err)
}
pkgName := code.ImportPathForDir(cwd)
if pkgName == "" {
return fmt.Errorf("unable to determine import path for current directory, you probably need to run 'go mod init' first")
return fmt.Errorf(
"unable to determine import path for current directory, you probably need to run 'go mod init' first",
)
}
modRoot := findModuleRoot(cwd)
if modRoot == "" {
return fmt.Errorf("go.mod is missing. Please, do 'go mod init' first\n")
}
// check schema and config don't already exist
@ -75,7 +116,7 @@ var initCmd = &cli.Command{
return fmt.Errorf("%s already exists", filename)
}
}
_, err := config.LoadConfigFromDefaultLocations()
_, err = config.LoadConfigFromDefaultLocations()
if err == nil {
return fmt.Errorf("gqlgen.yml already exists in a parent directory\n")
}
@ -171,7 +212,7 @@ func main() {
if context.Bool("verbose") {
log.SetFlags(0)
} else {
log.SetOutput(ioutil.Discard)
log.SetOutput(io.Discard)
}
return nil
}

View file

@ -0,0 +1,109 @@
package federation
import (
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/plugin/federation/fieldset"
"github.com/vektah/gqlparser/v2/ast"
)
// Entity represents a federated type
// that was declared in the GQL schema.
type Entity struct {
Name string // The same name as the type declaration
Def *ast.Definition
Resolvers []*EntityResolver
Requires []*Requires
Multi bool
}
type EntityResolver struct {
ResolverName string // The resolver name, such as FindUserByID
KeyFields []*KeyField // The fields declared in @key.
InputType string // The Go generated input type for multi entity resolvers
}
type KeyField struct {
Definition *ast.FieldDefinition
Field fieldset.Field // len > 1 for nested fields
Type *config.TypeReference // The Go representation of that field type
}
// Requires represents an @requires clause
type Requires struct {
Name string // the name of the field
Field fieldset.Field // source Field, len > 1 for nested fields
Type *config.TypeReference // The Go representation of that field type
}
func (e *Entity) allFieldsAreExternal(federationVersion int) bool {
for _, field := range e.Def.Fields {
if !e.isFieldImplicitlyExternal(field, federationVersion) && field.Directives.ForName("external") == nil {
return false
}
}
return true
}
// In federation v2, key fields are implicitly external.
func (e *Entity) isFieldImplicitlyExternal(field *ast.FieldDefinition, federationVersion int) bool {
// Key fields are only implicitly external in Federation 2
if federationVersion != 2 {
return false
}
// TODO: From the spec, it seems like if an entity is not resolvable then it should not only not have a resolver, but should not appear in the _Entitiy union.
// The current implementation is a less drastic departure from the previous behavior, but should probably be reviewed.
// See https://www.apollographql.com/docs/federation/subgraph-spec/
if e.isResolvable() {
return false
}
// If the field is a key field, it is implicitly external
if e.isKeyField(field) {
return true
}
return false
}
// Determine if the entity is resolvable.
func (e *Entity) isResolvable() bool {
key := e.Def.Directives.ForName("key")
if key == nil {
// If there is no key directive, the entity is resolvable.
return true
}
resolvable := key.Arguments.ForName("resolvable")
if resolvable == nil {
// If there is no resolvable argument, the entity is resolvable.
return true
}
// only if resolvable: false has been set on the @key directive do we consider the entity non-resolvable.
return resolvable.Value.Raw != "false"
}
// Determine if a field is part of the entities key.
func (e *Entity) isKeyField(field *ast.FieldDefinition) bool {
for _, keyField := range e.keyFields() {
if keyField == field.Name {
return true
}
}
return false
}
// Get the key fields for this entity.
func (e *Entity) keyFields() []string {
key := e.Def.Directives.ForName("key")
if key == nil {
return []string{}
}
fields := key.Arguments.ForName("fields")
if fields == nil {
return []string{}
}
fieldSet := fieldset.New(fields.Value.Raw, nil)
keyFields := make([]string, len(fieldSet))
for i, field := range fieldSet {
keyFields[i] = field[0]
}
return keyFields
}

View file

@ -1,6 +1,7 @@
package federation
import (
_ "embed"
"fmt"
"sort"
"strings"
@ -14,14 +15,22 @@ import (
"github.com/99designs/gqlgen/plugin/federation/fieldset"
)
//go:embed federation.gotpl
var federationTemplate string
type federation struct {
Entities []*Entity
Version int
}
// New returns a federation plugin that injects
// federated directives and types into the schema
func New() plugin.Plugin {
return &federation{}
func New(version int) plugin.Plugin {
if version == 0 {
version = 1
}
return &federation{Version: version}
}
// Name returns the plugin name
@ -51,6 +60,7 @@ func (f *federation) MutateConfig(cfg *config.Config) error {
Model: config.StringList{"github.com/99designs/gqlgen/graphql.Map"},
},
}
for typeName, entry := range builtins {
if cfg.Models.Exists(typeName) {
return fmt.Errorf("%v already exists which must be reserved when Federation is enabled", typeName)
@ -63,22 +73,46 @@ func (f *federation) MutateConfig(cfg *config.Config) error {
cfg.Directives["key"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["extends"] = config.DirectiveConfig{SkipRuntime: true}
// Federation 2 specific directives
if f.Version == 2 {
cfg.Directives["shareable"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["link"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["tag"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["override"] = config.DirectiveConfig{SkipRuntime: true}
cfg.Directives["inaccessible"] = config.DirectiveConfig{SkipRuntime: true}
}
return nil
}
func (f *federation) InjectSourceEarly() *ast.Source {
input := `
scalar _Any
scalar _FieldSet
directive @external on FIELD_DEFINITION
directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
directive @extends on OBJECT | INTERFACE
`
// add version-specific changes on key directive, as well as adding the new directives for federation 2
if f.Version == 1 {
input += `
directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
`
} else if f.Version == 2 {
input += `
directive @key(fields: _FieldSet!, resolvable: Boolean = true) repeatable on OBJECT | INTERFACE
directive @link(import: [String!], url: String!) repeatable on SCHEMA
directive @shareable on OBJECT | FIELD_DEFINITION
directive @tag(name: String!) repeatable on FIELD_DEFINITION | INTERFACE | OBJECT | UNION | ARGUMENT_DEFINITION | SCALAR | ENUM | ENUM_VALUE | INPUT_OBJECT | INPUT_FIELD_DEFINITION
directive @override(from: String!) on FIELD_DEFINITION
directive @inaccessible on SCALAR | OBJECT | FIELD_DEFINITION | ARGUMENT_DEFINITION | INTERFACE | UNION | ENUM | ENUM_VALUE | INPUT_OBJECT | INPUT_FIELD_DEFINITION
`
}
return &ast.Source{
Name: "federation/directives.graphql",
Input: `
scalar _Any
scalar _FieldSet
directive @external on FIELD_DEFINITION
directive @requires(fields: _FieldSet!) on FIELD_DEFINITION
directive @provides(fields: _FieldSet!) on FIELD_DEFINITION
directive @key(fields: _FieldSet!) repeatable on OBJECT | INTERFACE
directive @extends on OBJECT | INTERFACE
`,
Input: input,
BuiltIn: true,
}
}
@ -164,44 +198,6 @@ type Entity {
}
}
// Entity represents a federated type
// that was declared in the GQL schema.
type Entity struct {
Name string // The same name as the type declaration
Def *ast.Definition
Resolvers []*EntityResolver
Requires []*Requires
Multi bool
}
type EntityResolver struct {
ResolverName string // The resolver name, such as FindUserByID
KeyFields []*KeyField // The fields declared in @key.
InputType string // The Go generated input type for multi entity resolvers
}
type KeyField struct {
Definition *ast.FieldDefinition
Field fieldset.Field // len > 1 for nested fields
Type *config.TypeReference // The Go representation of that field type
}
// Requires represents an @requires clause
type Requires struct {
Name string // the name of the field
Field fieldset.Field // source Field, len > 1 for nested fields
Type *config.TypeReference // The Go representation of that field type
}
func (e *Entity) allFieldsAreExternal() bool {
for _, field := range e.Def.Fields {
if field.Directives.ForName("external") == nil {
return false
}
}
return true
}
func (f *federation) GenerateCode(data *codegen.Data) error {
if len(f.Entities) > 0 {
if data.Objects.ByName("Entity") != nil {
@ -244,6 +240,7 @@ func (f *federation) GenerateCode(data *codegen.Data) error {
Data: f,
GeneratedHeader: true,
Packages: data.Config.Packages,
Template: federationTemplate,
})
}
@ -288,12 +285,23 @@ func (f *federation) setEntities(schema *ast.Schema) {
// extend TypeDefinedInOtherService @key(fields: "id") {
// id: ID @external
// }
if !e.allFieldsAreExternal() {
if !e.allFieldsAreExternal(f.Version) {
for _, dir := range keys {
if len(dir.Arguments) != 1 || dir.Arguments[0].Name != "fields" {
panic("Exactly one `fields` argument needed for @key declaration.")
if len(dir.Arguments) > 2 {
panic("More than two arguments provided for @key declaration.")
}
arg := dir.Arguments[0]
var arg *ast.Argument
// since keys are able to now have multiple arguments, we need to check both possible for a possible @key(fields="" fields="")
for _, a := range dir.Arguments {
if a.Name == "fields" {
if arg != nil {
panic("More than one `fields` provided for @key declaration.")
}
arg = a
}
}
keyFieldSet := fieldset.New(arg.Value.Raw, nil)
keyFields := make([]*KeyField, len(keyFieldSet))

View file

@ -11,15 +11,12 @@ import (
// Set represents a FieldSet that is used in federation directives @key and @requires.
// Would be happier to reuse FieldSet parsing from gqlparser, but this suits for now.
//
type Set []Field
// Field represents a single field in a FieldSet
//
type Field []string
// New parses a FieldSet string into a TinyFieldSet.
//
func New(raw string, prefix []string) Set {
if !strings.Contains(raw, "{") {
return parseUnnestedKeyFieldSet(raw, prefix)
@ -48,7 +45,6 @@ func New(raw string, prefix []string) Set {
}
// FieldDefinition looks up a field in the type.
//
func (f Field) FieldDefinition(schemaType *ast.Definition, schema *ast.Schema) *ast.FieldDefinition {
objType := schemaType
def := objType.Fields.ForName(f[0])
@ -74,7 +70,6 @@ func (f Field) FieldDefinition(schemaType *ast.Definition, schema *ast.Schema) *
}
// TypeReference looks up the type of a field.
//
func (f Field) TypeReference(obj *codegen.Object, objects codegen.Objects) *codegen.Field {
var def *codegen.Field
@ -89,7 +84,6 @@ func (f Field) TypeReference(obj *codegen.Object, objects codegen.Objects) *code
}
// ToGo converts a (possibly nested) field into a proper public Go name.
//
func (f Field) ToGo() string {
var ret string
@ -100,7 +94,6 @@ func (f Field) ToGo() string {
}
// ToGoPrivate converts a (possibly nested) field into a proper private Go name.
//
func (f Field) ToGoPrivate() string {
var ret string
@ -115,13 +108,11 @@ func (f Field) ToGoPrivate() string {
}
// Join concatenates the field parts with a string separator between. Useful in templates.
//
func (f Field) Join(str string) string {
return strings.Join(f, str)
}
// JoinGo concatenates the Go name of field parts with a string separator between. Useful in templates.
//
func (f Field) JoinGo(str string) string {
strs := []string{}
@ -138,7 +129,6 @@ func (f Field) LastIndex() int {
// local functions
// parseUnnestedKeyFieldSet // handles simple case where none of the fields are nested.
//
func parseUnnestedKeyFieldSet(raw string, prefix []string) Set {
ret := Set{}
@ -150,7 +140,6 @@ func parseUnnestedKeyFieldSet(raw string, prefix []string) Set {
}
// extractSubs splits out and trims sub-expressions from before, inside, and after "{}".
//
func extractSubs(str string) (string, string, string) {
start := strings.Index(str, "{")
end := matchingBracketIndex(str, start)
@ -162,7 +151,6 @@ func extractSubs(str string) (string, string, string) {
}
// matchingBracketIndex returns the index of the closing bracket, assuming an open bracket at start.
//
func matchingBracketIndex(str string, start int) int {
if start < 0 || len(str) <= start+1 {
return -1

View file

@ -1,10 +1,12 @@
package modelgen
import (
_ "embed"
"fmt"
"go/types"
"sort"
"strings"
"text/template"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
@ -12,16 +14,26 @@ import (
"github.com/vektah/gqlparser/v2/ast"
)
type BuildMutateHook = func(b *ModelBuild) *ModelBuild
//go:embed models.gotpl
var modelTemplate string
type FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
type (
BuildMutateHook = func(b *ModelBuild) *ModelBuild
FieldMutateHook = func(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error)
)
// defaultFieldMutateHook is the default hook for the Plugin which applies the GoTagFieldHook.
func defaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
// DefaultFieldMutateHook is the default hook for the Plugin which applies the GoFieldHook and GoTagFieldHook.
func DefaultFieldMutateHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
var err error
f, err = GoFieldHook(td, fd, f)
if err != nil {
return f, err
}
return GoTagFieldHook(td, fd, f)
}
func defaultBuildMutateHook(b *ModelBuild) *ModelBuild {
// DefaultBuildMutateHook is the default hook for the Plugin which mutate ModelBuild.
func DefaultBuildMutateHook(b *ModelBuild) *ModelBuild {
return b
}
@ -36,6 +48,7 @@ type ModelBuild struct {
type Interface struct {
Description string
Name string
Fields []*Field
Implements []string
}
@ -48,7 +61,10 @@ type Object struct {
type Field struct {
Description string
// Name is the field's name as it appears in the schema
Name string
// GoName is the field's name as it appears in the generated Go code
GoName string
Type types.Type
Tag string
}
@ -66,8 +82,8 @@ type EnumValue struct {
func New() plugin.Plugin {
return &Plugin{
MutateHook: defaultBuildMutateHook,
FieldHook: defaultFieldMutateHook,
MutateHook: DefaultBuildMutateHook,
FieldHook: DefaultFieldMutateHook,
}
}
@ -83,8 +99,6 @@ func (m *Plugin) Name() string {
}
func (m *Plugin) MutateConfig(cfg *config.Config) error {
binder := cfg.NewBinder()
b := &ModelBuild{
PackageName: cfg.Model.Package,
}
@ -95,10 +109,20 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
}
switch schemaType.Kind {
case ast.Interface, ast.Union:
var fields []*Field
var err error
if !cfg.OmitGetters {
fields, err = m.generateFields(cfg, schemaType)
if err != nil {
return err
}
}
it := &Interface{
Description: schemaType.Description,
Name: schemaType.Name,
Implements: schemaType.Interfaces,
Fields: fields,
}
b.Interfaces = append(b.Interfaces, it)
@ -106,9 +130,16 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
if schemaType == cfg.Schema.Query || schemaType == cfg.Schema.Mutation || schemaType == cfg.Schema.Subscription {
continue
}
fields, err := m.generateFields(cfg, schemaType)
if err != nil {
return err
}
it := &Object{
Description: schemaType.Description,
Name: schemaType.Name,
Fields: fields,
}
// If Interface A implements interface B, and Interface C also implements interface B
@ -129,6 +160,150 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
}
}
b.Models = append(b.Models, it)
case ast.Enum:
it := &Enum{
Name: schemaType.Name,
Description: schemaType.Description,
}
for _, v := range schemaType.EnumValues {
it.Values = append(it.Values, &EnumValue{
Name: v.Name,
Description: v.Description,
})
}
b.Enums = append(b.Enums, it)
case ast.Scalar:
b.Scalars = append(b.Scalars, schemaType.Name)
}
}
sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
// if we are not just turning all struct-type fields in generated structs into pointers, we need to at least
// check for cyclical relationships and recursive structs
if !cfg.StructFieldsAlwaysPointers {
findAndHandleCyclicalRelationships(b)
}
for _, it := range b.Enums {
cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
}
for _, it := range b.Models {
cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
}
for _, it := range b.Interfaces {
cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
}
for _, it := range b.Scalars {
cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
}
if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
return nil
}
if m.MutateHook != nil {
b = m.MutateHook(b)
}
getInterfaceByName := func(name string) *Interface {
// Allow looking up interfaces, so template can generate getters for each field
for _, i := range b.Interfaces {
if i.Name == name {
return i
}
}
return nil
}
gettersGenerated := make(map[string]map[string]struct{})
generateGetter := func(model *Object, field *Field) string {
if model == nil || field == nil {
return ""
}
// Let templates check if a given getter has been generated already
typeGetters, exists := gettersGenerated[model.Name]
if !exists {
typeGetters = make(map[string]struct{})
gettersGenerated[model.Name] = typeGetters
}
_, exists = typeGetters[field.GoName]
typeGetters[field.GoName] = struct{}{}
if exists {
return ""
}
_, interfaceFieldTypeIsPointer := field.Type.(*types.Pointer)
var structFieldTypeIsPointer bool
for _, f := range model.Fields {
if f.GoName == field.GoName {
_, structFieldTypeIsPointer = f.Type.(*types.Pointer)
break
}
}
goType := templates.CurrentImports.LookupType(field.Type)
if strings.HasPrefix(goType, "[]") {
getter := fmt.Sprintf("func (this %s) Get%s() %s {\n", templates.ToGo(model.Name), field.GoName, goType)
getter += fmt.Sprintf("\tif this.%s == nil { return nil }\n", field.GoName)
getter += fmt.Sprintf("\tinterfaceSlice := make(%s, 0, len(this.%s))\n", goType, field.GoName)
getter += fmt.Sprintf("\tfor _, concrete := range this.%s { interfaceSlice = append(interfaceSlice, ", field.GoName)
if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
getter += "&"
} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
getter += "*"
}
getter += "concrete) }\n"
getter += "\treturn interfaceSlice\n"
getter += "}"
return getter
} else {
getter := fmt.Sprintf("func (this %s) Get%s() %s { return ", templates.ToGo(model.Name), field.GoName, goType)
if interfaceFieldTypeIsPointer && !structFieldTypeIsPointer {
getter += "&"
} else if !interfaceFieldTypeIsPointer && structFieldTypeIsPointer {
getter += "*"
}
getter += fmt.Sprintf("this.%s }", field.GoName)
return getter
}
}
funcMap := template.FuncMap{
"getInterfaceByName": getInterfaceByName,
"generateGetter": generateGetter,
}
err := templates.Render(templates.Options{
PackageName: cfg.Model.Package,
Filename: cfg.Model.Filename,
Data: b,
GeneratedHeader: true,
Packages: cfg.Packages,
Template: modelTemplate,
Funcs: funcMap,
})
if err != nil {
return err
}
// We may have generated code in a package we already loaded, so we reload all packages
// to allow packages to be compared correctly
cfg.ReloadAllPackages()
return nil
}
func (m *Plugin) generateFields(cfg *config.Config, schemaType *ast.Definition) ([]*Field, error) {
binder := cfg.NewBinder()
fields := make([]*Field, 0)
for _, field := range schemaType.Fields {
var typ types.Type
fieldDef := cfg.Schema.Types[field.Type.Name()]
@ -137,7 +312,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
var err error
typ, err = binder.FindTypeFromName(cfg.Models[field.Type.Name()].Model[0])
if err != nil {
return err
return nil, err
}
} else {
switch fieldDef.Kind {
@ -178,19 +353,22 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
}
}
name := field.Name
name := templates.ToGo(field.Name)
if nameOveride := cfg.Models[schemaType.Name].Fields[field.Name].FieldName; nameOveride != "" {
name = nameOveride
}
typ = binder.CopyModifiersFromAst(field.Type, typ)
if cfg.StructFieldsAlwaysPointers {
if isStruct(typ) && (fieldDef.Kind == ast.Object || fieldDef.Kind == ast.InputObject) {
typ = types.NewPointer(typ)
}
}
f := &Field{
Name: name,
Name: field.Name,
GoName: name,
Type: typ,
Description: field.Description,
Tag: `json:"` + field.Name + `"`,
@ -199,74 +377,15 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error {
if m.FieldHook != nil {
mf, err := m.FieldHook(schemaType, field, f)
if err != nil {
return fmt.Errorf("generror: field %v.%v: %w", it.Name, field.Name, err)
return nil, fmt.Errorf("generror: field %v.%v: %w", schemaType.Name, field.Name, err)
}
f = mf
}
it.Fields = append(it.Fields, f)
fields = append(fields, f)
}
b.Models = append(b.Models, it)
case ast.Enum:
it := &Enum{
Name: schemaType.Name,
Description: schemaType.Description,
}
for _, v := range schemaType.EnumValues {
it.Values = append(it.Values, &EnumValue{
Name: v.Name,
Description: v.Description,
})
}
b.Enums = append(b.Enums, it)
case ast.Scalar:
b.Scalars = append(b.Scalars, schemaType.Name)
}
}
sort.Slice(b.Enums, func(i, j int) bool { return b.Enums[i].Name < b.Enums[j].Name })
sort.Slice(b.Models, func(i, j int) bool { return b.Models[i].Name < b.Models[j].Name })
sort.Slice(b.Interfaces, func(i, j int) bool { return b.Interfaces[i].Name < b.Interfaces[j].Name })
for _, it := range b.Enums {
cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
}
for _, it := range b.Models {
cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
}
for _, it := range b.Interfaces {
cfg.Models.Add(it.Name, cfg.Model.ImportPath()+"."+templates.ToGo(it.Name))
}
for _, it := range b.Scalars {
cfg.Models.Add(it, "github.com/99designs/gqlgen/graphql.String")
}
if len(b.Models) == 0 && len(b.Enums) == 0 && len(b.Interfaces) == 0 && len(b.Scalars) == 0 {
return nil
}
if m.MutateHook != nil {
b = m.MutateHook(b)
}
err := templates.Render(templates.Options{
PackageName: cfg.Model.Package,
Filename: cfg.Model.Filename,
Data: b,
GeneratedHeader: true,
Packages: cfg.Packages,
})
if err != nil {
return err
}
// We may have generated code in a package we already loaded, so we reload all packages
// to allow packages to be compared correctly
cfg.ReloadAllPackages()
return nil
return fields, nil
}
// GoTagFieldHook applies the goTag directive to the generated Field f. When applying the Tag to the field, the field
@ -299,7 +418,73 @@ func GoTagFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Fie
return f, nil
}
// GoFieldHook applies the goField directive to the generated Field f.
func GoFieldHook(td *ast.Definition, fd *ast.FieldDefinition, f *Field) (*Field, error) {
args := make([]string, 0)
_ = args
for _, goField := range fd.Directives.ForNames("goField") {
if arg := goField.Arguments.ForName("name"); arg != nil {
if k, err := arg.Value.Value(nil); err == nil {
f.GoName = k.(string)
}
}
}
return f, nil
}
func isStruct(t types.Type) bool {
_, is := t.Underlying().(*types.Struct)
return is
}
// findAndHandleCyclicalRelationships checks for cyclical relationships between generated structs and replaces them
// with pointers. These relationships will produce compilation errors if they are not pointers.
// Also handles recursive structs.
func findAndHandleCyclicalRelationships(b *ModelBuild) {
for ii, structA := range b.Models {
for _, fieldA := range structA.Fields {
if strings.Contains(fieldA.Type.String(), "NotCyclicalA") {
fmt.Print()
}
if !isStruct(fieldA.Type) {
continue
}
// the field Type string will be in the form "github.com/99designs/gqlgen/codegen/testserver/followschema.LoopA"
// we only want the part after the last dot: "LoopA"
// this could lead to false positives, as we are only checking the name of the struct type, but these
// should be extremely rare, if it is even possible at all.
fieldAStructNameParts := strings.Split(fieldA.Type.String(), ".")
fieldAStructName := fieldAStructNameParts[len(fieldAStructNameParts)-1]
// find this struct type amongst the generated structs
for jj, structB := range b.Models {
if structB.Name != fieldAStructName {
continue
}
// check if structB contains a cyclical reference back to structA
var cyclicalReferenceFound bool
for _, fieldB := range structB.Fields {
if !isStruct(fieldB.Type) {
continue
}
fieldBStructNameParts := strings.Split(fieldB.Type.String(), ".")
fieldBStructName := fieldBStructNameParts[len(fieldBStructNameParts)-1]
if fieldBStructName == structA.Name {
cyclicalReferenceFound = true
fieldB.Type = types.NewPointer(fieldB.Type)
// keep looping in case this struct has additional fields of this type
}
}
// if this is a recursive struct (i.e. structA == structB), ensure that we only change this field to a pointer once
if cyclicalReferenceFound && ii != jj {
fieldA.Type = types.NewPointer(fieldA.Type)
break
}
}
}
}
}

View file

@ -14,74 +14,88 @@
{{- range $model := .Interfaces }}
{{ with .Description }} {{.|prefixLines "// "}} {{ end }}
type {{.Name|go }} interface {
type {{ goModelName .Name }} interface {
{{- range $impl := .Implements }}
{{ $impl|go }}
Is{{ goModelName $impl }}()
{{- end }}
Is{{ goModelName .Name }}()
{{- range $field := .Fields }}
{{- with .Description }}
{{.|prefixLines "// "}}
{{- end}}
Get{{ $field.GoName }}() {{ $field.Type | ref }}
{{- end }}
Is{{.Name|go }}()
}
{{- end }}
{{ range $model := .Models }}
{{with .Description }} {{.|prefixLines "// "}} {{end}}
type {{ .Name|go }} struct {
type {{ goModelName .Name }} struct {
{{- range $field := .Fields }}
{{- with .Description }}
{{.|prefixLines "// "}}
{{- end}}
{{ $field.Name|go }} {{$field.Type | ref}} `{{$field.Tag}}`
{{ $field.GoName }} {{$field.Type | ref}} `{{$field.Tag}}`
{{- end }}
}
{{- range $iface := .Implements }}
func ({{ $model.Name|go }}) Is{{ $iface|go }}() {}
{{ range .Implements }}
func ({{ goModelName $model.Name }}) Is{{ goModelName . }}() {}
{{- with getInterfaceByName . }}
{{- range .Fields }}
{{- with .Description }}
{{.|prefixLines "// "}}
{{- end}}
{{ generateGetter $model . }}
{{- end }}
{{- end }}
{{ end }}
{{- end}}
{{ range $enum := .Enums }}
{{ with .Description }} {{.|prefixLines "// "}} {{end}}
type {{.Name|go }} string
type {{ goModelName .Name }} string
const (
{{- range $value := .Values}}
{{- with .Description}}
{{.|prefixLines "// "}}
{{- end}}
{{ $enum.Name|go }}{{ .Name|go }} {{$enum.Name|go }} = {{.Name|quote}}
{{ goModelName $enum.Name .Name }} {{ goModelName $enum.Name }} = {{ .Name|quote }}
{{- end }}
)
var All{{.Name|go }} = []{{ .Name|go }}{
var All{{ goModelName .Name }} = []{{ goModelName .Name }}{
{{- range $value := .Values}}
{{$enum.Name|go }}{{ .Name|go }},
{{ goModelName $enum.Name .Name }},
{{- end }}
}
func (e {{.Name|go }}) IsValid() bool {
func (e {{ goModelName .Name }}) IsValid() bool {
switch e {
case {{ range $index, $element := .Values}}{{if $index}},{{end}}{{ $enum.Name|go }}{{ $element.Name|go }}{{end}}:
case {{ range $index, $element := .Values}}{{if $index}},{{end}}{{ goModelName $enum.Name $element.Name }}{{end}}:
return true
}
return false
}
func (e {{.Name|go }}) String() string {
func (e {{ goModelName .Name }}) String() string {
return string(e)
}
func (e *{{.Name|go }}) UnmarshalGQL(v interface{}) error {
func (e *{{ goModelName .Name }}) UnmarshalGQL(v interface{}) error {
str, ok := v.(string)
if !ok {
return fmt.Errorf("enums must be strings")
}
*e = {{ .Name|go }}(str)
*e = {{ goModelName .Name }}(str)
if !e.IsValid() {
return fmt.Errorf("%s is not a valid {{ .Name }}", str)
}
return nil
}
func (e {{.Name|go }}) MarshalGQL(w io.Writer) {
func (e {{ goModelName .Name }}) MarshalGQL(w io.Writer) {
fmt.Fprint(w, strconv.Quote(e.String()))
}

View file

@ -1,19 +1,29 @@
package resolvergen
import (
_ "embed"
"errors"
"fmt"
"go/ast"
"io/fs"
"os"
"path/filepath"
"strings"
"golang.org/x/text/cases"
"golang.org/x/text/language"
"github.com/99designs/gqlgen/codegen"
"github.com/99designs/gqlgen/codegen/config"
"github.com/99designs/gqlgen/codegen/templates"
"github.com/99designs/gqlgen/graphql"
"github.com/99designs/gqlgen/internal/rewrite"
"github.com/99designs/gqlgen/plugin"
)
//go:embed resolver.gotpl
var resolverTemplate string
func New() plugin.Plugin {
return &Plugin{}
}
@ -45,7 +55,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error {
file := File{}
if _, err := os.Stat(data.Config.Resolver.Filename); err == nil {
// file already exists and we dont support updating resolvers with layout = single so just return
// file already exists and we do not support updating resolvers with layout = single so just return
return nil
}
@ -58,7 +68,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error {
continue
}
resolver := Resolver{o, f, `panic("not implemented")`}
resolver := Resolver{o, f, nil, "// foo", `panic("not implemented")`}
file.Resolvers = append(file.Resolvers, &resolver)
}
}
@ -76,6 +86,7 @@ func (m *Plugin) generateSingleFile(data *codegen.Data) error {
Filename: data.Config.Resolver.Filename,
Data: resolverBuild,
Packages: data.Config.Packages,
Template: resolverTemplate,
})
}
@ -98,8 +109,9 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
files[fn] = &File{}
}
caser := cases.Title(language.English, cases.NoLower)
rewriter.MarkStructCopied(templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type))
rewriter.GetMethodBody(data.Config.Resolver.Type, strings.Title(o.Name))
rewriter.GetMethodBody(data.Config.Resolver.Type, caser.String(o.Name))
files[fn].Objects = append(files[fn].Objects, o)
}
for _, f := range o.Fields {
@ -108,12 +120,16 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
}
structName := templates.LcFirst(o.Name) + templates.UcFirst(data.Config.Resolver.Type)
comment := strings.TrimSpace(strings.TrimLeft(rewriter.GetMethodComment(structName, f.GoFieldName), `\`))
if comment == "" {
comment = fmt.Sprintf("%v is the resolver for the %v field.", f.GoFieldName, f.Name)
}
implementation := strings.TrimSpace(rewriter.GetMethodBody(structName, f.GoFieldName))
if implementation == "" {
implementation = `panic(fmt.Errorf("not implemented"))`
implementation = fmt.Sprintf("panic(fmt.Errorf(\"not implemented: %v - %v\"))", f.GoFieldName, f.Name)
}
resolver := Resolver{o, f, implementation}
resolver := Resolver{o, f, rewriter.GetPrevDecl(structName, f.GoFieldName), comment, implementation}
fn := gqlToResolverName(data.Config.Resolver.Dir(), f.Position.Src.Name, data.Config.Resolver.FilenameTemplate)
if files[fn] == nil {
files[fn] = &File{}
@ -139,10 +155,12 @@ func (m *Plugin) generatePerSchema(data *codegen.Data) error {
PackageName: data.Config.Resolver.Package,
FileNotice: `
// This file will be automatically regenerated based on the schema, any resolver implementations
// will be copied through when generating and any unknown code will be moved to the end.`,
// will be copied through when generating and any unknown code will be moved to the end.
// Code generated by github.com/99designs/gqlgen version ` + graphql.Version,
Filename: filename,
Data: resolverBuild,
Packages: data.Config.Packages,
Template: resolverTemplate,
})
if err != nil {
return err
@ -198,6 +216,8 @@ func (f *File) Imports() string {
type Resolver struct {
Object *codegen.Object
Field *codegen.Field
PrevDecl *ast.FuncDecl
Comment string
Implementation string
}

View file

@ -19,7 +19,8 @@
{{ end }}
{{ range $resolver := .Resolvers -}}
func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ $resolver.Field.ShortResolverDeclaration }} {
// {{ $resolver.Comment }}
func (r *{{lcFirst $resolver.Object.Name}}{{ucFirst $.ResolverType}}) {{$resolver.Field.GoFieldName}}{{ with $resolver.PrevDecl }}{{ $resolver.Field.ShortResolverSignature .Type }}{{ else }}{{ $resolver.Field.ShortResolverDeclaration }}{{ end }}{
{{ $resolver.Implementation }}
}

View file

@ -1,6 +1,7 @@
package servergen
import (
_ "embed"
"errors"
"io/fs"
"log"
@ -11,6 +12,9 @@ import (
"github.com/99designs/gqlgen/plugin"
)
//go:embed server.gotpl
var serverTemplate string
func New(filename string) plugin.Plugin {
return &Plugin{filename}
}
@ -37,6 +41,7 @@ func (m *Plugin) GenerateCode(data *codegen.Data) error {
Filename: m.filename,
Data: serverBuild,
Packages: data.Config.Packages,
Template: serverTemplate,
})
}

View file

@ -1,8 +0,0 @@
//go:build tools
// +build tools
package main
import (
_ "github.com/matryer/moq"
)

View file

@ -6,6 +6,13 @@
Gorilla WebSocket is a [Go](http://golang.org/) implementation of the
[WebSocket](http://www.rfc-editor.org/rfc/rfc6455.txt) protocol.
---
⚠️ **[The Gorilla WebSocket Package is looking for a new maintainer](https://github.com/gorilla/websocket/issues/370)**
---
### Documentation
* [API Reference](https://pkg.go.dev/github.com/gorilla/websocket?tab=doc)
@ -30,35 +37,3 @@ The Gorilla WebSocket package passes the server tests in the [Autobahn Test
Suite](https://github.com/crossbario/autobahn-testsuite) using the application in the [examples/autobahn
subdirectory](https://github.com/gorilla/websocket/tree/master/examples/autobahn).
### Gorilla WebSocket compared with other packages
<table>
<tr>
<th></th>
<th><a href="http://godoc.org/github.com/gorilla/websocket">github.com/gorilla</a></th>
<th><a href="http://godoc.org/golang.org/x/net/websocket">golang.org/x/net</a></th>
</tr>
<tr>
<tr><td colspan="3"><a href="http://tools.ietf.org/html/rfc6455">RFC 6455</a> Features</td></tr>
<tr><td>Passes <a href="https://github.com/crossbario/autobahn-testsuite">Autobahn Test Suite</a></td><td><a href="https://github.com/gorilla/websocket/tree/master/examples/autobahn">Yes</a></td><td>No</td></tr>
<tr><td>Receive <a href="https://tools.ietf.org/html/rfc6455#section-5.4">fragmented</a> message<td>Yes</td><td><a href="https://code.google.com/p/go/issues/detail?id=7632">No</a>, see note 1</td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.1">close</a> message</td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td><a href="https://code.google.com/p/go/issues/detail?id=4588">No</a></td></tr>
<tr><td>Send <a href="https://tools.ietf.org/html/rfc6455#section-5.5.2">pings</a> and receive <a href="https://tools.ietf.org/html/rfc6455#section-5.5.3">pongs</a></td><td><a href="http://godoc.org/github.com/gorilla/websocket#hdr-Control_Messages">Yes</a></td><td>No</td></tr>
<tr><td>Get the <a href="https://tools.ietf.org/html/rfc6455#section-5.6">type</a> of a received data message</td><td>Yes</td><td>Yes, see note 2</td></tr>
<tr><td colspan="3">Other Features</tr></td>
<tr><td><a href="https://tools.ietf.org/html/rfc7692">Compression Extensions</a></td><td>Experimental</td><td>No</td></tr>
<tr><td>Read message using io.Reader</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextReader">Yes</a></td><td>No, see note 3</td></tr>
<tr><td>Write message using io.WriteCloser</td><td><a href="http://godoc.org/github.com/gorilla/websocket#Conn.NextWriter">Yes</a></td><td>No, see note 3</td></tr>
</table>
Notes:
1. Large messages are fragmented in [Chrome's new WebSocket implementation](http://www.ietf.org/mail-archive/web/hybi/current/msg10503.html).
2. The application can get the type of a received data message by implementing
a [Codec marshal](http://godoc.org/golang.org/x/net/websocket#Codec.Marshal)
function.
3. The go.net io.Reader and io.Writer operate across WebSocket frame boundaries.
Read returns when the input buffer is full or a frame boundary is
encountered. Each call to Write sends a single frame message. The Gorilla
io.Reader and io.WriteCloser operate on a single WebSocket message.

View file

@ -48,15 +48,23 @@ func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufS
}
// A Dialer contains options for connecting to WebSocket server.
//
// It is safe to call Dialer's methods concurrently.
type Dialer struct {
// NetDial specifies the dial function for creating TCP connections. If
// NetDial is nil, net.Dial is used.
NetDial func(network, addr string) (net.Conn, error)
// NetDialContext specifies the dial function for creating TCP connections. If
// NetDialContext is nil, net.DialContext is used.
// NetDialContext is nil, NetDial is used.
NetDialContext func(ctx context.Context, network, addr string) (net.Conn, error)
// NetDialTLSContext specifies the dial function for creating TLS/TCP connections. If
// NetDialTLSContext is nil, NetDialContext is used.
// If NetDialTLSContext is set, Dial assumes the TLS handshake is done there and
// TLSClientConfig is ignored.
NetDialTLSContext func(ctx context.Context, network, addr string) (net.Conn, error)
// Proxy specifies a function to return a proxy for a given
// Request. If the function returns a non-nil error, the
// request is aborted with the provided error.
@ -65,6 +73,8 @@ type Dialer struct {
// TLSClientConfig specifies the TLS configuration to use with tls.Client.
// If nil, the default configuration is used.
// If either NetDialTLS or NetDialTLSContext are set, Dial assumes the TLS handshake
// is done there and TLSClientConfig is ignored.
TLSClientConfig *tls.Config
// HandshakeTimeout specifies the duration for the handshake to complete.
@ -176,7 +186,7 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
req := &http.Request{
Method: "GET",
Method: http.MethodGet,
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
@ -237,13 +247,32 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
// Get network dial function.
var netDial func(network, add string) (net.Conn, error)
switch u.Scheme {
case "http":
if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
} else {
}
case "https":
if d.NetDialTLSContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialTLSContext(ctx, network, addr)
}
} else if d.NetDialContext != nil {
netDial = func(network, addr string) (net.Conn, error) {
return d.NetDialContext(ctx, network, addr)
}
} else if d.NetDial != nil {
netDial = d.NetDial
}
default:
return nil, nil, errMalformedURL
}
if netDial == nil {
netDialer := &net.Dialer{}
netDial = func(network, addr string) (net.Conn, error) {
return netDialer.DialContext(ctx, network, addr)
@ -304,7 +333,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
}()
if u.Scheme == "https" {
if u.Scheme == "https" && d.NetDialTLSContext == nil {
// If NetDialTLSContext is set, assume that the TLS handshake has already been done
cfg := cloneTLSConfig(d.TLSClientConfig)
if cfg.ServerName == "" {
cfg.ServerName = hostNoPort
@ -312,11 +343,12 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
tlsConn := tls.Client(netConn, cfg)
netConn = tlsConn
var err error
if trace != nil {
err = doHandshakeWithTrace(trace, tlsConn, cfg)
} else {
err = doHandshake(tlsConn, cfg)
if trace != nil && trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(ctx, tlsConn, cfg)
if trace != nil && trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
if err != nil {
@ -348,8 +380,8 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
}
if resp.StatusCode != 101 ||
!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
!tokenListContainsValue(resp.Header, "Upgrade", "websocket") ||
!tokenListContainsValue(resp.Header, "Connection", "upgrade") ||
resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
// Before closing the network connection on return from this
// function, slurp up some of the response to aid application
@ -382,14 +414,9 @@ func (d *Dialer) DialContext(ctx context.Context, urlStr string, requestHeader h
return conn, resp, nil
}
func doHandshake(tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
return cfg.Clone()
}

View file

@ -1,16 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "crypto/tls"
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return cfg.Clone()
}

View file

@ -1,38 +0,0 @@
// Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
import "crypto/tls"
// cloneTLSConfig clones all public fields except the fields
// SessionTicketsDisabled and SessionTicketKey. This avoids copying the
// sync.Mutex in the sync.Once and makes it safe to call cloneTLSConfig on a
// config in active use.
func cloneTLSConfig(cfg *tls.Config) *tls.Config {
if cfg == nil {
return &tls.Config{}
}
return &tls.Config{
Rand: cfg.Rand,
Time: cfg.Time,
Certificates: cfg.Certificates,
NameToCertificate: cfg.NameToCertificate,
GetCertificate: cfg.GetCertificate,
RootCAs: cfg.RootCAs,
NextProtos: cfg.NextProtos,
ServerName: cfg.ServerName,
ClientAuth: cfg.ClientAuth,
ClientCAs: cfg.ClientCAs,
InsecureSkipVerify: cfg.InsecureSkipVerify,
CipherSuites: cfg.CipherSuites,
PreferServerCipherSuites: cfg.PreferServerCipherSuites,
ClientSessionCache: cfg.ClientSessionCache,
MinVersion: cfg.MinVersion,
MaxVersion: cfg.MaxVersion,
CurvePreferences: cfg.CurvePreferences,
}
}

View file

@ -13,6 +13,7 @@ import (
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
@ -401,6 +402,12 @@ func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error
return nil
}
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}
// WriteControl writes a control message with the given deadline. The allowed
// message types are CloseMessage, PingMessage and PongMessage.
func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error {
@ -794,47 +801,69 @@ func (c *Conn) advanceFrame() (int, error) {
}
// 2. Read and parse first two bytes of frame header.
// To aid debugging, collect and report all errors in the first two bytes
// of the header.
var errors []string
p, err := c.read(2)
if err != nil {
return noFrame, err
}
final := p[0]&finalBit != 0
frameType := int(p[0] & 0xf)
final := p[0]&finalBit != 0
rsv1 := p[0]&rsv1Bit != 0
rsv2 := p[0]&rsv2Bit != 0
rsv3 := p[0]&rsv3Bit != 0
mask := p[1]&maskBit != 0
c.setReadRemaining(int64(p[1] & 0x7f))
c.readDecompress = false
if c.newDecompressionReader != nil && (p[0]&rsv1Bit) != 0 {
if rsv1 {
if c.newDecompressionReader != nil {
c.readDecompress = true
p[0] &^= rsv1Bit
} else {
errors = append(errors, "RSV1 set")
}
}
if rsv := p[0] & (rsv1Bit | rsv2Bit | rsv3Bit); rsv != 0 {
return noFrame, c.handleProtocolError("unexpected reserved bits 0x" + strconv.FormatInt(int64(rsv), 16))
if rsv2 {
errors = append(errors, "RSV2 set")
}
if rsv3 {
errors = append(errors, "RSV3 set")
}
switch frameType {
case CloseMessage, PingMessage, PongMessage:
if c.readRemaining > maxControlFramePayloadSize {
return noFrame, c.handleProtocolError("control frame length > 125")
errors = append(errors, "len > 125 for control")
}
if !final {
return noFrame, c.handleProtocolError("control frame not final")
errors = append(errors, "FIN not set on control")
}
case TextMessage, BinaryMessage:
if !c.readFinal {
return noFrame, c.handleProtocolError("message start before final message frame")
errors = append(errors, "data before FIN")
}
c.readFinal = final
case continuationFrame:
if c.readFinal {
return noFrame, c.handleProtocolError("continuation after final message frame")
errors = append(errors, "continuation after FIN")
}
c.readFinal = final
default:
return noFrame, c.handleProtocolError("unknown opcode " + strconv.Itoa(frameType))
errors = append(errors, "bad opcode "+strconv.Itoa(frameType))
}
if mask != c.isServer {
errors = append(errors, "bad MASK")
}
if len(errors) > 0 {
return noFrame, c.handleProtocolError(strings.Join(errors, ", "))
}
// 3. Read and parse frame length as per
@ -872,10 +901,6 @@ func (c *Conn) advanceFrame() (int, error) {
// 4. Handle frame masking.
if mask != c.isServer {
return noFrame, c.handleProtocolError("incorrect mask flag")
}
if mask {
c.readMaskPos = 0
p, err := c.read(len(c.readMaskKey))
@ -935,7 +960,7 @@ func (c *Conn) advanceFrame() (int, error) {
if len(payload) >= 2 {
closeCode = int(binary.BigEndian.Uint16(payload))
if !isValidReceivedCloseCode(closeCode) {
return noFrame, c.handleProtocolError("invalid close code")
return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode))
}
closeText = string(payload[2:])
if !utf8.ValidString(closeText) {
@ -952,7 +977,11 @@ func (c *Conn) advanceFrame() (int, error) {
}
func (c *Conn) handleProtocolError(message string) error {
c.WriteControl(CloseMessage, FormatCloseMessage(CloseProtocolError, message), time.Now().Add(writeWait))
data := FormatCloseMessage(CloseProtocolError, message)
if len(data) > maxControlFramePayloadSize {
data = data[:maxControlFramePayloadSize]
}
c.WriteControl(CloseMessage, data, time.Now().Add(writeWait))
return errors.New("websocket: " + message)
}

View file

@ -1,15 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build go1.8
package websocket
import "net"
func (c *Conn) writeBufs(bufs ...[]byte) error {
b := net.Buffers(bufs)
_, err := b.WriteTo(c.conn)
return err
}

View file

@ -1,18 +0,0 @@
// Copyright 2016 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build !go1.8
package websocket
func (c *Conn) writeBufs(bufs ...[]byte) error {
for _, buf := range bufs {
if len(buf) > 0 {
if _, err := c.conn.Write(buf); err != nil {
return err
}
}
}
return nil
}

View file

@ -2,6 +2,7 @@
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build !appengine
// +build !appengine
package websocket

View file

@ -2,6 +2,7 @@
// this source code is governed by a BSD-style license that can be found in the
// LICENSE file.
//go:build appengine
// +build appengine
package websocket

View file

@ -48,7 +48,7 @@ func (hpd *httpProxyDialer) Dial(network string, addr string) (net.Conn, error)
}
connectReq := &http.Request{
Method: "CONNECT",
Method: http.MethodConnect,
URL: &url.URL{Opaque: addr},
Host: addr,
Header: connectHeader,

View file

@ -23,6 +23,8 @@ func (e HandshakeError) Error() string { return e.message }
// Upgrader specifies parameters for upgrading an HTTP connection to a
// WebSocket connection.
//
// It is safe to call Upgrader's methods concurrently.
type Upgrader struct {
// HandshakeTimeout specifies the duration for the handshake to complete.
HandshakeTimeout time.Duration
@ -115,8 +117,8 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header
// Upgrade upgrades the HTTP server connection to the WebSocket protocol.
//
// The responseHeader is included in the response to the client's upgrade
// request. Use the responseHeader to specify cookies (Set-Cookie) and the
// application negotiated subprotocol (Sec-WebSocket-Protocol).
// request. Use the responseHeader to specify cookies (Set-Cookie). To specify
// subprotocols supported by the server, set Upgrader.Subprotocols directly.
//
// If the upgrade fails, then Upgrade replies to the client with an HTTP error
// response.
@ -131,7 +133,7 @@ func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeade
return u.returnError(w, r, http.StatusBadRequest, badHandshake+"'websocket' token not found in 'Upgrade' header")
}
if r.Method != "GET" {
if r.Method != http.MethodGet {
return u.returnError(w, r, http.StatusMethodNotAllowed, badHandshake+"request method is not GET")
}

21
vendor/github.com/gorilla/websocket/tls_handshake.go generated vendored Normal file
View file

@ -0,0 +1,21 @@
//go:build go1.17
// +build go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.HandshakeContext(ctx); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View file

@ -0,0 +1,21 @@
//go:build !go1.17
// +build !go1.17
package websocket
import (
"context"
"crypto/tls"
)
func doHandshake(ctx context.Context, tlsConn *tls.Conn, cfg *tls.Config) error {
if err := tlsConn.Handshake(); err != nil {
return err
}
if !cfg.InsecureSkipVerify {
if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
return err
}
}
return nil
}

View file

@ -1,19 +0,0 @@
// +build go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
if trace.TLSHandshakeStart != nil {
trace.TLSHandshakeStart()
}
err := doHandshake(tlsConn, cfg)
if trace.TLSHandshakeDone != nil {
trace.TLSHandshakeDone(tlsConn.ConnectionState(), err)
}
return err
}

View file

@ -1,12 +0,0 @@
// +build !go1.8
package websocket
import (
"crypto/tls"
"net/http/httptrace"
)
func doHandshakeWithTrace(trace *httptrace.ClientTrace, tlsConn *tls.Conn, cfg *tls.Config) error {
return doHandshake(tlsConn, cfg)
}

View file

@ -1,29 +0,0 @@
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
*.prof
.vscode
.idea
.playground
dist/

View file

@ -1,34 +0,0 @@
# This is an example goreleaser.yaml file with some sane defaults.
# Make sure to check the documentation at http://goreleaser.com
builds:
- env:
- CGO_ENABLED=0
goos:
- darwin
- windows
- linux
goarch:
- amd64
- arm
- arm64
ldflags:
- -X main.Version={{.Version}}
archives:
- replacements:
darwin: macOS
linux: Linux
windows: Windows
386: i386
amd64: x86_64
universal_binaries:
- replace: false
checksum:
name_template: 'checksums.txt'
snapshot:
name_template: "{{ .Tag }}"
changelog:
sort: asc
filters:
exclude:
- '^docs:'
- '^test:'

View file

@ -1,21 +0,0 @@
MIT License
Copyright (c) 2016 Mat Ryer and David Hernandez
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View file

@ -1,135 +0,0 @@
![moq logo](moq-logo-small.png) [![build](https://github.com/matryer/moq/workflows/build/badge.svg)](https://github.com/matryer/moq/actions?query=branch%3Amaster) [![Go Report Card](https://goreportcard.com/badge/github.com/matryer/moq)](https://goreportcard.com/report/github.com/matryer/moq)
Interface mocking tool for go generate.
### What is Moq?
Moq is a tool that generates a struct from any interface. The struct can be used in test code as a mock of the interface.
![Preview](preview.png)
above: Moq generates the code on the right.
You can read more in the [Meet Moq blog post](http://bit.ly/meetmoq).
### Installing
To start using latest released version of Moq, just run:
#### Go version < 1.16
```
$ go get github.com/matryer/moq
```
#### Go 1.16+
```
$ go install github.com/matryer/moq@latest
```
### Usage
```
moq [flags] source-dir interface [interface2 [interface3 [...]]]
-fmt string
go pretty-printer: gofmt, goimports or noop (default gofmt)
-out string
output file (default stdout)
-pkg string
package name (default will infer)
-stub
return zero values when no mock implementation is provided, do not panic
-skip-ensure
suppress mock implementation check, avoid import cycle if mocks
generated outside of the tested package
Specifying an alias for the mock is also supported with the format 'interface:alias'
Example: moq -pkg different . MyInterface:MyMock
```
**NOTE:** `source-dir` is the directory where the source code (definition) of the target interface is located.
It needs to be a path to a directory and not the import statement for a Go package.
In a command line:
```
$ moq -out mocks_test.go . MyInterface
```
In code (for go generate):
```go
package my
//go:generate moq -out myinterface_moq_test.go . MyInterface
type MyInterface interface {
Method1() error
Method2(i int)
}
```
Then run `go generate` for your package.
### How to use it
Mocking interfaces is a nice way to write unit tests where you can easily control the behaviour of the mocked object.
Moq creates a struct that has a function field for each method, which you can declare in your test code.
In this example, Moq generated the `EmailSenderMock` type:
```go
func TestCompleteSignup(t *testing.T) {
var sentTo string
mockedEmailSender = &EmailSenderMock{
SendFunc: func(to, subject, body string) error {
sentTo = to
return nil
},
}
CompleteSignUp("me@email.com", mockedEmailSender)
callsToSend := len(mockedEmailSender.SendCalls())
if callsToSend != 1 {
t.Errorf("Send was called %d times", callsToSend)
}
if sentTo != "me@email.com" {
t.Errorf("unexpected recipient: %s", sentTo)
}
}
func CompleteSignUp(to string, sender EmailSender) {
// TODO: this
}
```
The mocked structure implements the interface, where each method calls the associated function field.
## Tips
* Keep mocked logic inside the test that is using it
* Only mock the fields you need
* It will panic if a nil function gets called
* Name arguments in the interface for a better experience
* Use closured variables inside your test function to capture details about the calls to the methods
* Use `.MethodCalls()` to track the calls
* Use `go:generate` to invoke the `moq` command
* If Moq fails with a `go/format` error, it indicates the generated code was not valid.
You can run the same command with `-fmt noop` to print the generated source code without attempting to format it.
This can aid in debugging the root cause.
## License
The Moq project (and all code) is licensed under the [MIT License](LICENSE).
Moq was created by [Mat Ryer](https://twitter.com/matryer) and [David Hernandez](https://github.com/dahernan), with ideas lovingly stolen from [Ernesto Jimenez](https://github.com/ernesto-jimenez). Featuring a major refactor by @sudo-suhas, as well as lots of other contributors.
The Moq logo was created by [Chris Ryer](http://chrisryer.co.uk) and is licensed under the [Creative Commons Attribution 3.0 License](https://creativecommons.org/licenses/by/3.0/).

View file

@ -1,135 +0,0 @@
package registry
import (
"go/types"
"strconv"
)
// MethodScope is the sub-registry for allocating variables present in
// the method scope.
//
// It should be created using a registry instance.
type MethodScope struct {
registry *Registry
moqPkgPath string
vars []*Var
conflicted map[string]bool
}
// AddVar allocates a variable instance and adds it to the method scope.
//
// Variables names are generated if required and are ensured to be
// without conflict with other variables and imported packages. It also
// adds the relevant imports to the registry for each added variable.
func (m *MethodScope) AddVar(vr *types.Var, suffix string) *Var {
imports := make(map[string]*Package)
m.populateImports(vr.Type(), imports)
m.resolveImportVarConflicts(imports)
name := varName(vr, suffix)
// Ensure that the var name does not conflict with a package import.
if _, ok := m.registry.searchImport(name); ok {
name += "MoqParam"
}
if _, ok := m.searchVar(name); ok || m.conflicted[name] {
name = m.resolveVarNameConflict(name)
}
v := Var{
vr: vr,
imports: imports,
moqPkgPath: m.moqPkgPath,
Name: name,
}
m.vars = append(m.vars, &v)
return &v
}
func (m *MethodScope) resolveVarNameConflict(suggested string) string {
for n := 1; ; n++ {
_, ok := m.searchVar(suggested + strconv.Itoa(n))
if ok {
continue
}
if n == 1 {
conflict, _ := m.searchVar(suggested)
conflict.Name += "1"
m.conflicted[suggested] = true
n++
}
return suggested + strconv.Itoa(n)
}
}
func (m MethodScope) searchVar(name string) (*Var, bool) {
for _, v := range m.vars {
if v.Name == name {
return v, true
}
}
return nil, false
}
// populateImports extracts all the package imports for a given type
// recursively. The imported packages by a single type can be more than
// one (ex: map[a.Type]b.Type).
func (m MethodScope) populateImports(t types.Type, imports map[string]*Package) {
switch t := t.(type) {
case *types.Named:
if pkg := t.Obj().Pkg(); pkg != nil {
imports[stripVendorPath(pkg.Path())] = m.registry.AddImport(pkg)
}
case *types.Array:
m.populateImports(t.Elem(), imports)
case *types.Slice:
m.populateImports(t.Elem(), imports)
case *types.Signature:
for i := 0; i < t.Params().Len(); i++ {
m.populateImports(t.Params().At(i).Type(), imports)
}
for i := 0; i < t.Results().Len(); i++ {
m.populateImports(t.Results().At(i).Type(), imports)
}
case *types.Map:
m.populateImports(t.Key(), imports)
m.populateImports(t.Elem(), imports)
case *types.Chan:
m.populateImports(t.Elem(), imports)
case *types.Pointer:
m.populateImports(t.Elem(), imports)
case *types.Struct: // anonymous struct
for i := 0; i < t.NumFields(); i++ {
m.populateImports(t.Field(i).Type(), imports)
}
case *types.Interface: // anonymous interface
for i := 0; i < t.NumExplicitMethods(); i++ {
m.populateImports(t.ExplicitMethod(i).Type(), imports)
}
for i := 0; i < t.NumEmbeddeds(); i++ {
m.populateImports(t.EmbeddedType(i), imports)
}
}
}
// resolveImportVarConflicts ensures that all the newly added imports do not
// conflict with any of the existing vars.
func (m MethodScope) resolveImportVarConflicts(imports map[string]*Package) {
// Ensure that all the newly added imports do not conflict with any of the
// existing vars.
for _, imprt := range imports {
if v, ok := m.searchVar(imprt.Qualifier()); ok {
v.Name += "MoqParam"
}
}
}

View file

@ -1,93 +0,0 @@
package registry
import (
"go/types"
"path"
"strings"
)
// Package represents an imported package.
type Package struct {
pkg *types.Package
Alias string
}
// NewPackage creates a new instance of Package.
func NewPackage(pkg *types.Package) *Package { return &Package{pkg: pkg} }
// Qualifier returns the qualifier which must be used to refer to types
// declared in the package.
func (p *Package) Qualifier() string {
if p == nil {
return ""
}
if p.Alias != "" {
return p.Alias
}
return p.pkg.Name()
}
// Path is the full package import path (without vendor).
func (p *Package) Path() string {
if p == nil {
return ""
}
return stripVendorPath(p.pkg.Path())
}
var replacer = strings.NewReplacer(
"go-", "",
"-go", "",
"-", "",
"_", "",
".", "",
"@", "",
"+", "",
"~", "",
)
// uniqueName generates a unique name for a package by concatenating
// path components. The generated name is guaranteed to unique with an
// appropriate level because the full package import paths themselves
// are unique.
func (p Package) uniqueName(lvl int) string {
pp := strings.Split(p.Path(), "/")
reverse(pp)
var name string
for i := 0; i < min(len(pp), lvl+1); i++ {
name = strings.ToLower(replacer.Replace(pp[i])) + name
}
return name
}
// stripVendorPath strips the vendor dir prefix from a package path.
// For example we might encounter an absolute path like
// github.com/foo/bar/vendor/github.com/pkg/errors which is resolved
// to github.com/pkg/errors.
func stripVendorPath(p string) string {
parts := strings.Split(p, "/vendor/")
if len(parts) == 1 {
return p
}
return strings.TrimLeft(path.Join(parts[1:]...), "/")
}
func min(a, b int) int {
if a < b {
return a
}
return b
}
func reverse(a []string) {
for i := len(a)/2 - 1; i >= 0; i-- {
opp := len(a) - 1 - i
a[i], a[opp] = a[opp], a[i]
}
}

View file

@ -1,190 +0,0 @@
package registry
import (
"errors"
"fmt"
"go/types"
"path/filepath"
"sort"
"strings"
"golang.org/x/tools/go/packages"
)
// Registry encapsulates types information for the source and mock
// destination package. For the mock package, it tracks the list of
// imports and ensures there are no conflicts in the imported package
// qualifiers.
type Registry struct {
srcPkg *packages.Package
moqPkgPath string
aliases map[string]string
imports map[string]*Package
}
// New loads the source package info and returns a new instance of
// Registry.
func New(srcDir, moqPkg string) (*Registry, error) {
srcPkg, err := pkgInfoFromPath(
srcDir, packages.NeedName|packages.NeedSyntax|packages.NeedTypes|packages.NeedTypesInfo,
)
if err != nil {
return nil, fmt.Errorf("couldn't load source package: %s", err)
}
return &Registry{
srcPkg: srcPkg,
moqPkgPath: findPkgPath(moqPkg, srcPkg),
aliases: parseImportsAliases(srcPkg),
imports: make(map[string]*Package),
}, nil
}
// SrcPkg returns the types info for the source package.
func (r Registry) SrcPkg() *types.Package {
return r.srcPkg.Types
}
// SrcPkgName returns the name of the source package.
func (r Registry) SrcPkgName() string {
return r.srcPkg.Name
}
// LookupInterface returns the underlying interface definition of the
// given interface name.
func (r Registry) LookupInterface(name string) (*types.Interface, error) {
obj := r.SrcPkg().Scope().Lookup(name)
if obj == nil {
return nil, fmt.Errorf("interface not found: %s", name)
}
if !types.IsInterface(obj.Type()) {
return nil, fmt.Errorf("%s (%s) is not an interface", name, obj.Type())
}
return obj.Type().Underlying().(*types.Interface).Complete(), nil
}
// MethodScope returns a new MethodScope.
func (r *Registry) MethodScope() *MethodScope {
return &MethodScope{
registry: r,
moqPkgPath: r.moqPkgPath,
conflicted: map[string]bool{},
}
}
// AddImport adds the given package to the set of imports. It generates a
// suitable alias if there are any conflicts with previously imported
// packages.
func (r *Registry) AddImport(pkg *types.Package) *Package {
path := stripVendorPath(pkg.Path())
if path == r.moqPkgPath {
return nil
}
if imprt, ok := r.imports[path]; ok {
return imprt
}
imprt := Package{pkg: pkg, Alias: r.aliases[path]}
if conflict, ok := r.searchImport(imprt.Qualifier()); ok {
resolveImportConflict(&imprt, conflict, 0)
}
r.imports[path] = &imprt
return &imprt
}
// Imports returns the list of imported packages. The list is sorted by
// path.
func (r Registry) Imports() []*Package {
imports := make([]*Package, 0, len(r.imports))
for _, imprt := range r.imports {
imports = append(imports, imprt)
}
sort.Slice(imports, func(i, j int) bool {
return imports[i].Path() < imports[j].Path()
})
return imports
}
func (r Registry) searchImport(name string) (*Package, bool) {
for _, imprt := range r.imports {
if imprt.Qualifier() == name {
return imprt, true
}
}
return nil, false
}
func pkgInfoFromPath(srcDir string, mode packages.LoadMode) (*packages.Package, error) {
pkgs, err := packages.Load(&packages.Config{
Mode: mode,
Dir: srcDir,
})
if err != nil {
return nil, err
}
if len(pkgs) == 0 {
return nil, errors.New("package not found")
}
if len(pkgs) > 1 {
return nil, errors.New("found more than one package")
}
if errs := pkgs[0].Errors; len(errs) != 0 {
if len(errs) == 1 {
return nil, errs[0]
}
return nil, fmt.Errorf("%s (and %d more errors)", errs[0], len(errs)-1)
}
return pkgs[0], nil
}
func findPkgPath(pkgInputVal string, srcPkg *packages.Package) string {
if pkgInputVal == "" {
return srcPkg.PkgPath
}
if pkgInDir(srcPkg.PkgPath, pkgInputVal) {
return srcPkg.PkgPath
}
subdirectoryPath := filepath.Join(srcPkg.PkgPath, pkgInputVal)
if pkgInDir(subdirectoryPath, pkgInputVal) {
return subdirectoryPath
}
return ""
}
func pkgInDir(pkgName, dir string) bool {
currentPkg, err := pkgInfoFromPath(dir, packages.NeedName)
if err != nil {
return false
}
return currentPkg.Name == pkgName || currentPkg.Name+"_test" == pkgName
}
func parseImportsAliases(pkg *packages.Package) map[string]string {
aliases := make(map[string]string)
for _, syntax := range pkg.Syntax {
for _, imprt := range syntax.Imports {
if imprt.Name != nil && imprt.Name.Name != "." && imprt.Name.Name != "_" {
aliases[strings.Trim(imprt.Path.Value, `"`)] = imprt.Name.Name
}
}
}
return aliases
}
// resolveImportConflict generates and assigns a unique alias for
// packages with conflicting qualifiers.
func resolveImportConflict(a, b *Package, lvl int) {
u1, u2 := a.uniqueName(lvl), b.uniqueName(lvl)
if u1 != u2 {
a.Alias, b.Alias = u1, u2
return
}
resolveImportConflict(a, b, lvl+1)
}

View file

@ -1,146 +0,0 @@
package registry
import (
"go/types"
"strings"
)
// Var represents a method variable/parameter.
//
// It should be created using a method scope instance.
type Var struct {
vr *types.Var
imports map[string]*Package
moqPkgPath string
Name string
}
// IsSlice returns whether the type (or the underlying type) is a slice.
func (v Var) IsSlice() bool {
_, ok := v.vr.Type().Underlying().(*types.Slice)
return ok
}
// TypeString returns the variable type with the package qualifier in the
// format 'pkg.Type'.
func (v Var) TypeString() string {
return types.TypeString(v.vr.Type(), v.packageQualifier)
}
// packageQualifier is a types.Qualifier.
func (v Var) packageQualifier(pkg *types.Package) string {
path := stripVendorPath(pkg.Path())
if v.moqPkgPath != "" && v.moqPkgPath == path {
return ""
}
return v.imports[path].Qualifier()
}
func varName(vr *types.Var, suffix string) string {
name := vr.Name()
if name != "" && name != "_" {
return name + suffix
}
name = varNameForType(vr.Type()) + suffix
switch name {
case "mock", "callInfo", "break", "default", "func", "interface", "select", "case", "defer", "go", "map", "struct",
"chan", "else", "goto", "package", "switch", "const", "fallthrough", "if", "range", "type", "continue", "for",
"import", "return", "var",
// avoid shadowing basic types
"string", "bool", "byte", "rune", "uintptr",
"int", "int8", "int16", "int32", "int64",
"uint", "uint8", "uint16", "uint32", "uint64",
"float32", "float64", "complex64", "complex128":
name += "MoqParam"
}
return name
}
// varNameForType generates a name for the variable using the type
// information.
//
// Examples:
// - string -> s
// - int -> n
// - chan int -> intCh
// - []a.MyType -> myTypes
// - map[string]int -> stringToInt
// - error -> err
// - a.MyType -> myType
func varNameForType(t types.Type) string {
nestedType := func(t types.Type) string {
if t, ok := t.(*types.Basic); ok {
return deCapitalise(t.String())
}
return varNameForType(t)
}
switch t := t.(type) {
case *types.Named:
if t.Obj().Name() == "error" {
return "err"
}
name := deCapitalise(t.Obj().Name())
if name == t.Obj().Name() {
name += "MoqParam"
}
return name
case *types.Basic:
return basicTypeVarName(t)
case *types.Array:
return nestedType(t.Elem()) + "s"
case *types.Slice:
return nestedType(t.Elem()) + "s"
case *types.Struct: // anonymous struct
return "val"
case *types.Pointer:
return varNameForType(t.Elem())
case *types.Signature:
return "fn"
case *types.Interface: // anonymous interface
return "ifaceVal"
case *types.Map:
return nestedType(t.Key()) + "To" + capitalise(nestedType(t.Elem()))
case *types.Chan:
return nestedType(t.Elem()) + "Ch"
}
return "v"
}
func basicTypeVarName(b *types.Basic) string {
switch b.Info() {
case types.IsBoolean:
return "b"
case types.IsInteger:
return "n"
case types.IsFloat:
return "f"
case types.IsString:
return "s"
}
return "v"
}
func capitalise(s string) string { return strings.ToUpper(s[:1]) + s[1:] }
func deCapitalise(s string) string { return strings.ToLower(s[:1]) + s[1:] }

View file

@ -1,190 +0,0 @@
package template
import (
"io"
"strings"
"text/template"
"github.com/matryer/moq/internal/registry"
)
// Template is the Moq template. It is capable of generating the Moq
// implementation for the given template.Data.
type Template struct {
tmpl *template.Template
}
// New returns a new instance of Template.
func New() (Template, error) {
tmpl, err := template.New("moq").Funcs(templateFuncs).Parse(moqTemplate)
if err != nil {
return Template{}, err
}
return Template{tmpl: tmpl}, nil
}
// Execute generates and writes the Moq implementation for the given
// data.
func (t Template) Execute(w io.Writer, data Data) error {
return t.tmpl.Execute(w, data)
}
// moqTemplate is the template for mocked code.
// language=GoTemplate
var moqTemplate = `// Code generated by moq; DO NOT EDIT.
// github.com/matryer/moq
package {{.PkgName}}
import (
{{- range .Imports}}
{{. | ImportStatement}}
{{- end}}
)
{{range $i, $mock := .Mocks -}}
{{- if not $.SkipEnsure -}}
// Ensure, that {{.MockName}} does implement {{$.SrcPkgQualifier}}{{.InterfaceName}}.
// If this is not the case, regenerate this file with moq.
var _ {{$.SrcPkgQualifier}}{{.InterfaceName}} = &{{.MockName}}{}
{{- end}}
// {{.MockName}} is a mock implementation of {{$.SrcPkgQualifier}}{{.InterfaceName}}.
//
// func TestSomethingThatUses{{.InterfaceName}}(t *testing.T) {
//
// // make and configure a mocked {{$.SrcPkgQualifier}}{{.InterfaceName}}
// mocked{{.InterfaceName}} := &{{.MockName}}{
{{- range .Methods}}
// {{.Name}}Func: func({{.ArgList}}) {{.ReturnArgTypeList}} {
// panic("mock out the {{.Name}} method")
// },
{{- end}}
// }
//
// // use mocked{{.InterfaceName}} in code that requires {{$.SrcPkgQualifier}}{{.InterfaceName}}
// // and then make assertions.
//
// }
type {{.MockName}} struct {
{{- range .Methods}}
// {{.Name}}Func mocks the {{.Name}} method.
{{.Name}}Func func({{.ArgList}}) {{.ReturnArgTypeList}}
{{end}}
// calls tracks calls to the methods.
calls struct {
{{- range .Methods}}
// {{.Name}} holds details about calls to the {{.Name}} method.
{{.Name}} []struct {
{{- range .Params}}
// {{.Name | Exported}} is the {{.Name}} argument value.
{{.Name | Exported}} {{.TypeString}}
{{- end}}
}
{{- end}}
}
{{- range .Methods}}
lock{{.Name}} {{$.Imports | SyncPkgQualifier}}.RWMutex
{{- end}}
}
{{range .Methods}}
// {{.Name}} calls {{.Name}}Func.
func (mock *{{$mock.MockName}}) {{.Name}}({{.ArgList}}) {{.ReturnArgTypeList}} {
{{- if not $.StubImpl}}
if mock.{{.Name}}Func == nil {
panic("{{$mock.MockName}}.{{.Name}}Func: method is nil but {{$mock.InterfaceName}}.{{.Name}} was just called")
}
{{- end}}
callInfo := struct {
{{- range .Params}}
{{.Name | Exported}} {{.TypeString}}
{{- end}}
}{
{{- range .Params}}
{{.Name | Exported}}: {{.Name}},
{{- end}}
}
mock.lock{{.Name}}.Lock()
mock.calls.{{.Name}} = append(mock.calls.{{.Name}}, callInfo)
mock.lock{{.Name}}.Unlock()
{{- if .Returns}}
{{- if $.StubImpl}}
if mock.{{.Name}}Func == nil {
var (
{{- range .Returns}}
{{.Name}} {{.TypeString}}
{{- end}}
)
return {{.ReturnArgNameList}}
}
{{- end}}
return mock.{{.Name}}Func({{.ArgCallList}})
{{- else}}
{{- if $.StubImpl}}
if mock.{{.Name}}Func == nil {
return
}
{{- end}}
mock.{{.Name}}Func({{.ArgCallList}})
{{- end}}
}
// {{.Name}}Calls gets all the calls that were made to {{.Name}}.
// Check the length with:
// len(mocked{{$mock.InterfaceName}}.{{.Name}}Calls())
func (mock *{{$mock.MockName}}) {{.Name}}Calls() []struct {
{{- range .Params}}
{{.Name | Exported}} {{.TypeString}}
{{- end}}
} {
var calls []struct {
{{- range .Params}}
{{.Name | Exported}} {{.TypeString}}
{{- end}}
}
mock.lock{{.Name}}.RLock()
calls = mock.calls.{{.Name}}
mock.lock{{.Name}}.RUnlock()
return calls
}
{{end -}}
{{end -}}`
// This list comes from the golint codebase. Golint will complain about any of
// these being mixed-case, like "Id" instead of "ID".
var golintInitialisms = []string{
"ACL", "API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS",
"QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SQL", "SSH", "TCP", "TLS", "TTL", "UDP", "UI", "UID", "UUID", "URI",
"URL", "UTF8", "VM", "XML", "XMPP", "XSRF", "XSS",
}
var templateFuncs = template.FuncMap{
"ImportStatement": func(imprt *registry.Package) string {
if imprt.Alias == "" {
return `"` + imprt.Path() + `"`
}
return imprt.Alias + ` "` + imprt.Path() + `"`
},
"SyncPkgQualifier": func(imports []*registry.Package) string {
for _, imprt := range imports {
if imprt.Path() == "sync" {
return imprt.Qualifier()
}
}
return "sync"
},
"Exported": func(s string) string {
if s == "" {
return ""
}
for _, initialism := range golintInitialisms {
if strings.ToUpper(s) == initialism {
return initialism
}
}
return strings.ToUpper(s[0:1]) + s[1:]
},
}

View file

@ -1,125 +0,0 @@
package template
import (
"fmt"
"strings"
"github.com/matryer/moq/internal/registry"
)
// Data is the template data used to render the Moq template.
type Data struct {
PkgName string
SrcPkgQualifier string
Imports []*registry.Package
Mocks []MockData
StubImpl bool
SkipEnsure bool
}
// MocksSomeMethod returns true of any one of the Mocks has at least 1
// method.
func (d Data) MocksSomeMethod() bool {
for _, m := range d.Mocks {
if len(m.Methods) > 0 {
return true
}
}
return false
}
// MockData is the data used to generate a mock for some interface.
type MockData struct {
InterfaceName string
MockName string
Methods []MethodData
}
// MethodData is the data which represents a method on some interface.
type MethodData struct {
Name string
Params []ParamData
Returns []ParamData
}
// ArgList is the string representation of method parameters, ex:
// 's string, n int, foo bar.Baz'.
func (m MethodData) ArgList() string {
params := make([]string, len(m.Params))
for i, p := range m.Params {
params[i] = p.MethodArg()
}
return strings.Join(params, ", ")
}
// ArgCallList is the string representation of method call parameters,
// ex: 's, n, foo'. In case of a last variadic parameter, it will be of
// the format 's, n, foos...'
func (m MethodData) ArgCallList() string {
params := make([]string, len(m.Params))
for i, p := range m.Params {
params[i] = p.CallName()
}
return strings.Join(params, ", ")
}
// ReturnArgTypeList is the string representation of method return
// types, ex: 'bar.Baz', '(string, error)'.
func (m MethodData) ReturnArgTypeList() string {
params := make([]string, len(m.Returns))
for i, p := range m.Returns {
params[i] = p.TypeString()
}
if len(m.Returns) > 1 {
return fmt.Sprintf("(%s)", strings.Join(params, ", "))
}
return strings.Join(params, ", ")
}
// ReturnArgNameList is the string representation of values being
// returned from the method, ex: 'foo', 's, err'.
func (m MethodData) ReturnArgNameList() string {
params := make([]string, len(m.Returns))
for i, p := range m.Returns {
params[i] = p.Name()
}
return strings.Join(params, ", ")
}
// ParamData is the data which represents a parameter to some method of
// an interface.
type ParamData struct {
Var *registry.Var
Variadic bool
}
// Name returns the name of the parameter.
func (p ParamData) Name() string {
return p.Var.Name
}
// MethodArg is the representation of the parameter in the function
// signature, ex: 'name a.Type'.
func (p ParamData) MethodArg() string {
if p.Variadic {
return fmt.Sprintf("%s ...%s", p.Name(), p.TypeString()[2:])
}
return fmt.Sprintf("%s %s", p.Name(), p.TypeString())
}
// CallName returns the string representation of the parameter to be
// used for a method call. For a variadic paramter, it will be of the
// format 'foos...'.
func (p ParamData) CallName() string {
if p.Variadic {
return p.Name() + "..."
}
return p.Name()
}
// TypeString returns the string representation of the type of the
// parameter.
func (p ParamData) TypeString() string {
return p.Var.TypeString()
}

109
vendor/github.com/matryer/moq/main.go generated vendored
View file

@ -1,109 +0,0 @@
package main
import (
"bytes"
"errors"
"flag"
"fmt"
"io"
"io/ioutil"
"os"
"path/filepath"
"github.com/matryer/moq/pkg/moq"
)
// Version is the command version, injected at build time.
var Version string = "dev"
type userFlags struct {
outFile string
pkgName string
formatter string
stubImpl bool
skipEnsure bool
remove bool
args []string
}
func main() {
var flags userFlags
flag.StringVar(&flags.outFile, "out", "", "output file (default stdout)")
flag.StringVar(&flags.pkgName, "pkg", "", "package name (default will infer)")
flag.StringVar(&flags.formatter, "fmt", "", "go pretty-printer: gofmt, goimports or noop (default gofmt)")
flag.BoolVar(&flags.stubImpl, "stub", false,
"return zero values when no mock implementation is provided, do not panic")
printVersion := flag.Bool("version", false, "show the version for moq")
flag.BoolVar(&flags.skipEnsure, "skip-ensure", false,
"suppress mock implementation check, avoid import cycle if mocks generated outside of the tested package")
flag.BoolVar(&flags.remove, "rm", false, "first remove output file, if it exists")
flag.Usage = func() {
fmt.Println(`moq [flags] source-dir interface [interface2 [interface3 [...]]]`)
flag.PrintDefaults()
fmt.Println(`Specifying an alias for the mock is also supported with the format 'interface:alias'`)
fmt.Println(`Ex: moq -pkg different . MyInterface:MyMock`)
}
flag.Parse()
flags.args = flag.Args()
if *printVersion {
fmt.Printf("moq version %s\n", Version)
os.Exit(0)
}
if err := run(flags); err != nil {
fmt.Fprintln(os.Stderr, err)
flag.Usage()
os.Exit(1)
}
}
func run(flags userFlags) error {
if len(flags.args) < 2 {
return errors.New("not enough arguments")
}
if flags.remove && flags.outFile != "" {
if err := os.Remove(flags.outFile); err != nil {
if !errors.Is(err, os.ErrNotExist) {
return err
}
}
}
var buf bytes.Buffer
var out io.Writer = os.Stdout
if flags.outFile != "" {
out = &buf
}
srcDir, args := flags.args[0], flags.args[1:]
m, err := moq.New(moq.Config{
SrcDir: srcDir,
PkgName: flags.pkgName,
Formatter: flags.formatter,
StubImpl: flags.stubImpl,
SkipEnsure: flags.skipEnsure,
})
if err != nil {
return err
}
if err = m.Mock(out, args...); err != nil {
return err
}
if flags.outFile == "" {
return nil
}
// create the file
err = os.MkdirAll(filepath.Dir(flags.outFile), 0750)
if err != nil {
return err
}
return ioutil.WriteFile(flags.outFile, buf.Bytes(), 0600)
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 32 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 29 KiB

View file

@ -1,31 +0,0 @@
package moq
import (
"fmt"
"go/format"
"golang.org/x/tools/imports"
)
func goimports(src []byte) ([]byte, error) {
formatted, err := imports.Process("filename", src, &imports.Options{
TabWidth: 8,
TabIndent: true,
Comments: true,
Fragment: true,
})
if err != nil {
return nil, fmt.Errorf("goimports: %s", err)
}
return formatted, nil
}
func gofmt(src []byte) ([]byte, error) {
formatted, err := format.Source(src)
if err != nil {
return nil, fmt.Errorf("go/format: %s", err)
}
return formatted, nil
}

View file

@ -1,171 +0,0 @@
package moq
import (
"bytes"
"errors"
"go/types"
"io"
"strings"
"github.com/matryer/moq/internal/registry"
"github.com/matryer/moq/internal/template"
)
// Mocker can generate mock structs.
type Mocker struct {
cfg Config
registry *registry.Registry
tmpl template.Template
}
// Config specifies details about how interfaces should be mocked.
// SrcDir is the only field which needs be specified.
type Config struct {
SrcDir string
PkgName string
Formatter string
StubImpl bool
SkipEnsure bool
}
// New makes a new Mocker for the specified package directory.
func New(cfg Config) (*Mocker, error) {
reg, err := registry.New(cfg.SrcDir, cfg.PkgName)
if err != nil {
return nil, err
}
tmpl, err := template.New()
if err != nil {
return nil, err
}
return &Mocker{
cfg: cfg,
registry: reg,
tmpl: tmpl,
}, nil
}
// Mock generates a mock for the specified interface name.
func (m *Mocker) Mock(w io.Writer, namePairs ...string) error {
if len(namePairs) == 0 {
return errors.New("must specify one interface")
}
mocks := make([]template.MockData, len(namePairs))
for i, np := range namePairs {
name, mockName := parseInterfaceName(np)
iface, err := m.registry.LookupInterface(name)
if err != nil {
return err
}
methods := make([]template.MethodData, iface.NumMethods())
for j := 0; j < iface.NumMethods(); j++ {
methods[j] = m.methodData(iface.Method(j))
}
mocks[i] = template.MockData{
InterfaceName: name,
MockName: mockName,
Methods: methods,
}
}
data := template.Data{
PkgName: m.mockPkgName(),
Mocks: mocks,
StubImpl: m.cfg.StubImpl,
SkipEnsure: m.cfg.SkipEnsure,
}
if data.MocksSomeMethod() {
m.registry.AddImport(types.NewPackage("sync", "sync"))
}
if m.registry.SrcPkgName() != m.mockPkgName() {
data.SrcPkgQualifier = m.registry.SrcPkgName() + "."
if !m.cfg.SkipEnsure {
imprt := m.registry.AddImport(m.registry.SrcPkg())
data.SrcPkgQualifier = imprt.Qualifier() + "."
}
}
data.Imports = m.registry.Imports()
var buf bytes.Buffer
if err := m.tmpl.Execute(&buf, data); err != nil {
return err
}
formatted, err := m.format(buf.Bytes())
if err != nil {
return err
}
if _, err := w.Write(formatted); err != nil {
return err
}
return nil
}
func (m *Mocker) methodData(f *types.Func) template.MethodData {
sig := f.Type().(*types.Signature)
scope := m.registry.MethodScope()
n := sig.Params().Len()
params := make([]template.ParamData, n)
for i := 0; i < n; i++ {
p := template.ParamData{
Var: scope.AddVar(sig.Params().At(i), ""),
}
p.Variadic = sig.Variadic() && i == n-1 && p.Var.IsSlice() // check for final variadic argument
params[i] = p
}
n = sig.Results().Len()
results := make([]template.ParamData, n)
for i := 0; i < n; i++ {
results[i] = template.ParamData{
Var: scope.AddVar(sig.Results().At(i), "Out"),
}
}
return template.MethodData{
Name: f.Name(),
Params: params,
Returns: results,
}
}
func (m *Mocker) mockPkgName() string {
if m.cfg.PkgName != "" {
return m.cfg.PkgName
}
return m.registry.SrcPkgName()
}
func (m *Mocker) format(src []byte) ([]byte, error) {
switch m.cfg.Formatter {
case "goimports":
return goimports(src)
case "noop":
return src, nil
}
return gofmt(src)
}
func parseInterfaceName(namePair string) (ifaceName, mockName string) {
parts := strings.SplitN(namePair, ":", 2)
if len(parts) == 2 {
return parts[0], parts[1]
}
ifaceName = parts[0]
return ifaceName, ifaceName + "Mock"
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 726 KiB

Some files were not shown because too many files have changed in this diff Show more