Skip to content

Instantly share code, notes, and snippets.

@spraints
Last active December 22, 2025 19:16
Show Gist options
  • Select an option

  • Save spraints/2527a589757f0ea33f9f2995ac776cd0 to your computer and use it in GitHub Desktop.

Select an option

Save spraints/2527a589757f0ea33f9f2995ac776cd0 to your computer and use it in GitHub Desktop.
A really simple postgres protocol analyzer
test.crt
test.key

Simple postgres protocol analyzer

Synopsis:

(Assume postgres is running on localhost.)

$ go run . &

$ psql "host=127.0.0.1 port=5555 sslmode=disable user=... password=..."
(protocol details spew)

postgres=> select 1;
(more protocol details spew)

Instructions for Adding PostgreSQL Message Types

Overview

This document describes how to add implementations for PostgreSQL protocol message types in messages.go.

Reference

Message types are defined in the PostgreSQL documentation: https://www.postgresql.org/docs/current/protocol-message-formats.html

Implementation Guidelines

1. Switch Statement Ordering

In the getMessageFormatter() function, add new message type cases in alphabetical order:

  • All uppercase letters before lowercase letters
  • Within each group (uppercase/lowercase), maintain alphabetical order

Example ordering:

case 'K':  // uppercase
case 'Q':
case 'R':
case 'S':
case 'T':
case 'X':
case 'Z':
case 'p':  // lowercase

2. Type Definition Ordering

Define the message type structs in the same order as they appear in the switch statement.

Each message type struct follows this pattern:

type messageName struct {
    want int
    have []byte
}

3. Format Method

Each message type must implement the messageFormatter interface by providing a Format([]byte) string method.

The Format method should:

  1. Accumulate incoming data: m.have = append(m.have, data...)
  2. Check if all data has arrived: if len(m.have) < m.want { return "..." }
  3. Validate the message length
  4. Parse and format the message content according to the protocol specification
  5. Handle error cases gracefully (invalid data, unexpected trailing bytes, etc.)

4. Common Patterns

  • Use decodeString() for null-terminated strings
  • Use byteOrder.UintXX() for reading multi-byte integers
  • Use strings.Builder for constructing complex output
  • Report hex dumps for raw/invalid data: %x
  • Quote strings in output: %q

Example Implementation

See existing message types in messages.go for examples:

  • Simple messages: terminate (no payload), readyForQuery (single byte)
  • String messages: query, parameterStatus
  • Complex messages: rowDescription, authenticationRequest
module pgmitm
go 1.25.0
require github.com/spf13/pflag v1.0.10
github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk=
github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
// Synopsis:
//
// $ go run . --listen 127.0.0.1:5555 --dial 127.0.0.1:5432
//
// $ psql "host=127.0.0.1 port=5555 sslmode=disable user=xx password=yy"
package main
import (
"bytes"
"context"
"crypto/tls"
"fmt"
"io"
"log"
"net"
"os"
"os/signal"
"strings"
"sync"
"sync/atomic"
"syscall"
"time"
"github.com/spf13/pflag"
)
func main() {
if err := mainImpl(); err != nil {
log.Printf("fatal error: %v", err)
os.Exit(1)
}
}
func mainImpl() error {
ctx, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
listenAddr := pflag.String("listen", "127.0.0.1:5555", "address to listen on")
upstreamAddr := pflag.String("dial", "127.0.0.1:5432", "address to forward connections to")
sslCertFile := pflag.String("ssl-cert", "", "path to an SSL certificate to use for the server (will be generated if the path is present but the file is not)")
sslKeyFile := pflag.String("ssl-key", "", "path to an SSL key to use for the server (will be generated if the path is present but the file is not)")
beTLSModeStr := pflag.String("tls-mode", "allow", "what to do about TLS on the backend (disable, require, allow, insecure)")
pflag.Parse()
var frontendTLS *tls.Config
if *sslCertFile != "" && *sslKeyFile != "" {
if c, err := initSSL(*sslCertFile, *sslKeyFile); err != nil {
return err
} else {
log.Printf("TLS enabled")
frontendTLS = c
}
}
beTLS, ok := configureBackendTLS(*beTLSModeStr, strings.Split(*upstreamAddr, ":")[0])
if !ok {
return fmt.Errorf("invalid tls-mode %q", *beTLSModeStr)
}
l, err := net.Listen("tcp", *listenAddr)
if err != nil {
return fmt.Errorf("%s: error listening: %w", *listenAddr, err)
}
log.Printf("%s: listening", *listenAddr)
var closing int32
go func(ctx context.Context, l net.Listener) {
<-ctx.Done()
log.Printf("received shutdown signal")
atomic.StoreInt32(&closing, 1)
l.Close()
}(ctx, l)
var wg sync.WaitGroup
for {
conn, err := l.Accept()
if err != nil {
if atomic.LoadInt32(&closing) == 0 {
log.Printf("%s: shutting down because of accept error: %v", *listenAddr, err)
}
break
}
wg.Go(func() {
logger := log.New(log.Writer(), "["+conn.RemoteAddr().String()+"] ", log.Flags())
logger.Println("accepted connection")
serve(conn, *upstreamAddr, frontendTLS, beTLS, logger)
})
}
cancel()
wg.Wait()
return nil
}
func serve(conn net.Conn, pgAddr string, feTLS *tls.Config, beTLS backendTLSPolicy, l *log.Logger) {
defer conn.Close()
defer l.Printf("closing connection")
bg, err := net.Dial("tcp", pgAddr)
if err != nil {
l.Printf("%s: error dialing backend: %v", pgAddr, err)
return
}
defer bg.Close()
handles, err := negotiate(conn, bg, feTLS, beTLS, l)
if err != nil {
l.Printf("%s: error negitiating transport: %v", pgAddr, err)
return
}
var wg sync.WaitGroup
wg.Go(func() { serveClient(handles.ClientReader, handles.ServerWriter, l) })
wg.Go(func() { serveServer(handles.ServerReader, handles.ClientWriter, l) })
wg.Wait()
}
type negotiated struct {
ClientReader, ServerReader io.ReadCloser
ClientWriter, ServerWriter io.WriteCloser
}
var pgSSLRequest = []byte{0, 0, 0, 8, 0x04, 0xd2, 0x16, 0x2f}
func negotiate(frontend, backend net.Conn, feTLS *tls.Config, beTLS backendTLSPolicy, l *log.Logger) (negotiated, error) {
c, err := negotiateClient(frontend, feTLS, l)
if err != nil {
return negotiated{}, err
}
s, err := negotiateServer(backend, beTLS, l)
if err != nil {
return negotiated{}, err
}
return negotiated{
ClientReader: c,
ClientWriter: c,
ServerReader: s,
ServerWriter: s,
}, nil
}
func negotiateClient(conn net.Conn, feTLS *tls.Config, l *log.Logger) (io.ReadWriteCloser, error) {
preamble := make([]byte, len(pgSSLRequest))
n, err := conn.Read(preamble)
if err != nil {
return nil, fmt.Errorf("error getting startup message: %v", err)
}
if !bytes.Equal(pgSSLRequest, preamble[:n]) {
return prepended(preamble[:n], conn), nil
}
if feTLS == nil {
l.Printf("--- SSL requested by client, declining ---")
if _, err := conn.Write([]byte("N")); err != nil {
return nil, err
}
return conn, nil
}
l.Printf("--- performing SSL handshake with client ---")
if _, err := conn.Write([]byte("S")); err != nil {
return nil, err
}
return tls.Server(conn, feTLS), nil
}
func negotiateServer(conn net.Conn, beTLS backendTLSPolicy, l *log.Logger) (io.ReadWriteCloser, error) {
if !beTLS.Request {
return conn, nil
}
if _, err := conn.Write(pgSSLRequest); err != nil {
return nil, err
}
var resp [1]byte
n, err := conn.Read(resp[:])
if err != nil {
return nil, err
}
if n != 1 {
return nil, fmt.Errorf("error reading response to SSLRequest")
}
switch resp[0] {
case 'N':
if beTLS.Require {
return nil, fmt.Errorf("TLS is required but backend does not allow it")
}
l.Println("--- server declined SSL ---")
return conn, nil
case 'S':
l.Println("--- performing SSL handshake with server ---")
return tls.Client(conn, beTLS.Config), nil
default:
return nil, fmt.Errorf("invalid response to SSLRequest %c(%02x)", resp[0], resp[0])
}
}
func prepended(preamble []byte, conn net.Conn) io.ReadWriteCloser {
return &prepend{
r: io.MultiReader(bytes.NewReader(preamble), conn),
w: conn,
c: conn,
}
}
type prepend struct {
r io.Reader
w io.Writer
c io.Closer
}
func (p *prepend) Read(buf []byte) (int, error) { return p.r.Read(buf) }
func (p *prepend) Write(buf []byte) (int, error) { return p.w.Write(buf) }
func (p *prepend) Close() error { return p.c.Close() }
func serveClient(client io.ReadCloser, server io.WriteCloser, l *log.Logger) {
defer client.Close()
defer server.Close()
l = log.New(l.Writer(), l.Prefix()+"client->server ", l.Flags())
processMessages(client, server, l, frontend{}, 1)
// Wait a moment before closing connections.
time.Sleep(100 * time.Millisecond)
}
func serveServer(server io.ReadCloser, client io.WriteCloser, l *log.Logger) {
defer server.Close()
defer client.Close()
l = log.New(l.Writer(), l.Prefix()+"server->client ", l.Flags())
processMessages(server, client, l, backend{}, 0)
// Wait a moment before closing connections.
time.Sleep(100 * time.Millisecond)
}
package main
import (
"fmt"
"strings"
)
type frontend struct{}
func (frontend) getMessageFormatter(msgType byte, msgLength int) messageFormatter {
// Just the "F" messages from https://www.postgresql.org/docs/current/protocol-message-formats.html
switch msgType {
case 0:
return &startup{want: msgLength}
case 'B':
return &bind{want: msgLength}
case 'D':
return &describe{want: msgLength}
case 'E':
return &execute{want: msgLength}
case 'P':
return &parse{want: msgLength}
case 'Q':
return &query{want: msgLength}
case 'S':
return &syncMsg{want: msgLength}
case 'X':
return &terminate{want: msgLength}
case 'p':
return &gssResponse{want: msgLength}
default:
panic(fmt.Sprintf("'%c' (%d)", msgType, msgType))
}
}
type backend struct{}
func (backend) getMessageFormatter(msgType byte, msgLength int) messageFormatter {
// Just the "B" messages from https://www.postgresql.org/docs/current/protocol-message-formats.html
switch msgType {
case '1':
return &parseComplete{want: msgLength}
case '2':
return &bindComplete{want: msgLength}
case 'C':
return &commandComplete{want: msgLength}
case 'D':
return &dataRow{want: msgLength}
case 'E':
return &errorResponse{want: msgLength}
case 'K':
return &backendKey{want: msgLength}
case 'R':
return &authenticationRequest{want: msgLength}
case 'S':
return &parameterStatus{want: msgLength}
case 'T':
return &rowDescription{want: msgLength}
case 'Z':
return &readyForQuery{want: msgLength}
case 't':
return &parameterDescription{want: msgLength}
default:
panic(fmt.Sprintf("'%c' (%d)", msgType, msgType))
}
}
type messageFormatter interface {
Format([]byte) string
}
type startup struct {
want int
have []byte
}
func (m *startup) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 5 {
return fmt.Sprintf("StartupMessage too short raw=%x", m.have)
}
var sb strings.Builder
protoVersion := byteOrder.Uint32(m.have)
fmt.Fprintf(&sb, "StartupMessage proto=%d", protoVersion)
if protoVersion != 196610 {
fmt.Fprintf(&sb, "(!)")
}
p := m.have[4:]
for {
if len(p) == 1 && p[0] == 0 {
break
}
if len(p) < 1 {
fmt.Fprintf(&sb, " (unexpected end of parameters!)")
break
}
str, rest, ok := decodeString(p)
if !ok {
fmt.Fprintf(&sb, " (unexpected end of parameters remaining=%x)", p)
break
}
fmt.Fprintf(&sb, " %s=", str)
p = rest
str, rest, ok = decodeString(p)
if !ok {
fmt.Fprintf(&sb, " (unexpected end of parameters remaining=%x)", p)
break
}
fmt.Fprintf(&sb, "%q", str)
p = rest
}
return sb.String()
}
type bind struct {
want int
have []byte
}
func (m *bind) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 4 {
return fmt.Sprintf("Bind message too short raw=%x", m.have)
}
var sb strings.Builder
fmt.Fprintf(&sb, "Bind")
// Portal name
portalName, rest, ok := decodeString(m.have)
if !ok {
fmt.Fprintf(&sb, " (error decoding portal name) raw=%x", m.have)
return sb.String()
}
if len(portalName) == 0 {
fmt.Fprintf(&sb, " dstportal=<unnamed>")
} else {
fmt.Fprintf(&sb, " dstportal=%q", portalName)
}
p := rest
// Statement name
stmtName, rest, ok := decodeString(p)
if !ok {
fmt.Fprintf(&sb, " (error decoding statement name) rest=%x", p)
return sb.String()
}
if len(stmtName) == 0 {
fmt.Fprintf(&sb, " stmt=<unnamed>")
} else {
fmt.Fprintf(&sb, " stmt=%q", stmtName)
}
p = rest
// Parameter format codes
if len(p) < 2 {
fmt.Fprintf(&sb, " (missing param format count) rest=%x", p)
return sb.String()
}
numParamFormats := byteOrder.Uint16(p[0:2])
p = p[2:]
if len(p) < int(numParamFormats)*2 {
fmt.Fprintf(&sb, " (incomplete param formats) rest=%x", p)
return sb.String()
}
paramFormats := make([]uint16, numParamFormats)
for i := 0; i < int(numParamFormats); i++ {
paramFormats[i] = byteOrder.Uint16(p[i*2 : (i+1)*2])
}
p = p[numParamFormats*2:]
// Parameter values
if len(p) < 2 {
fmt.Fprintf(&sb, " (missing param value count) rest=%x", p)
return sb.String()
}
numParams := byteOrder.Uint16(p[0:2])
p = p[2:]
fmt.Fprintf(&sb, " params=%d", numParams)
for i := 0; i < int(numParams); i++ {
if len(p) < 4 {
fmt.Fprintf(&sb, " (incomplete param %d length) rest=%x", i, p)
return sb.String()
}
paramLen := int32(byteOrder.Uint32(p[0:4]))
p = p[4:]
switch {
case paramLen == -1:
fmt.Fprintf(&sb, " NULL")
case paramLen < 0:
fmt.Fprintf(&sb, " (invalid param length %d)", paramLen)
return sb.String()
default:
if len(p) < int(paramLen) {
fmt.Fprintf(&sb, " (incomplete param %d data) rest=%x", i, p)
return sb.String()
}
paramData := p[:paramLen]
p = p[paramLen:]
if isPrintable(paramData) {
fmt.Fprintf(&sb, " %q", string(paramData))
} else {
fmt.Fprintf(&sb, " [%x]", paramData)
}
}
}
// Result format codes
if len(p) < 2 {
fmt.Fprintf(&sb, " (missing result format count) rest=%x", p)
return sb.String()
}
numResultFormats := byteOrder.Uint16(p[0:2])
p = p[2:]
if len(p) < int(numResultFormats)*2 {
fmt.Fprintf(&sb, " (incomplete result formats) rest=%x", p)
return sb.String()
}
if numResultFormats > 0 {
fmt.Fprintf(&sb, " result_fmts=[")
for i := 0; i < int(numResultFormats); i++ {
if i > 0 {
fmt.Fprintf(&sb, ",")
}
rf := byteOrder.Uint16(p[i*2 : (i+1)*2])
switch rf {
case 0:
fmt.Fprint(&sb, "text")
case 1:
fmt.Fprint(&sb, "binary")
default:
fmt.Fprintf(&sb, "invalid(%d)", rf)
}
}
fmt.Fprintf(&sb, "]")
p = p[numResultFormats*2:]
}
if len(p) > 0 {
fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p)
}
return sb.String()
}
type describe struct {
want int
have []byte
}
func (m *describe) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 2 {
return fmt.Sprintf("Describe message too short raw=%x", m.have)
}
var sb strings.Builder
fmt.Fprintf(&sb, "Describe")
objType := m.have[0]
switch objType {
case 'S':
fmt.Fprintf(&sb, " statement")
case 'P':
fmt.Fprintf(&sb, " portal")
default:
fmt.Fprintf(&sb, " ??(%x %c)", objType, objType)
}
name, rest, ok := decodeString(m.have[1:])
if !ok {
fmt.Fprintf(&sb, " (error decoding name) rest=%x", m.have[1:])
return sb.String()
}
if len(name) == 0 {
fmt.Fprintf(&sb, " <unnamed>")
} else {
fmt.Fprintf(&sb, " %q", name)
}
if len(rest) > 0 {
fmt.Fprintf(&sb, " (unexpected trailing data: %x)", rest)
}
return sb.String()
}
type execute struct {
want int
have []byte
}
func (m *execute) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 5 {
return fmt.Sprintf("Execute message too short raw=%x", m.have)
}
// Portal name
portalName, rest, ok := decodeString(m.have)
if !ok {
return fmt.Sprintf("Execute (error decoding portal name) raw=%x", m.have)
}
if len(rest) < 4 {
return fmt.Sprintf("Execute (missing max rows) rest=%x", rest)
}
maxRows := byteOrder.Uint32(rest[0:4])
rest = rest[4:]
var sb strings.Builder
fmt.Fprintf(&sb, "Execute")
if len(portalName) == 0 {
fmt.Fprintf(&sb, " portal=<unnamed>")
} else {
fmt.Fprintf(&sb, " portal=%q", portalName)
}
if maxRows == 0 {
fmt.Fprintf(&sb, " maxrows=unlimited")
} else {
fmt.Fprintf(&sb, " maxrows=%d", maxRows)
}
if len(rest) > 0 {
fmt.Fprintf(&sb, " (unexpected trailing data: %x)", rest)
}
return sb.String()
}
type parse struct {
want int
have []byte
}
func (m *parse) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 3 {
return fmt.Sprintf("Parse message too short raw=%x", m.have)
}
var sb strings.Builder
fmt.Fprintf(&sb, "Parse")
// Statement name
stmtName, rest, ok := decodeString(m.have)
if !ok {
fmt.Fprintf(&sb, " (error decoding statement name) raw=%x", m.have)
return sb.String()
}
if len(stmtName) == 0 {
fmt.Fprintf(&sb, " stmt=<unnamed>")
} else {
fmt.Fprintf(&sb, " stmt=%q", stmtName)
}
p := rest
// Query string
queryStr, rest, ok := decodeString(p)
if !ok {
fmt.Fprintf(&sb, " (error decoding query) rest=%x", p)
return sb.String()
}
fmt.Fprintf(&sb, " query=%q", queryStr)
p = rest
// Number of parameter types
if len(p) < 2 {
fmt.Fprintf(&sb, " (missing param count) rest=%x", p)
return sb.String()
}
numParams := byteOrder.Uint16(p[0:2])
p = p[2:]
fmt.Fprintf(&sb, " params=%d", numParams)
if numParams > 0 {
// Parameter type OIDs
if len(p) < int(numParams)*4 {
fmt.Fprintf(&sb, " (incomplete param types, need %d bytes) rest=%x", numParams*4, p)
return sb.String()
}
fmt.Fprintf(&sb, " types=[")
for i := 0; i < int(numParams); i++ {
if i > 0 {
fmt.Fprintf(&sb, ",")
}
typeOID := byteOrder.Uint32(p[i*4 : (i+1)*4])
if typeOID == 0 {
fmt.Fprintf(&sb, "?")
} else {
fmt.Fprintf(&sb, "%d", typeOID)
}
}
fmt.Fprintf(&sb, "]")
p = p[numParams*4:]
}
if len(p) > 0 {
fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p)
}
return sb.String()
}
type parseComplete struct {
want int
have []byte
}
func (m *parseComplete) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want != 0 {
return fmt.Sprintf("ParseComplete (unexpected payload: %x)", m.have)
}
return "ParseComplete"
}
type bindComplete struct {
want int
have []byte
}
func (m *bindComplete) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want != 0 {
return fmt.Sprintf("BindComplete (unexpected payload: %x)", m.have)
}
return "BindComplete"
}
type commandComplete struct {
want int
have []byte
}
func (m *commandComplete) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 1 {
return fmt.Sprintf("CommandComplete message too short raw=%x", m.have)
}
tag, rest, ok := decodeString(m.have)
if !ok {
return fmt.Sprintf("CommandComplete invalid tag raw=%x", m.have)
}
if len(rest) > 0 {
return fmt.Sprintf("CommandComplete %q (unexpected trailing data: %x)", tag, rest)
}
return fmt.Sprintf("CommandComplete %q", tag)
}
type dataRow struct {
want int
have []byte
}
func (m *dataRow) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 2 {
return fmt.Sprintf("DataRow message too short raw=%x", m.have)
}
numCols := byteOrder.Uint16(m.have)
var sb strings.Builder
fmt.Fprintf(&sb, "DataRow cols=%d", numCols)
p := m.have[2:]
for i := 0; i < int(numCols); i++ {
if len(p) < 4 {
fmt.Fprintf(&sb, " (incomplete column %d, need length field) rest=%x", i, p)
break
}
colLen := int32(byteOrder.Uint32(p[0:4]))
p = p[4:]
if colLen == -1 {
fmt.Fprintf(&sb, " NULL")
continue
}
if colLen < 0 {
fmt.Fprintf(&sb, " (invalid length %d) rest=%x", colLen, p)
break
}
if len(p) < int(colLen) {
fmt.Fprintf(&sb, " (incomplete column %d, need %d bytes) rest=%x", i, colLen, p)
break
}
colData := p[:colLen]
p = p[colLen:]
// Try to display as string if it looks printable
if isPrintable(colData) {
fmt.Fprintf(&sb, " %q", string(colData))
} else {
fmt.Fprintf(&sb, " [%x]", colData)
}
}
if len(p) > 0 {
fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p)
}
return sb.String()
}
func isPrintable(data []byte) bool {
for _, b := range data {
if b < 32 && b != '\t' && b != '\n' && b != '\r' {
return false
}
if b > 126 {
return false
}
}
return true
}
type errorResponse struct {
want int
have []byte
}
func (m *errorResponse) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 1 {
return fmt.Sprintf("ErrorResponse message too short raw=%x", m.have)
}
var sb strings.Builder
fmt.Fprintf(&sb, "ErrorResponse")
p := m.have
for {
if len(p) == 0 {
fmt.Fprintf(&sb, " (missing terminator)")
break
}
fieldType := p[0]
p = p[1:]
if fieldType == 0 {
// Terminator
break
}
fieldValue, rest, ok := decodeString(p)
if !ok {
fmt.Fprintf(&sb, " (error decoding field %c) rest=%x", fieldType, p)
break
}
p = rest
switch fieldType {
case 'S':
fmt.Fprintf(&sb, " severity=%q", fieldValue)
case 'V':
fmt.Fprintf(&sb, " severity_nonlocal=%q", fieldValue)
case 'C':
fmt.Fprintf(&sb, " code=%q", fieldValue)
case 'M':
fmt.Fprintf(&sb, " message=%q", fieldValue)
case 'D':
fmt.Fprintf(&sb, " detail=%q", fieldValue)
case 'H':
fmt.Fprintf(&sb, " hint=%q", fieldValue)
case 'P':
fmt.Fprintf(&sb, " position=%q", fieldValue)
case 'p':
fmt.Fprintf(&sb, " internal_position=%q", fieldValue)
case 'q':
fmt.Fprintf(&sb, " internal_query=%q", fieldValue)
case 'W':
fmt.Fprintf(&sb, " where=%q", fieldValue)
case 's':
fmt.Fprintf(&sb, " schema=%q", fieldValue)
case 't':
fmt.Fprintf(&sb, " table=%q", fieldValue)
case 'c':
fmt.Fprintf(&sb, " column=%q", fieldValue)
case 'd':
fmt.Fprintf(&sb, " datatype=%q", fieldValue)
case 'n':
fmt.Fprintf(&sb, " constraint=%q", fieldValue)
case 'F':
fmt.Fprintf(&sb, " file=%q", fieldValue)
case 'L':
fmt.Fprintf(&sb, " line=%q", fieldValue)
case 'R':
fmt.Fprintf(&sb, " routine=%q", fieldValue)
default:
fmt.Fprintf(&sb, " %x=%q", fieldType, fieldValue)
}
}
if len(p) > 0 {
fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p)
}
return sb.String()
}
type backendKey struct {
want int
have []byte
}
func (m *backendKey) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 4 {
return fmt.Sprintf("BackendKeyData too short raw=%x", m.have)
}
pid := byteOrder.Uint32(m.have)
return fmt.Sprintf("BackendKeyData pid=%d secret=%x", pid, m.have[4:])
}
type query struct {
want int
have []byte
}
func (m *query) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 1 {
return fmt.Sprintf("Query message too short raw=%x", m.have)
}
queryString, rest, ok := decodeString(m.have)
if !ok {
return fmt.Sprintf("Query invalid string raw=%x", m.have)
}
if len(rest) > 0 {
return fmt.Sprintf("Query %q (unexpected trailing data: %x)", queryString, rest)
}
return fmt.Sprintf("Query %q", queryString)
}
type syncMsg struct {
want int
have []byte
}
func (m *syncMsg) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want != 0 {
return fmt.Sprintf("Sync (unexpected payload: %x)", m.have)
}
return "Sync"
}
type authenticationRequest struct {
want int
have []byte
}
func (m *authenticationRequest) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 4 {
return fmt.Sprintf("Authentication message too short raw=%x", m.have)
}
code := byteOrder.Uint32(m.have)
switch {
case m.want == 4 && code == 0:
return "AuthenticationOk"
case m.want == 4 && code == 2:
return "AuthenticationKerberosV5"
case m.want == 4 && code == 3:
return "AuthenticationCleartextPassword"
case m.want == 8 && code == 5:
return fmt.Sprintf("AuthenticationMD5Password salt=%x", m.have[4:])
case m.want == 4 && code == 7:
return "AuthenticationGSS"
case code == 8:
return fmt.Sprintf("AuthenticationGSSContinue data=%x", m.have[4:])
case m.want == 4 && code == 9:
return "AuthenticationSSPI"
case code == 10:
name, _, ok := decodeString(m.have[4:])
if ok {
return fmt.Sprintf("AuthenticationSASL name=%q", name)
} else {
return fmt.Sprintf("AuthenticationSASL invalid name=%q", m.have[4:])
}
case code == 11:
return fmt.Sprintf("AuthenticationSASLContinue data=%x", m.have[4:])
case code == 12:
return fmt.Sprintf("AuthenticationSASLFinal data=%x", m.have[4:])
default:
return fmt.Sprintf("unrecognized Authentication message len=%d code=%d payload=%x",
m.want, code, m.have[4:])
}
}
type parameterStatus struct {
want int
have []byte
}
func (m *parameterStatus) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 2 {
return fmt.Sprintf("ParameterStatus message too short raw=%x", m.have)
}
name, rest, ok := decodeString(m.have)
if !ok {
return fmt.Sprintf("ParameterStatus message invalid name raw=%x", m.have)
}
value, rest, ok := decodeString(rest)
if !ok {
return fmt.Sprintf("ParameterStatus %s= invalid value raw=%x", name, rest)
}
return fmt.Sprintf("ParameterStatus %s=%q", name, value)
}
type rowDescription struct {
want int
have []byte
}
func (m *rowDescription) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 2 {
return fmt.Sprintf("RowDescription message too short raw=%x", m.have)
}
numFields := byteOrder.Uint16(m.have)
var sb strings.Builder
fmt.Fprintf(&sb, "RowDescription fields=%d", numFields)
p := m.have[2:]
for i := 0; i < int(numFields); i++ {
// Field name
name, rest, ok := decodeString(p)
if !ok {
fmt.Fprintf(&sb, " (error decoding field %d name)", i)
break
}
fmt.Fprintf(&sb, " [%q", name)
p = rest
// Need at least 18 bytes for the remaining fields
if len(p) < 18 {
fmt.Fprintf(&sb, " incomplete] rest=%x", p)
break
}
tableOID := byteOrder.Uint32(p[0:4])
colAttrNum := byteOrder.Uint16(p[4:6])
typeOID := byteOrder.Uint32(p[6:10])
typeSize := int16(byteOrder.Uint16(p[10:12]))
typeMod := byteOrder.Uint32(p[12:16])
formatCode := byteOrder.Uint16(p[16:18])
fmt.Fprintf(&sb, " table=%d col=%d typeOID=%d", tableOID, colAttrNum, typeOID)
if typeSize < 0 {
fmt.Fprintf(&sb, " size=var(%d)", typeSize)
} else {
fmt.Fprintf(&sb, " size=%d", typeSize)
}
fmt.Fprintf(&sb, " mod=%08x", typeMod)
switch formatCode {
case 0:
fmt.Fprintf(&sb, " fmt=text")
case 1:
fmt.Fprintf(&sb, " fmt=binary")
default:
fmt.Fprintf(&sb, " fmt=!!%d!!", formatCode)
}
fmt.Fprintf(&sb, "]")
p = p[18:]
}
if len(p) > 0 {
fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p)
}
return sb.String()
}
type terminate struct {
want int
have []byte
}
func (m *terminate) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want != 0 {
return fmt.Sprintf("Terminate (unexpected payload: %x)", m.have)
}
return "Terminate"
}
type readyForQuery struct {
want int
have []byte
}
func (m *readyForQuery) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 1 {
return fmt.Sprintf("ReadyForQuery message too short raw=%x", m.have)
}
status := m.have[0]
switch status {
case 'I':
return "ReadyForQuery (idle)"
case 'T':
return "ReadyForQuery (in a transaction block)"
case 'E':
return "ReadyForQuery (in a vailed transaction block)"
default:
return fmt.Sprintf("ReadyForQuery status=unknown(%c/%x)", status, status)
}
}
type parameterDescription struct {
want int
have []byte
}
func (m *parameterDescription) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
if m.want < 2 {
return fmt.Sprintf("ParameterDescription message too short raw=%x", m.have)
}
numParams := byteOrder.Uint16(m.have)
var sb strings.Builder
fmt.Fprintf(&sb, "ParameterDescription params=%d", numParams)
p := m.have[2:]
if numParams > 0 {
if len(p) < int(numParams)*4 {
fmt.Fprintf(&sb, " (incomplete, need %d bytes) rest=%x", numParams*4, p)
return sb.String()
}
fmt.Fprintf(&sb, " types=[")
for i := 0; i < int(numParams); i++ {
if i > 0 {
fmt.Fprintf(&sb, ",")
}
typeOID := byteOrder.Uint32(p[i*4 : (i+1)*4])
fmt.Fprintf(&sb, "%d", typeOID)
}
fmt.Fprintf(&sb, "]")
p = p[numParams*4:]
}
if len(p) > 0 {
fmt.Fprintf(&sb, " (unexpected trailing data: %x)", p)
}
return sb.String()
}
type gssResponse struct {
want int
have []byte
}
func (m *gssResponse) Format(data []byte) string {
m.have = append(m.have, data...)
if len(m.have) < m.want {
return "..."
}
return fmt.Sprintf("GSSResponse data=%x", m.have)
}
package main
import (
"bytes"
"encoding/binary"
"io"
"log"
)
var byteOrder = binary.BigEndian
type side interface {
getMessageFormatter(msgType byte, msgLen int) messageFormatter
}
// processMessages does some introspection of messages received from r before
// forwarding them to w.
//
// hi should be 1 for client->server and 0 for server->client.
func processMessages(r io.Reader, w io.Writer, l *log.Logger, side side, hi int) {
buf := make([]byte, 1024*1024)
const maxHeaderSize = 5
var hdr [maxHeaderSize]byte
bodyRemaining := 0
var pretty messageFormatter
for {
n, err := r.Read(buf)
if err != nil {
l.Printf("error reading data: %v", err)
return
}
for p := buf[:n]; len(p) > 0; {
switch {
case bodyRemaining == 0:
// we need more header bytes
copied := copy(hdr[hi:], p)
hi += copied
p = p[copied:]
if hi == maxHeaderSize {
// fall through
msgType := hdr[0]
length := int(byteOrder.Uint32(hdr[1:]))
l.Printf("MSG '%c' (%x) (len=%d)", msgType, msgType, length-4)
// SSL mode is []uint16{1234,5678} or []byte{0x04, 0xd2, 0x16, 0x2f}.
if msgType == 0 && length == 8 && bytes.Equal(p[:4], []byte{0x04, 0xd2, 0x16, 0x2f}) {
l.Printf("SSL mode is not supported")
return
}
bodyRemaining = length - 4
hi = 0
pretty = side.getMessageFormatter(msgType, bodyRemaining)
if bodyRemaining == 0 {
l.Printf(" %s", pretty.Format(nil))
}
}
case bodyRemaining > len(p):
l.Printf(" %s", pretty.Format(p))
bodyRemaining -= len(p)
p = nil
case bodyRemaining > 0:
l.Printf(" %s", pretty.Format(p[:bodyRemaining]))
p = p[bodyRemaining:]
bodyRemaining = 0
default:
l.Printf("invalid state, bailing")
}
}
l.Printf("write(%d)", n)
if _, err := w.Write(buf[:n]); err != nil {
l.Printf("error writing data: %v", err)
return
}
}
}
// decodeString decodes a Postgres 'String' atom.
//
// > A null-terminated string (C-style string). There is no specific length
// > limitation on strings. If s is specified it is the exact value that will
// > appear, otherwise the value is variable. Eg. String, String("user").
//
// returns (string, rest, true) if found, (nil, data, false) if not.
func decodeString(data []byte) ([]byte, []byte, bool) {
i := bytes.Index(data, []byte{0})
if i == -1 {
return nil, data, false
}
return data[:i], data[i+1:], true
}
package main
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net"
"os"
"time"
)
func initSSL(certPath, keyPath string) (*tls.Config, error) {
certData, certErr := os.ReadFile(certPath)
keyData, keyErr := os.ReadFile(keyPath)
switch {
case os.IsNotExist(certErr) && os.IsNotExist(keyErr):
// fall through to the code after the switch, which generates a new cert and key.
case certErr != nil:
return nil, fmt.Errorf("%s: %v", certPath, certErr)
case keyErr != nil:
return nil, fmt.Errorf("%s: %v", keyPath, keyErr)
default:
cert, err := tls.X509KeyPair(certData, keyData)
if err != nil {
return nil, err
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
}, nil
}
// Generate self-signed certificate
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("error generating private key: %w", err)
}
notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour) // Valid for 1 year
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("error generating serial number: %w", err)
}
template := x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
Organization: []string{"PostgreSQL MITM Proxy"},
CommonName: "localhost",
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
DNSNames: []string{"localhost"},
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
}
derBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &priv.PublicKey, priv)
if err != nil {
return nil, fmt.Errorf("error creating certificate: %w", err)
}
// Save certificate
certOut, err := os.Create(certPath)
if err != nil {
return nil, fmt.Errorf("error creating cert file: %w", err)
}
defer certOut.Close()
if err := pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: derBytes}); err != nil {
return nil, fmt.Errorf("error writing cert: %w", err)
}
// Save private key
keyOut, err := os.Create(keyPath)
if err != nil {
return nil, fmt.Errorf("error creating key file: %w", err)
}
defer keyOut.Close()
privBytes, err := x509.MarshalPKCS8PrivateKey(priv)
if err != nil {
return nil, fmt.Errorf("error marshaling private key: %w", err)
}
if err := pem.Encode(keyOut, &pem.Block{Type: "PRIVATE KEY", Bytes: privBytes}); err != nil {
return nil, fmt.Errorf("error writing key: %w", err)
}
// Load the newly created certificate
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, fmt.Errorf("error loading generated certificate: %w", err)
}
return &tls.Config{
Certificates: []tls.Certificate{cert},
}, nil
}
type backendTLSPolicy struct {
Request bool
Require bool
Config *tls.Config
}
func configureBackendTLS(mode, servername string) (backendTLSPolicy, bool) {
switch mode {
case "disable":
return backendTLSPolicy{}, true
case "allow":
return backendTLSPolicy{Request: true, Config: strictBETLSConfig(servername)}, true
case "require":
return backendTLSPolicy{Request: true, Require: true, Config: strictBETLSConfig(servername)}, true
case "insecure":
return backendTLSPolicy{Request: true, Config: insecureBETLSConfig()}, true
default:
return backendTLSPolicy{}, false
}
}
func strictBETLSConfig(servername string) *tls.Config {
return &tls.Config{
ServerName: servername,
}
}
func insecureBETLSConfig() *tls.Config {
return &tls.Config{
InsecureSkipVerify: true,
}
}
#!/bin/sh
for m in "sslmode=prefer" "sslmode=require" "sslmode=disable"; do
echo "==== $m ===="
psql "host=127.0.0.1 port=5555 user=postgres password=xxx $m" -c 'select * from hi'
done
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment