Synopsis:
(Assume postgres is running on localhost.)
$ go run . &
$ psql "host=127.0.0.1 port=5555 sslmode=disable user=... password=..."
(protocol details spew)
postgres=> select 1;
(more protocol details spew)
| test.crt | |
| test.key |
This document describes how to add implementations for PostgreSQL protocol message types in messages.go.
Message types are defined in the PostgreSQL documentation: https://www.postgresql.org/docs/current/protocol-message-formats.html
In the getMessageFormatter() function, add new message type cases in alphabetical order:
Example ordering:
case 'K': // uppercase
case 'Q':
case 'R':
case 'S':
case 'T':
case 'X':
case 'Z':
case 'p': // lowercaseDefine the message type structs in the same order as they appear in the switch statement.
Each message type struct follows this pattern:
type messageName struct {
want int
have []byte
}Each message type must implement the messageFormatter interface by providing a Format([]byte) string method.
The Format method should:
m.have = append(m.have, data...)if len(m.have) < m.want { return "..." }decodeString() for null-terminated stringsbyteOrder.UintXX() for reading multi-byte integersstrings.Builder for constructing complex output%x%qSee existing message types in messages.go for examples:
terminate (no payload), readyForQuery (single byte)query, parameterStatusrowDescription, authenticationRequest| module pgmitm | |
| go 1.25.0 | |
| require github.com/spf13/pflag v1.0.10 |
| github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= | |
| github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= |
| // Synopsis: | |
| // | |
| // $ go run . --listen 127.0.0.1:5555 --dial 127.0.0.1:5432 | |
| // | |
| // $ psql "host=127.0.0.1 port=5555 sslmode=disable user=xx password=yy" | |
| package main | |
| import ( | |
| "bytes" | |
| "context" | |
| "crypto/tls" | |
| "fmt" | |
| "io" | |
| "log" | |
| "net" | |
| "os" | |
| "os/signal" | |
| "strings" | |
| "sync" | |
| "sync/atomic" | |
| "syscall" | |
| "time" | |
| "github.com/spf13/pflag" | |
| ) | |
| func main() { | |
| if err := mainImpl(); err != nil { | |
| log.Printf("fatal error: %v", err) | |
| os.Exit(1) | |
| } | |
| } | |
| func mainImpl() error { | |
| ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM) | |
| defer cancel() | |
| listenAddr := pflag.String("listen", "127.0.0.1:5555", "address to listen on") | |
| upstreamAddr := pflag.String("dial", "127.0.0.1:5432", "address to forward connections to") | |
| sslCertFile := pflag.String("ssl-cert", "", "path to an SSL certificate to use for the server (will be generated if the path is present but the file is not)") | |
| sslKeyFile := pflag.String("ssl-key", "", "path to an SSL key to use for the server (will be generated if the path is present but the file is not)") | |
| beTLSModeStr := pflag.String("tls-mode", "allow", "what to do about TLS on the backend (disable, require, allow, insecure)") | |
| pflag.Parse() | |
| var frontendTLS *tls.Config | |
| if *sslCertFile != "" && *sslKeyFile != "" { | |
| if c, err := initSSL(*sslCertFile, *sslKeyFile); err != nil { | |
| return err | |
| } else { | |
| log.Printf("TLS enabled") | |
| frontendTLS = c | |
| } | |
| } | |
| beTLS, ok := configureBackendTLS(*beTLSModeStr, strings.Split(*upstreamAddr, ":")[0]) | |
| if !ok { | |
| return fmt.Errorf("invalid tls-mode %q", *beTLSModeStr) | |
| } | |
| l, err := net.Listen("tcp", *listenAddr) | |
| if err != nil { | |
| return fmt.Errorf("%s: error listening: %w", *listenAddr, err) | |
| } | |
| log.Printf("%s: listening", *listenAddr) | |
| var closing int32 | |
| go func(ctx context.Context, l net.Listener) { | |
| <-ctx.Done() | |
| log.Printf("received shutdown signal") | |
| atomic.StoreInt32(&closing, 1) | |
| l.Close() | |
| }(ctx, l) | |
| var wg sync.WaitGroup | |
| for { | |
| conn, err := l.Accept() | |
| if err != nil { | |
| if atomic.LoadInt32(&closing) == 0 { | |
| log.Printf("%s: shutting down because of accept error: %v", *listenAddr, err) | |
| } | |
| break | |
| } | |
| wg.Go(func() { | |
| logger := log.New(log.Writer(), "["+conn.RemoteAddr().String()+"] ", log.Flags()) | |
| logger.Println("accepted connection") | |
| serve(conn, *upstreamAddr, frontendTLS, beTLS, logger) | |
| }) | |
| } | |
| cancel() | |
| wg.Wait() | |
| return nil | |
| } | |
| func serve(conn net.Conn, pgAddr string, feTLS *tls.Config, beTLS backendTLSPolicy, l *log.Logger) { | |
| defer conn.Close() | |
| defer l.Printf("closing connection") | |
| bg, err := net.Dial("tcp", pgAddr) | |
| if err != nil { | |
| l.Printf("%s: error dialing backend: %v", pgAddr, err) | |
| return | |
| } | |
| defer bg.Close() | |
| handles, err := negotiate(conn, bg, feTLS, beTLS, l) | |
| if err != nil { | |
| l.Printf("%s: error negitiating transport: %v", pgAddr, err) | |
| return | |
| } | |
| var wg sync.WaitGroup | |
| wg.Go(func() { serveClient(handles.ClientReader, handles.ServerWriter, l) }) | |
| wg.Go(func() { serveServer(handles.ServerReader, handles.ClientWriter, l) }) | |
| wg.Wait() | |
| } | |
| type negotiated struct { | |
| ClientReader, ServerReader io.ReadCloser | |
| ClientWriter, ServerWriter io.WriteCloser | |
| } | |
| var pgSSLRequest = []byte{0, 0, 0, 8, 0x04, 0xd2, 0x16, 0x2f} | |
| func negotiate(frontend, backend net.Conn, feTLS *tls.Config, beTLS backendTLSPolicy, l *log.Logger) (negotiated, error) { | |
| c, err := negotiateClient(frontend, feTLS, l) | |
| if err != nil { | |
| return negotiated{}, err | |
| } | |
| s, err := negotiateServer(backend, beTLS, l) | |
| if err != nil { | |
| return negotiated{}, err | |
| } | |
| return negotiated{ | |
| ClientReader: c, | |
| ClientWriter: c, | |
| ServerReader: s, | |
| ServerWriter: s, | |
| }, nil | |
| } | |
| func negotiateClient(conn net.Conn, feTLS *tls.Config, l *log.Logger) (io.ReadWriteCloser, error) { | |
| preamble := make([]byte, len(pgSSLRequest)) | |
| n, err := conn.Read(preamble) | |
| if err != nil { | |
| return nil, fmt.Errorf("error getting startup message: %v", err) | |
| } | |
| if !bytes.Equal(pgSSLRequest, preamble[:n]) { | |
| return prepended(preamble[:n], conn), nil | |
| } | |
| if feTLS == nil { | |
| l.Printf("--- SSL requested by client, declining ---") | |
| if _, err := conn.Write([]byte("N")); err != nil { | |
| return nil, err | |
| } | |
| return conn, nil | |
| } | |
| l.Printf("--- performing SSL handshake with client ---") | |
| if _, err := conn.Write([]byte("S")); err != nil { | |
| return nil, err | |
| } | |
| return tls.Server(conn, feTLS), nil | |
| } | |
| func negotiateServer(conn net.Conn, beTLS backendTLSPolicy, l *log.Logger) (io.ReadWriteCloser, error) { | |
| if !beTLS.Request { | |
| return conn, nil | |
| } | |
| if _, err := conn.Write(pgSSLRequest); err != nil { | |
| return nil, err | |
| } | |
| var resp [1]byte | |
| n, err := conn.Read(resp[:]) | |
| if err != nil { | |
| return nil, err | |
| } | |
| if n != 1 { | |
| return nil, fmt.Errorf("error reading response to SSLRequest") | |
| } | |
| switch resp[0] { | |
| case 'N': | |
| if beTLS.Require { | |
| return nil, fmt.Errorf("TLS is required but backend does not allow it") | |
| } | |
| l.Println("--- server declined SSL ---") | |
| return conn, nil | |
| case 'S': | |
| l.Println("--- performing SSL handshake with server ---") | |
| return tls.Client(conn, beTLS.Config), nil | |
| default: | |
| return nil, fmt.Errorf("invalid response to SSLRequest %c(%02x)", resp[0], resp[0]) | |
| } | |
| } | |
| func prepended(preamble []byte, conn net.Conn) io.ReadWriteCloser { | |
| return &prepend{ | |
| r: io.MultiReader(bytes.NewReader(preamble), conn), | |
| w: conn, | |
| c: conn, | |
| } | |
| } | |
| type prepend struct { | |
| r io.Reader | |
| w io.Writer | |
| c io.Closer | |
| } | |
| func (p *prepend) Read(buf []byte) (int, error) { return p.r.Read(buf) } | |
| func (p *prepend) Write(buf []byte) (int, error) { return p.w.Write(buf) } | |
| func (p *prepend) Close() error { return p.c.Close() } | |
| func serveClient(client io.ReadCloser, server io.WriteCloser, l *log.Logger) { | |
| defer client.Close() | |
| defer server.Close() | |
| l = log.New(l.Writer(), l.Prefix()+"client->server ", l.Flags()) | |
| processMessages(client, server, l, frontend{}, 1) | |
| // Wait a moment before closing connections. | |
| time.Sleep(100 * time.Millisecond) | |
| } | |
| func serveServer(server io.ReadCloser, client io.WriteCloser, l *log.Logger) { | |
| defer server.Close() | |
| defer client.Close() | |
| l = log.New(l.Writer(), l.Prefix()+"server->client ", l.Flags()) | |
| processMessages(server, client, l, backend{}, 0) | |
| // Wait a moment before closing connections. | |
| time.Sleep(100 * time.Millisecond) | |
| } |
| package main | |
| import ( | |
| "fmt" | |
| "strings" | |
| ) | |
| type frontend struct{} | |
| func (frontend) getMessageFormatter(msgType byte, msgLength int) messageFormatter { | |
| // Just the "F" messages from https://www.postgresql.org/docs/current/protocol-message-formats.html | |
| switch msgType { | |
| case 0: | |
| return &startup{want: msgLength} | |
| case 'B': | |
| return &bind{want: msgLength} | |
| case 'D': | |
| return &describe{want: msgLength} | |
| case 'E': | |
| return &execute{want: msgLength} | |
| case 'P': | |
| return &parse{want: msgLength} | |
| case 'Q': | |
| return &query{want: msgLength} | |
| case 'S': | |
| return &syncMsg{want: msgLength} | |
| case 'X': | |
| return &terminate{want: msgLength} | |
| case 'p': | |
| return &gssResponse{want: msgLength} | |
| default: | |
| panic(fmt.Sprintf("'%c' (%d)", msgType, msgType)) | |
| } | |
| } | |
| type backend struct{} | |
| func (backend) getMessageFormatter(msgType byte, msgLength int) messageFormatter { | |
| // Just the "B" messages from https://www.postgresql.org/docs/current/protocol-message-formats.html | |
| switch msgType { | |
| case '1': | |
| return &parseComplete{want: msgLength} | |
| case '2': | |
| return &bindComplete{want: msgLength} | |
| case 'C': | |
| return &commandComplete{want: msgLength} | |
| case 'D': | |
| return &dataRow{want: msgLength} | |
| case 'E': | |
| return &errorResponse{want: msgLength} | |
| case 'K': | |
| return &backendKey{want: msgLength} | |
| case 'R': | |
| return &authenticationRequest{want: msgLength} | |
| case 'S': | |
| return ¶meterStatus{want: msgLength} | |
| case 'T': | |
| return &rowDescription{want: msgLength} | |
| case 'Z': | |
| return &readyForQuery{want: msgLength} | |
| case 't': | |
| return ¶meterDescription{want: msgLength} | |
| default: | |
| panic(fmt.Sprintf("'%c' (%d)", msgType, msgType)) | |
| } | |
| } | |
| type messageFormatter interface { | |
| Format([]byte) string | |
| } | |
| type startup struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *startup) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 5 { | |
| return fmt.Sprintf("StartupMessage too short raw=%x", m.have) | |
| } | |
| var sb strings.Builder | |
| protoVersion := byteOrder.Uint32(m.have) | |
| fmt.Fprintf(&sb, "StartupMessage proto=%d", protoVersion) | |
| if protoVersion != 196610 { | |
| fmt.Fprintf(&sb, "(!)") | |
| } | |
| p := m.have[4:] | |
| for { | |
| if len(p) == 1 && p[0] == 0 { | |
| break | |
| } | |
| if len(p) < 1 { | |
| fmt.Fprintf(&sb, " (unexpected end of parameters!)") | |
| break | |
| } | |
| str, rest, ok := decodeString(p) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (unexpected end of parameters remaining=%x)", p) | |
| break | |
| } | |
| fmt.Fprintf(&sb, " %s=", str) | |
| p = rest | |
| str, rest, ok = decodeString(p) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (unexpected end of parameters remaining=%x)", p) | |
| break | |
| } | |
| fmt.Fprintf(&sb, "%q", str) | |
| p = rest | |
| } | |
| return sb.String() | |
| } | |
| type bind struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *bind) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 4 { | |
| return fmt.Sprintf("Bind message too short raw=%x", m.have) | |
| } | |
| var sb strings.Builder | |
| fmt.Fprintf(&sb, "Bind") | |
| // Portal name | |
| portalName, rest, ok := decodeString(m.have) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (error decoding portal name) raw=%x", m.have) | |
| return sb.String() | |
| } | |
| if len(portalName) == 0 { | |
| fmt.Fprintf(&sb, " dstportal=<unnamed>") | |
| } else { | |
| fmt.Fprintf(&sb, " dstportal=%q", portalName) | |
| } | |
| p := rest | |
| // Statement name | |
| stmtName, rest, ok := decodeString(p) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (error decoding statement name) rest=%x", p) | |
| return sb.String() | |
| } | |
| if len(stmtName) == 0 { | |
| fmt.Fprintf(&sb, " stmt=<unnamed>") | |
| } else { | |
| fmt.Fprintf(&sb, " stmt=%q", stmtName) | |
| } | |
| p = rest | |
| // Parameter format codes | |
| if len(p) < 2 { | |
| fmt.Fprintf(&sb, " (missing param format count) rest=%x", p) | |
| return sb.String() | |
| } | |
| numParamFormats := byteOrder.Uint16(p[0:2]) | |
| p = p[2:] | |
| if len(p) < int(numParamFormats)*2 { | |
| fmt.Fprintf(&sb, " (incomplete param formats) rest=%x", p) | |
| return sb.String() | |
| } | |
| paramFormats := make([]uint16, numParamFormats) | |
| for i := 0; i < int(numParamFormats); i++ { | |
| paramFormats[i] = byteOrder.Uint16(p[i*2 : (i+1)*2]) | |
| } | |
| p = p[numParamFormats*2:] | |
| // Parameter values | |
| if len(p) < 2 { | |
| fmt.Fprintf(&sb, " (missing param value count) rest=%x", p) | |
| return sb.String() | |
| } | |
| numParams := byteOrder.Uint16(p[0:2]) | |
| p = p[2:] | |
| fmt.Fprintf(&sb, " params=%d", numParams) | |
| for i := 0; i < int(numParams); i++ { | |
| if len(p) < 4 { | |
| fmt.Fprintf(&sb, " (incomplete param %d length) rest=%x", i, p) | |
| return sb.String() | |
| } | |
| paramLen := int32(byteOrder.Uint32(p[0:4])) | |
| p = p[4:] | |
| switch { | |
| case paramLen == -1: | |
| fmt.Fprintf(&sb, " NULL") | |
| case paramLen < 0: | |
| fmt.Fprintf(&sb, " (invalid param length %d)", paramLen) | |
| return sb.String() | |
| default: | |
| if len(p) < int(paramLen) { | |
| fmt.Fprintf(&sb, " (incomplete param %d data) rest=%x", i, p) | |
| return sb.String() | |
| } | |
| paramData := p[:paramLen] | |
| p = p[paramLen:] | |
| if isPrintable(paramData) { | |
| fmt.Fprintf(&sb, " %q", string(paramData)) | |
| } else { | |
| fmt.Fprintf(&sb, " [%x]", paramData) | |
| } | |
| } | |
| } | |
| // Result format codes | |
| if len(p) < 2 { | |
| fmt.Fprintf(&sb, " (missing result format count) rest=%x", p) | |
| return sb.String() | |
| } | |
| numResultFormats := byteOrder.Uint16(p[0:2]) | |
| p = p[2:] | |
| if len(p) < int(numResultFormats)*2 { | |
| fmt.Fprintf(&sb, " (incomplete result formats) rest=%x", p) | |
| return sb.String() | |
| } | |
| if numResultFormats > 0 { | |
| fmt.Fprintf(&sb, " result_fmts=[") | |
| for i := 0; i < int(numResultFormats); i++ { | |
| if i > 0 { | |
| fmt.Fprintf(&sb, ",") | |
| } | |
| rf := byteOrder.Uint16(p[i*2 : (i+1)*2]) | |
| switch rf { | |
| case 0: | |
| fmt.Fprint(&sb, "text") | |
| case 1: | |
| fmt.Fprint(&sb, "binary") | |
| default: | |
| fmt.Fprintf(&sb, "invalid(%d)", rf) | |
| } | |
| } | |
| fmt.Fprintf(&sb, "]") | |
| p = p[numResultFormats*2:] | |
| } | |
| if len(p) > 0 { | |
| fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p) | |
| } | |
| return sb.String() | |
| } | |
| type describe struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *describe) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 2 { | |
| return fmt.Sprintf("Describe message too short raw=%x", m.have) | |
| } | |
| var sb strings.Builder | |
| fmt.Fprintf(&sb, "Describe") | |
| objType := m.have[0] | |
| switch objType { | |
| case 'S': | |
| fmt.Fprintf(&sb, " statement") | |
| case 'P': | |
| fmt.Fprintf(&sb, " portal") | |
| default: | |
| fmt.Fprintf(&sb, " ??(%x %c)", objType, objType) | |
| } | |
| name, rest, ok := decodeString(m.have[1:]) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (error decoding name) rest=%x", m.have[1:]) | |
| return sb.String() | |
| } | |
| if len(name) == 0 { | |
| fmt.Fprintf(&sb, " <unnamed>") | |
| } else { | |
| fmt.Fprintf(&sb, " %q", name) | |
| } | |
| if len(rest) > 0 { | |
| fmt.Fprintf(&sb, " (unexpected trailing data: %x)", rest) | |
| } | |
| return sb.String() | |
| } | |
| type execute struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *execute) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 5 { | |
| return fmt.Sprintf("Execute message too short raw=%x", m.have) | |
| } | |
| // Portal name | |
| portalName, rest, ok := decodeString(m.have) | |
| if !ok { | |
| return fmt.Sprintf("Execute (error decoding portal name) raw=%x", m.have) | |
| } | |
| if len(rest) < 4 { | |
| return fmt.Sprintf("Execute (missing max rows) rest=%x", rest) | |
| } | |
| maxRows := byteOrder.Uint32(rest[0:4]) | |
| rest = rest[4:] | |
| var sb strings.Builder | |
| fmt.Fprintf(&sb, "Execute") | |
| if len(portalName) == 0 { | |
| fmt.Fprintf(&sb, " portal=<unnamed>") | |
| } else { | |
| fmt.Fprintf(&sb, " portal=%q", portalName) | |
| } | |
| if maxRows == 0 { | |
| fmt.Fprintf(&sb, " maxrows=unlimited") | |
| } else { | |
| fmt.Fprintf(&sb, " maxrows=%d", maxRows) | |
| } | |
| if len(rest) > 0 { | |
| fmt.Fprintf(&sb, " (unexpected trailing data: %x)", rest) | |
| } | |
| return sb.String() | |
| } | |
| type parse struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *parse) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 3 { | |
| return fmt.Sprintf("Parse message too short raw=%x", m.have) | |
| } | |
| var sb strings.Builder | |
| fmt.Fprintf(&sb, "Parse") | |
| // Statement name | |
| stmtName, rest, ok := decodeString(m.have) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (error decoding statement name) raw=%x", m.have) | |
| return sb.String() | |
| } | |
| if len(stmtName) == 0 { | |
| fmt.Fprintf(&sb, " stmt=<unnamed>") | |
| } else { | |
| fmt.Fprintf(&sb, " stmt=%q", stmtName) | |
| } | |
| p := rest | |
| // Query string | |
| queryStr, rest, ok := decodeString(p) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (error decoding query) rest=%x", p) | |
| return sb.String() | |
| } | |
| fmt.Fprintf(&sb, " query=%q", queryStr) | |
| p = rest | |
| // Number of parameter types | |
| if len(p) < 2 { | |
| fmt.Fprintf(&sb, " (missing param count) rest=%x", p) | |
| return sb.String() | |
| } | |
| numParams := byteOrder.Uint16(p[0:2]) | |
| p = p[2:] | |
| fmt.Fprintf(&sb, " params=%d", numParams) | |
| if numParams > 0 { | |
| // Parameter type OIDs | |
| if len(p) < int(numParams)*4 { | |
| fmt.Fprintf(&sb, " (incomplete param types, need %d bytes) rest=%x", numParams*4, p) | |
| return sb.String() | |
| } | |
| fmt.Fprintf(&sb, " types=[") | |
| for i := 0; i < int(numParams); i++ { | |
| if i > 0 { | |
| fmt.Fprintf(&sb, ",") | |
| } | |
| typeOID := byteOrder.Uint32(p[i*4 : (i+1)*4]) | |
| if typeOID == 0 { | |
| fmt.Fprintf(&sb, "?") | |
| } else { | |
| fmt.Fprintf(&sb, "%d", typeOID) | |
| } | |
| } | |
| fmt.Fprintf(&sb, "]") | |
| p = p[numParams*4:] | |
| } | |
| if len(p) > 0 { | |
| fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p) | |
| } | |
| return sb.String() | |
| } | |
| type parseComplete struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *parseComplete) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want != 0 { | |
| return fmt.Sprintf("ParseComplete (unexpected payload: %x)", m.have) | |
| } | |
| return "ParseComplete" | |
| } | |
| type bindComplete struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *bindComplete) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want != 0 { | |
| return fmt.Sprintf("BindComplete (unexpected payload: %x)", m.have) | |
| } | |
| return "BindComplete" | |
| } | |
| type commandComplete struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *commandComplete) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 1 { | |
| return fmt.Sprintf("CommandComplete message too short raw=%x", m.have) | |
| } | |
| tag, rest, ok := decodeString(m.have) | |
| if !ok { | |
| return fmt.Sprintf("CommandComplete invalid tag raw=%x", m.have) | |
| } | |
| if len(rest) > 0 { | |
| return fmt.Sprintf("CommandComplete %q (unexpected trailing data: %x)", tag, rest) | |
| } | |
| return fmt.Sprintf("CommandComplete %q", tag) | |
| } | |
| type dataRow struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *dataRow) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 2 { | |
| return fmt.Sprintf("DataRow message too short raw=%x", m.have) | |
| } | |
| numCols := byteOrder.Uint16(m.have) | |
| var sb strings.Builder | |
| fmt.Fprintf(&sb, "DataRow cols=%d", numCols) | |
| p := m.have[2:] | |
| for i := 0; i < int(numCols); i++ { | |
| if len(p) < 4 { | |
| fmt.Fprintf(&sb, " (incomplete column %d, need length field) rest=%x", i, p) | |
| break | |
| } | |
| colLen := int32(byteOrder.Uint32(p[0:4])) | |
| p = p[4:] | |
| if colLen == -1 { | |
| fmt.Fprintf(&sb, " NULL") | |
| continue | |
| } | |
| if colLen < 0 { | |
| fmt.Fprintf(&sb, " (invalid length %d) rest=%x", colLen, p) | |
| break | |
| } | |
| if len(p) < int(colLen) { | |
| fmt.Fprintf(&sb, " (incomplete column %d, need %d bytes) rest=%x", i, colLen, p) | |
| break | |
| } | |
| colData := p[:colLen] | |
| p = p[colLen:] | |
| // Try to display as string if it looks printable | |
| if isPrintable(colData) { | |
| fmt.Fprintf(&sb, " %q", string(colData)) | |
| } else { | |
| fmt.Fprintf(&sb, " [%x]", colData) | |
| } | |
| } | |
| if len(p) > 0 { | |
| fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p) | |
| } | |
| return sb.String() | |
| } | |
| func isPrintable(data []byte) bool { | |
| for _, b := range data { | |
| if b < 32 && b != '\t' && b != '\n' && b != '\r' { | |
| return false | |
| } | |
| if b > 126 { | |
| return false | |
| } | |
| } | |
| return true | |
| } | |
| type errorResponse struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *errorResponse) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 1 { | |
| return fmt.Sprintf("ErrorResponse message too short raw=%x", m.have) | |
| } | |
| var sb strings.Builder | |
| fmt.Fprintf(&sb, "ErrorResponse") | |
| p := m.have | |
| for { | |
| if len(p) == 0 { | |
| fmt.Fprintf(&sb, " (missing terminator)") | |
| break | |
| } | |
| fieldType := p[0] | |
| p = p[1:] | |
| if fieldType == 0 { | |
| // Terminator | |
| break | |
| } | |
| fieldValue, rest, ok := decodeString(p) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (error decoding field %c) rest=%x", fieldType, p) | |
| break | |
| } | |
| p = rest | |
| switch fieldType { | |
| case 'S': | |
| fmt.Fprintf(&sb, " severity=%q", fieldValue) | |
| case 'V': | |
| fmt.Fprintf(&sb, " severity_nonlocal=%q", fieldValue) | |
| case 'C': | |
| fmt.Fprintf(&sb, " code=%q", fieldValue) | |
| case 'M': | |
| fmt.Fprintf(&sb, " message=%q", fieldValue) | |
| case 'D': | |
| fmt.Fprintf(&sb, " detail=%q", fieldValue) | |
| case 'H': | |
| fmt.Fprintf(&sb, " hint=%q", fieldValue) | |
| case 'P': | |
| fmt.Fprintf(&sb, " position=%q", fieldValue) | |
| case 'p': | |
| fmt.Fprintf(&sb, " internal_position=%q", fieldValue) | |
| case 'q': | |
| fmt.Fprintf(&sb, " internal_query=%q", fieldValue) | |
| case 'W': | |
| fmt.Fprintf(&sb, " where=%q", fieldValue) | |
| case 's': | |
| fmt.Fprintf(&sb, " schema=%q", fieldValue) | |
| case 't': | |
| fmt.Fprintf(&sb, " table=%q", fieldValue) | |
| case 'c': | |
| fmt.Fprintf(&sb, " column=%q", fieldValue) | |
| case 'd': | |
| fmt.Fprintf(&sb, " datatype=%q", fieldValue) | |
| case 'n': | |
| fmt.Fprintf(&sb, " constraint=%q", fieldValue) | |
| case 'F': | |
| fmt.Fprintf(&sb, " file=%q", fieldValue) | |
| case 'L': | |
| fmt.Fprintf(&sb, " line=%q", fieldValue) | |
| case 'R': | |
| fmt.Fprintf(&sb, " routine=%q", fieldValue) | |
| default: | |
| fmt.Fprintf(&sb, " %x=%q", fieldType, fieldValue) | |
| } | |
| } | |
| if len(p) > 0 { | |
| fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p) | |
| } | |
| return sb.String() | |
| } | |
| type backendKey struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *backendKey) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 4 { | |
| return fmt.Sprintf("BackendKeyData too short raw=%x", m.have) | |
| } | |
| pid := byteOrder.Uint32(m.have) | |
| return fmt.Sprintf("BackendKeyData pid=%d secret=%x", pid, m.have[4:]) | |
| } | |
| type query struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *query) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 1 { | |
| return fmt.Sprintf("Query message too short raw=%x", m.have) | |
| } | |
| queryString, rest, ok := decodeString(m.have) | |
| if !ok { | |
| return fmt.Sprintf("Query invalid string raw=%x", m.have) | |
| } | |
| if len(rest) > 0 { | |
| return fmt.Sprintf("Query %q (unexpected trailing data: %x)", queryString, rest) | |
| } | |
| return fmt.Sprintf("Query %q", queryString) | |
| } | |
| type syncMsg struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *syncMsg) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want != 0 { | |
| return fmt.Sprintf("Sync (unexpected payload: %x)", m.have) | |
| } | |
| return "Sync" | |
| } | |
| type authenticationRequest struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *authenticationRequest) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 4 { | |
| return fmt.Sprintf("Authentication message too short raw=%x", m.have) | |
| } | |
| code := byteOrder.Uint32(m.have) | |
| switch { | |
| case m.want == 4 && code == 0: | |
| return "AuthenticationOk" | |
| case m.want == 4 && code == 2: | |
| return "AuthenticationKerberosV5" | |
| case m.want == 4 && code == 3: | |
| return "AuthenticationCleartextPassword" | |
| case m.want == 8 && code == 5: | |
| return fmt.Sprintf("AuthenticationMD5Password salt=%x", m.have[4:]) | |
| case m.want == 4 && code == 7: | |
| return "AuthenticationGSS" | |
| case code == 8: | |
| return fmt.Sprintf("AuthenticationGSSContinue data=%x", m.have[4:]) | |
| case m.want == 4 && code == 9: | |
| return "AuthenticationSSPI" | |
| case code == 10: | |
| name, _, ok := decodeString(m.have[4:]) | |
| if ok { | |
| return fmt.Sprintf("AuthenticationSASL name=%q", name) | |
| } else { | |
| return fmt.Sprintf("AuthenticationSASL invalid name=%q", m.have[4:]) | |
| } | |
| case code == 11: | |
| return fmt.Sprintf("AuthenticationSASLContinue data=%x", m.have[4:]) | |
| case code == 12: | |
| return fmt.Sprintf("AuthenticationSASLFinal data=%x", m.have[4:]) | |
| default: | |
| return fmt.Sprintf("unrecognized Authentication message len=%d code=%d payload=%x", | |
| m.want, code, m.have[4:]) | |
| } | |
| } | |
| type parameterStatus struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *parameterStatus) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 2 { | |
| return fmt.Sprintf("ParameterStatus message too short raw=%x", m.have) | |
| } | |
| name, rest, ok := decodeString(m.have) | |
| if !ok { | |
| return fmt.Sprintf("ParameterStatus message invalid name raw=%x", m.have) | |
| } | |
| value, rest, ok := decodeString(rest) | |
| if !ok { | |
| return fmt.Sprintf("ParameterStatus %s= invalid value raw=%x", name, rest) | |
| } | |
| return fmt.Sprintf("ParameterStatus %s=%q", name, value) | |
| } | |
| type rowDescription struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *rowDescription) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 2 { | |
| return fmt.Sprintf("RowDescription message too short raw=%x", m.have) | |
| } | |
| numFields := byteOrder.Uint16(m.have) | |
| var sb strings.Builder | |
| fmt.Fprintf(&sb, "RowDescription fields=%d", numFields) | |
| p := m.have[2:] | |
| for i := 0; i < int(numFields); i++ { | |
| // Field name | |
| name, rest, ok := decodeString(p) | |
| if !ok { | |
| fmt.Fprintf(&sb, " (error decoding field %d name)", i) | |
| break | |
| } | |
| fmt.Fprintf(&sb, " [%q", name) | |
| p = rest | |
| // Need at least 18 bytes for the remaining fields | |
| if len(p) < 18 { | |
| fmt.Fprintf(&sb, " incomplete] rest=%x", p) | |
| break | |
| } | |
| tableOID := byteOrder.Uint32(p[0:4]) | |
| colAttrNum := byteOrder.Uint16(p[4:6]) | |
| typeOID := byteOrder.Uint32(p[6:10]) | |
| typeSize := int16(byteOrder.Uint16(p[10:12])) | |
| typeMod := byteOrder.Uint32(p[12:16]) | |
| formatCode := byteOrder.Uint16(p[16:18]) | |
| fmt.Fprintf(&sb, " table=%d col=%d typeOID=%d", tableOID, colAttrNum, typeOID) | |
| if typeSize < 0 { | |
| fmt.Fprintf(&sb, " size=var(%d)", typeSize) | |
| } else { | |
| fmt.Fprintf(&sb, " size=%d", typeSize) | |
| } | |
| fmt.Fprintf(&sb, " mod=%08x", typeMod) | |
| switch formatCode { | |
| case 0: | |
| fmt.Fprintf(&sb, " fmt=text") | |
| case 1: | |
| fmt.Fprintf(&sb, " fmt=binary") | |
| default: | |
| fmt.Fprintf(&sb, " fmt=!!%d!!", formatCode) | |
| } | |
| fmt.Fprintf(&sb, "]") | |
| p = p[18:] | |
| } | |
| if len(p) > 0 { | |
| fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p) | |
| } | |
| return sb.String() | |
| } | |
| type terminate struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *terminate) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want != 0 { | |
| return fmt.Sprintf("Terminate (unexpected payload: %x)", m.have) | |
| } | |
| return "Terminate" | |
| } | |
| type readyForQuery struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *readyForQuery) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 1 { | |
| return fmt.Sprintf("ReadyForQuery message too short raw=%x", m.have) | |
| } | |
| status := m.have[0] | |
| switch status { | |
| case 'I': | |
| return "ReadyForQuery (idle)" | |
| case 'T': | |
| return "ReadyForQuery (in a transaction block)" | |
| case 'E': | |
| return "ReadyForQuery (in a vailed transaction block)" | |
| default: | |
| return fmt.Sprintf("ReadyForQuery status=unknown(%c/%x)", status, status) | |
| } | |
| } | |
| type parameterDescription struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *parameterDescription) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| if m.want < 2 { | |
| return fmt.Sprintf("ParameterDescription message too short raw=%x", m.have) | |
| } | |
| numParams := byteOrder.Uint16(m.have) | |
| var sb strings.Builder | |
| fmt.Fprintf(&sb, "ParameterDescription params=%d", numParams) | |
| p := m.have[2:] | |
| if numParams > 0 { | |
| if len(p) < int(numParams)*4 { | |
| fmt.Fprintf(&sb, " (incomplete, need %d bytes) rest=%x", numParams*4, p) | |
| return sb.String() | |
| } | |
| fmt.Fprintf(&sb, " types=[") | |
| for i := 0; i < int(numParams); i++ { | |
| if i > 0 { | |
| fmt.Fprintf(&sb, ",") | |
| } | |
| typeOID := byteOrder.Uint32(p[i*4 : (i+1)*4]) | |
| fmt.Fprintf(&sb, "%d", typeOID) | |
| } | |
| fmt.Fprintf(&sb, "]") | |
| p = p[numParams*4:] | |
| } | |
| if len(p) > 0 { | |
| fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p) | |
| } | |
| return sb.String() | |
| } | |
| type gssResponse struct { | |
| want int | |
| have []byte | |
| } | |
| func (m *gssResponse) Format(data []byte) string { | |
| m.have = append(m.have, data...) | |
| if len(m.have) < m.want { | |
| return "..." | |
| } | |
| return fmt.Sprintf("GSSResponse data=%x", m.have) | |
| } |
| package main | |
| import ( | |
| "bytes" | |
| "encoding/binary" | |
| "io" | |
| "log" | |
| ) | |
| var byteOrder = binary.BigEndian | |
| type side interface { | |
| getMessageFormatter(msgType byte, msgLen int) messageFormatter | |
| } | |
| // processMessages does some introspection of messages received from r before | |
| // forwarding them to w. | |
| // | |
| // hi should be 1 for client->server and 0 for server->client. | |
| func processMessages(r io.Reader, w io.Writer, l *log.Logger, side side, hi int) { | |
| buf := make([]byte, 1024*1024) | |
| const maxHeaderSize = 5 | |
| var hdr [maxHeaderSize]byte | |
| bodyRemaining := 0 | |
| var pretty messageFormatter | |
| for { | |
| n, err := r.Read(buf) | |
| if err != nil { | |
| l.Printf("error reading data: %v", err) | |
| return | |
| } | |
| for p := buf[:n]; len(p) > 0; { | |
| switch { | |
| case bodyRemaining == 0: | |
| // we need more header bytes | |
| copied := copy(hdr[hi:], p) | |
| hi += copied | |
| p = p[copied:] | |
| if hi == maxHeaderSize { | |
| // fall through | |
| msgType := hdr[0] | |
| length := int(byteOrder.Uint32(hdr[1:])) | |
| l.Printf("MSG '%c' (%x) (len=%d)", msgType, msgType, length-4) | |
| // SSL mode is []uint16{1234,5678} or []byte{0x04, 0xd2, 0x16, 0x2f}. | |
| if msgType == 0 && length == 8 && bytes.Equal(p[:4], []byte{0x04, 0xd2, 0x16, 0x2f}) { | |
| l.Printf("SSL mode is not supported") | |
| return | |
| } | |
| bodyRemaining = length - 4 | |
| hi = 0 | |
| pretty = side.getMessageFormatter(msgType, bodyRemaining) | |
| if bodyRemaining == 0 { | |
| l.Printf(" %s", pretty.Format(nil)) | |
| } | |
| } | |
| case bodyRemaining > len(p): | |
| l.Printf(" %s", pretty.Format(p)) | |
| bodyRemaining -= len(p) | |
| p = nil | |
| case bodyRemaining > 0: | |
| l.Printf(" %s", pretty.Format(p[:bodyRemaining])) | |
| p = p[bodyRemaining:] | |
| bodyRemaining = 0 | |
| default: | |
| l.Printf("invalid state, bailing") | |
| } | |
| } | |
| l.Printf("write(%d)", n) | |
| if _, err := w.Write(buf[:n]); err != nil { | |
| l.Printf("error writing data: %v", err) | |
| return | |
| } | |
| } | |
| } | |
| // decodeString decodes a Postgres 'String' atom. | |
| // | |
| // > A null-terminated string (C-style string). There is no specific length | |
| // > limitation on strings. If s is specified it is the exact value that will | |
| // > appear, otherwise the value is variable. Eg. String, String("user"). | |
| // | |
| // returns (string, rest, true) if found, (nil, data, false) if not. | |
| func decodeString(data []byte) ([]byte, []byte, bool) { | |
| i := bytes.Index(data, []byte{0}) | |
| if i == -1 { | |
| return nil, data, false | |
| } | |
| return data[:i], data[i+1:], true | |
| } |
| package main | |
| import ( | |
| "crypto/rand" | |
| "crypto/rsa" | |
| "crypto/tls" | |
| "crypto/x509" | |
| "crypto/x509/pkix" | |
| "encoding/pem" | |
| "fmt" | |
| "math/big" | |
| "net" | |
| "os" | |
| "time" | |
| ) | |
| func initSSL(certPath, keyPath string) (*tls.Config, error) { | |
| certData, certErr := os.ReadFile(certPath) | |
| keyData, keyErr := os.ReadFile(keyPath) | |
| switch { | |
| case os.IsNotExist(certErr) && os.IsNotExist(keyErr): | |
| // fall through to the code after the switch, which generates a new cert and key. | |
| case certErr != nil: | |
| return nil, fmt.Errorf("%s: %v", certPath, certErr) | |
| case keyErr != nil: | |
| return nil, fmt.Errorf("%s: %v", keyPath, keyErr) | |
| default: | |
| cert, err := tls.X509KeyPair(certData, keyData) | |
| if err != nil { | |
| return nil, err | |
| } | |
| return &tls.Config{ | |
| Certificates: []tls.Certificate{cert}, | |
| }, nil | |
| } | |
| // Generate self-signed certificate | |
| priv, err := rsa.GenerateKey(rand.Reader, 2048) | |
| if err != nil { | |
| return nil, fmt.Errorf("error generating private key: %w", err) | |
| } | |
| notBefore := time.Now() | |
| notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year | |
| serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) | |
| if err != nil { | |
| return nil, fmt.Errorf("error generating serial number: %w", err) | |
| } | |
| template := x509.Certificate{ | |
| SerialNumber: serialNumber, | |
| Subject: pkix.Name{ | |
| Organization: []string{"PostgreSQL MITM Proxy"}, | |
| CommonName: "localhost", | |
| }, | |
| NotBefore: notBefore, | |
| NotAfter: notAfter, | |
| KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, | |
| ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, | |
| BasicConstraintsValid: true, | |
| DNSNames: []string{"localhost"}, | |
| IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, | |
| } | |
| derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv) | |
| if err != nil { | |
| return nil, fmt.Errorf("error creating certificate: %w", err) | |
| } | |
| // Save certificate | |
| certOut, err := os.Create(certPath) | |
| if err != nil { | |
| return nil, fmt.Errorf("error creating cert file: %w", err) | |
| } | |
| defer certOut.Close() | |
| if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil { | |
| return nil, fmt.Errorf("error writing cert: %w", err) | |
| } | |
| // Save private key | |
| keyOut, err := os.Create(keyPath) | |
| if err != nil { | |
| return nil, fmt.Errorf("error creating key file: %w", err) | |
| } | |
| defer keyOut.Close() | |
| privBytes, err := x509.MarshalPKCS8PrivateKey(priv) | |
| if err != nil { | |
| return nil, fmt.Errorf("error marshaling private key: %w", err) | |
| } | |
| if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil { | |
| return nil, fmt.Errorf("error writing key: %w", err) | |
| } | |
| // Load the newly created certificate | |
| cert, err := tls.LoadX509KeyPair(certPath, keyPath) | |
| if err != nil { | |
| return nil, fmt.Errorf("error loading generated certificate: %w", err) | |
| } | |
| return &tls.Config{ | |
| Certificates: []tls.Certificate{cert}, | |
| }, nil | |
| } | |
| type backendTLSPolicy struct { | |
| Request bool | |
| Require bool | |
| Config *tls.Config | |
| } | |
| func configureBackendTLS(mode, servername string) (backendTLSPolicy, bool) { | |
| switch mode { | |
| case "disable": | |
| return backendTLSPolicy{}, true | |
| case "allow": | |
| return backendTLSPolicy{Request: true, Config: strictBETLSConfig(servername)}, true | |
| case "require": | |
| return backendTLSPolicy{Request: true, Require: true, Config: strictBETLSConfig(servername)}, true | |
| case "insecure": | |
| return backendTLSPolicy{Request: true, Config: insecureBETLSConfig()}, true | |
| default: | |
| return backendTLSPolicy{}, false | |
| } | |
| } | |
| func strictBETLSConfig(servername string) *tls.Config { | |
| return &tls.Config{ | |
| ServerName: servername, | |
| } | |
| } | |
| func insecureBETLSConfig() *tls.Config { | |
| return &tls.Config{ | |
| InsecureSkipVerify: true, | |
| } | |
| } |
| #!/bin/sh | |
| for m in "sslmode=prefer" "sslmode=require" "sslmode=disable"; do | |
| echo "==== $m ====" | |
| psql "host=127.0.0.1 port=5555 user=postgres password=xxx $m" -c 'select * from hi' | |
| done |