| /* |
| Copyright 2017 The Transicator Authors |
| |
| Licensed under the Apache License, Version 2.0 (the "License"); |
| you may not use this file except in compliance with the License. |
| You may obtain a copy of the License at |
| |
| http://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| See the License for the specific language governing permissions and |
| limitations under the License. |
| */ |
| package pgclient |
| |
| import ( |
| "bytes" |
| "crypto/rand" |
| "crypto/tls" |
| "encoding/binary" |
| "fmt" |
| "io" |
| "net" |
| "regexp" |
| |
| log "github.com/Sirupsen/logrus" |
| ) |
| |
| const ( |
| mockUserName = "mock" |
| mockPassword = "mocketty" |
| mockDatabaseName = "turtle" |
| ) |
| |
| type mockState int |
| |
| const ( |
| mockIdle mockState = 1 |
| ) |
| |
| // MockAuth specifies what type of authentication the server supports |
| type MockAuth int |
| |
| // Different auth types |
| const ( |
| MockTrust MockAuth = 0 |
| MockClear MockAuth = 1 |
| MockMD5 MockAuth = 2 |
| ) |
| |
| var insertRE = regexp.MustCompile("insert into mock values \\('([\\w]+)', '([\\w]+)'\\)") |
| |
| /* |
| A MockServer is a server that implements a little bit of the Postgres wire |
| protocol. We can use it for testing of the wire protocol client. In particular, |
| we can use it to test the myriad of password authentication and TLS options |
| without having to start and stop a real Postgres server in the test suite. |
| */ |
| type MockServer struct { |
| listener net.Listener |
| mockTable map[string]string |
| authType MockAuth |
| tlsConfig *tls.Config |
| forceTLS bool |
| } |
| |
| /* |
| NewMockServer starts a new server in the current process, listening on the |
| specified port. |
| */ |
| func NewMockServer() *MockServer { |
| return &MockServer{ |
| mockTable: make(map[string]string), |
| authType: MockTrust, |
| } |
| } |
| |
| /* |
| SetAuthType sets what kind of password authentication to require |
| */ |
| func (m *MockServer) SetAuthType(auth MockAuth) { |
| m.authType = auth |
| } |
| |
| /* |
| SetTLSInfo sets the cert and key file and makes it possible for the server |
| to support TLS. |
| */ |
| func (m *MockServer) SetTLSInfo(certFile, keyFile string) error { |
| cert, err := tls.LoadX509KeyPair(certFile, keyFile) |
| if err != nil { |
| return err |
| } |
| m.tlsConfig = &tls.Config{ |
| Certificates: []tls.Certificate{cert}, |
| } |
| return nil |
| } |
| |
| /* |
| SetForceTLS sets up the server to reject any non-TLS clients. |
| */ |
| func (m *MockServer) SetForceTLS() { |
| m.forceTLS = true |
| } |
| |
| /* |
| Start listening for stuff. |
| */ |
| func (m *MockServer) Start(port int) (err error) { |
| listener, err := net.ListenTCP("tcp", &net.TCPAddr{ |
| Port: port, |
| }) |
| |
| if err != nil { |
| return err |
| } |
| |
| m.listener = listener |
| go m.acceptLoop() |
| return nil |
| } |
| |
| /* |
| Address returns the listen address in host:port format. |
| */ |
| func (m *MockServer) Address() string { |
| return m.listener.Addr().String() |
| } |
| |
| /* |
| Stop stops the server listening for new connections. |
| */ |
| func (m *MockServer) Stop() { |
| if m.listener != nil { |
| m.listener.Close() |
| } |
| } |
| |
| /* |
| acceptLoop sits and accepts new connections. |
| */ |
| func (m *MockServer) acceptLoop() { |
| for { |
| conn, err := m.listener.Accept() |
| if err != nil { |
| return |
| } |
| go m.connectLoop(conn) |
| } |
| } |
| |
| /* |
| connectLoop responds to a new connection by handling the Postgres |
| authentication and startup protocol. |
| */ |
| func (m *MockServer) connectLoop(c net.Conn) { |
| defer c.Close() |
| |
| startup, err := readMockMessage(c, true) |
| if err != nil { |
| log.Errorf("Error reading startup message: %s\n", err) |
| return |
| } |
| |
| protoVersion, _ := startup.ReadInt32() |
| |
| if protoVersion == sslMagicNumber { |
| // SSL startup attempt. Respond with "S" or "N" |
| if m.tlsConfig == nil { |
| c.Write([]byte{'N'}) |
| } else { |
| // Respond with the appropriate byte and upgrade to TLS |
| c.Write([]byte{'S'}) |
| c = tls.Server(c, m.tlsConfig) |
| } |
| |
| // Look for a new startup packet now. |
| startup, err = readMockMessage(c, true) |
| if err != nil { |
| // This might happen if TLS handshake failed, which is a valid test case |
| return |
| } |
| protoVersion, _ = startup.ReadInt32() |
| } else if m.forceTLS { |
| // We will just close any connection that doesn't ask for TLS |
| return |
| } |
| |
| if protoVersion != protocolVersion { |
| sendError(c, fmt.Sprintf("Invalid read protocol version %d\n", protoVersion)) |
| return |
| } |
| |
| var paramName, paramVal string |
| for { |
| paramName, _ = startup.ReadString() |
| if paramName == "" { |
| break |
| } |
| paramVal, _ = startup.ReadString() |
| |
| if paramName == "user" { |
| if paramVal != mockUserName { |
| sendError(c, fmt.Sprintf("Invalid user name %s", paramVal)) |
| return |
| } |
| } |
| if paramName == "database" { |
| if paramVal != mockDatabaseName { |
| sendError(c, fmt.Sprintf("Invalid database name %s\n", paramVal)) |
| return |
| } |
| } |
| } |
| |
| authOK := m.authLoop(c) |
| |
| if authOK { |
| sendAuthResponse(c, 0) |
| sendReady(c) |
| m.readLoop(c) |
| } |
| } |
| |
| /* |
| authLoop runs the various Postgres authentication options. |
| */ |
| func (m *MockServer) authLoop(c net.Conn) bool { |
| switch m.authType { |
| case MockTrust: |
| return true |
| |
| case MockClear: |
| sendAuthResponse(c, 3) |
| return m.readPassword(c, mockPassword) |
| |
| case MockMD5: |
| salt := make([]byte, 4) |
| rand.Read(salt) |
| out := NewServerOutputMessage(AuthenticationResponse) |
| out.WriteInt32(5) |
| out.WriteBytes(salt) |
| c.Write(out.Encode()) |
| return m.readPassword(c, passwordMD5(mockUserName, mockPassword, salt)) |
| |
| default: |
| return false |
| } |
| } |
| |
| func (m *MockServer) readPassword(c net.Conn, expected string) bool { |
| msg, _ := readMockMessage(c, false) |
| if msg.ServerType() != PasswordMessage { |
| sendError(c, "Expected PasswordMessage") |
| return false |
| } |
| |
| pwd, _ := msg.ReadString() |
| if pwd == expected { |
| return true |
| } |
| sendError(c, "Invalid password") |
| return false |
| } |
| |
| /* |
| readLoop now reads and parses SQL commands until it's time to shut the |
| connection down. |
| */ |
| func (m *MockServer) readLoop(c net.Conn) { |
| state := mockIdle |
| |
| for { |
| msg, err := readMockMessage(c, false) |
| if err != nil { |
| return |
| } |
| |
| switch state { |
| case mockIdle: |
| m.readIdle(c, msg) |
| } |
| } |
| } |
| |
| func (m *MockServer) readIdle(c net.Conn, msg *InputMessage) { |
| switch msg.ServerType() { |
| case Query: |
| sql, _ := msg.ReadString() |
| match := insertRE.FindStringSubmatch(sql) |
| if match != nil { |
| m.mockTable[match[1]] = match[2] |
| out := NewServerOutputMessage(CommandComplete) |
| out.WriteString("INSERT 1") |
| c.Write(out.Encode()) |
| sendReady(c) |
| |
| } else { |
| sendError(c, fmt.Sprintf("Invalid SQL \"%s\"", sql)) |
| sendReady(c) |
| } |
| |
| default: |
| sendError(c, fmt.Sprintf("Invalid message %s", msg.ServerType())) |
| sendReady(c) |
| } |
| } |
| |
| func readMockMessage(c net.Conn, isStartup bool) (msg *InputMessage, err error) { |
| var hdr []byte |
| if isStartup { |
| hdr = make([]byte, 4) |
| } else { |
| hdr = make([]byte, 5) |
| } |
| |
| _, err = io.ReadFull(c, hdr) |
| if err != nil { |
| return |
| } |
| |
| hdrBuf := bytes.NewBuffer(hdr) |
| var msgType PgOutputType |
| |
| if !isStartup { |
| var msgTypeVal byte |
| msgTypeVal, err = hdrBuf.ReadByte() |
| if err != nil { |
| return |
| } |
| msgType = PgOutputType(msgTypeVal) |
| } |
| |
| var msgLen int32 |
| err = binary.Read(hdrBuf, networkByteOrder, &msgLen) |
| if err != nil { |
| return |
| } |
| |
| if msgLen < 4 { |
| err = fmt.Errorf("Invalid message length %d", msgLen) |
| return |
| } |
| |
| bodBuf := make([]byte, msgLen-4) |
| _, err = io.ReadFull(c, bodBuf) |
| if err != nil { |
| return |
| } |
| |
| msg = NewServerInputMessage(msgType, bodBuf) |
| return |
| } |
| |
| func sendError(c net.Conn, msg string) { |
| out := NewServerOutputMessage(ErrorResponse) |
| out.WriteByte('S') |
| out.WriteString("FATAL") |
| out.WriteByte('M') |
| out.WriteString(msg) |
| out.WriteByte(0) |
| c.Write(out.Encode()) |
| } |
| |
| func sendAuthResponse(c net.Conn, code int32) { |
| out := NewServerOutputMessage(AuthenticationResponse) |
| out.WriteInt32(code) |
| c.Write(out.Encode()) |
| } |
| |
| func sendReady(c net.Conn) { |
| out := NewServerOutputMessage(ReadyForQuery) |
| out.WriteByte('I') |
| c.Write(out.Encode()) |
| } |