* Generate RSA key and spawn SFTP server for test
diff --git a/sftp_test.go b/sftp_test.go index 8c48652..839a19d 100644 --- a/sftp_test.go +++ b/sftp_test.go
@@ -14,11 +14,18 @@ package afero import ( - "os" - "io/ioutil" "testing" - - "github.com/spf13/afero" + "os" + "log" + "fmt" + "net" + "flag" + "time" + "io/ioutil" + "crypto/rsa" + _rand "crypto/rand" + "encoding/pem" + "crypto/x509" "golang.org/x/crypto/ssh" "github.com/pkg/sftp" @@ -30,7 +37,7 @@ sftpc *sftp.Client } -func SftpConnect(user string, host string) (*SftpFsContext, error) { +func SftpConnect(user, password, host string) (*SftpFsContext, error) { pemBytes, err := ioutil.ReadFile(os.Getenv("HOME") + "/.ssh/id_rsa") if err != nil { return nil,err @@ -44,6 +51,7 @@ sshcfg := &ssh.ClientConfig{ User: user, Auth: []ssh.AuthMethod{ + ssh.Password(password), ssh.PublicKeys(signer), }, } @@ -73,20 +81,194 @@ return nil } +func RunSftpServer() { + var ( + readOnly bool + debugLevelStr string + debugLevel int + debugStderr bool + rootDir string + ) + + flag.BoolVar(&readOnly, "R", false, "read-only server") + flag.BoolVar(&debugStderr, "e", false, "debug to stderr") + flag.StringVar(&debugLevelStr, "l", "none", "debug level") + flag.StringVar(&rootDir, "root", ".", "root directory") + flag.Parse() + + debugStream := ioutil.Discard + if debugStderr { + debugStream = os.Stderr + debugLevel = 1 + } + + // An SSH server is represented by a ServerConfig, which holds + // certificate details and handles authentication of ServerConns. + config := &ssh.ServerConfig{ + PasswordCallback: func(c ssh.ConnMetadata, pass []byte) (*ssh.Permissions, error) { + // Should use constant-time compare (or better, salt+hash) in + // a production setting. + fmt.Fprintf(debugStream, "Login: %s\n", c.User()) + if c.User() == "test" && string(pass) == "test" { + return nil, nil + } + return nil, fmt.Errorf("password rejected for %q", c.User()) + }, + } + + privateBytes, err := ioutil.ReadFile("./test/id_rsa") + if err != nil { + log.Fatal("Failed to load private key", err) + } + + private, err := ssh.ParsePrivateKey(privateBytes) + if err != nil { + log.Fatal("Failed to parse private key", err) + } + + config.AddHostKey(private) + + // Once a ServerConfig has been configured, connections can be + // accepted. + listener, err := net.Listen("tcp", "0.0.0.0:2022") + if err != nil { + log.Fatal("failed to listen for connection", err) + } + fmt.Printf("Listening on %v\n", listener.Addr()) + + nConn, err := listener.Accept() + if err != nil { + log.Fatal("failed to accept incoming connection", err) + } + + // Before use, a handshake must be performed on the incoming + // net.Conn. + _, chans, reqs, err := ssh.NewServerConn(nConn, config) + if err != nil { + log.Fatal("failed to handshake", err) + } + fmt.Fprintf(debugStream, "SSH server established\n") + + // The incoming Request channel must be serviced. + go ssh.DiscardRequests(reqs) + + // Service the incoming Channel channel. + for newChannel := range chans { + // Channels have a type, depending on the application level + // protocol intended. In the case of an SFTP session, this is "subsystem" + // with a payload string of "<length=4>sftp" + fmt.Fprintf(debugStream, "Incoming channel: %s\n", newChannel.ChannelType()) + if newChannel.ChannelType() != "session" { + newChannel.Reject(ssh.UnknownChannelType, "unknown channel type") + fmt.Fprintf(debugStream, "Unknown channel type: %s\n", newChannel.ChannelType()) + continue + } + channel, requests, err := newChannel.Accept() + if err != nil { + log.Fatal("could not accept channel.", err) + } + fmt.Fprintf(debugStream, "Channel accepted\n") + + // Sessions have out-of-band requests such as "shell", + // "pty-req" and "env". Here we handle only the + // "subsystem" request. + go func(in <-chan *ssh.Request) { + for req := range in { + fmt.Fprintf(debugStream, "Request: %v\n", req.Type) + ok := false + switch req.Type { + case "subsystem": + fmt.Fprintf(debugStream, "Subsystem: %s\n", req.Payload[4:]) + if string(req.Payload[4:]) == "sftp" { + ok = true + } + } + fmt.Fprintf(debugStream, " - accepted: %v\n", ok) + req.Reply(ok, nil) + } + }(requests) + + server, err := sftp.NewServer(channel, channel, debugStream, debugLevel, readOnly, rootDir) + if err != nil { + log.Fatal(err) + } + if err := server.Serve(); err != nil { + log.Fatal("sftp server completed with error:", err) + } + } +} + + +// MakeSSHKeyPair make a pair of public and private keys for SSH access. +// Public key is encoded in the format for inclusion in an OpenSSH authorized_keys file. +// Private Key generated is PEM encoded +func MakeSSHKeyPair(bits int, pubKeyPath, privateKeyPath string) error { + privateKey, err := rsa.GenerateKey(_rand.Reader, bits) + if err != nil { + return err + } + + // generate and write private key as PEM + privateKeyFile, err := os.Create(privateKeyPath) + defer privateKeyFile.Close() + if err != nil { + return err + } + + privateKeyPEM := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(privateKey)} + if err := pem.Encode(privateKeyFile, privateKeyPEM); err != nil { + return err + } + + // generate and write public key + pub, err := ssh.NewPublicKey(&privateKey.PublicKey) + if err != nil { + return err + } + + return ioutil.WriteFile(pubKeyPath, ssh.MarshalAuthorizedKey(pub), 0655) +} + func TestSftpCreate(t *testing.T) { - ctx, err := SftpConnect("user", "host:port") + os.Mkdir("./test", 0777) + MakeSSHKeyPair(1024, "./test/id_rsa.pub", "./test/id_rsa") + + go RunSftpServer() + time.Sleep(2 * time.Second) + + ctx, err := SftpConnect("test", "test", "localhost:2022") if err != nil { t.Fatal(err) } defer ctx.Disconnect() - var AppFs afero.Fs = afero.SftpFs{ + var AppFs Fs = SftpFs{ SftpClient: ctx.sftpc, } - file, err := AppFs.Create("aferoSftpFsTestFile") + AppFs.MkdirAll("test/dir1/dir2/dir3", os.FileMode(0777)) + AppFs.Mkdir("test/foo", os.FileMode(0000)) + AppFs.Mkdir("test/bar", os.FileMode(0777)) + + file, err := AppFs.Create("file1") if err != nil { t.Error(err) } defer file.Close() + + file.Write([]byte("whohoo\n")) + file.WriteString("sdsdsdsdsddssdsd") + + f1, err := AppFs.Open("file1") + if err != nil { + log.Fatalf("open: %v", err) + } + defer f1.Close() + + b := make([]byte, 100) + + _, err = f1.Read(b) + fmt.Println(string(b)) + + AppFs.Remove("test") }