Improve SSH client.
[lgpl/argeo-commons.git] / org.argeo.core / src / org / argeo / ssh / SshKeyPair.java
index 8feca58f95048693c01647819cb7589cce446b54..f9b34859885b1e077d0b91a1846a2426e8ab802b 100644 (file)
@@ -1,8 +1,14 @@
 package org.argeo.ssh;
 
+import java.io.IOException;
+import java.io.InputStreamReader;
 import java.io.Reader;
 import java.io.StringReader;
 import java.io.StringWriter;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
 import java.security.GeneralSecurityException;
 import java.security.KeyFactory;
 import java.security.KeyPair;
@@ -28,6 +34,8 @@ import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo;
 
 @SuppressWarnings("restriction")
 public class SshKeyPair {
+       public final static String RSA_KEY_TYPE = "ssh-rsa";
+
        private PublicKey publicKey;
        private PrivateKey privateKey;
        private KeyPair keyPair;
@@ -54,7 +62,7 @@ public class SshKeyPair {
                return PublicKeyEntry.toString(publicKey);
        }
 
-       public String getPrivateKeyAsString(char[] password) {
+       public String getPrivateKeyAsPemString(char[] password) {
                try {
                        Object obj;
 
@@ -79,9 +87,36 @@ public class SshKeyPair {
                }
        }
 
+       public static SshKeyPair loadOrGenerate(Path privateKeyPath, int size, char[] password) {
+               try {
+                       SshKeyPair sshKeyPair;
+                       if (Files.exists(privateKeyPath)) {
+//                             String privateKeyStr = new String(Files.readAllBytes(privateKeyPath), StandardCharsets.US_ASCII);
+                               sshKeyPair = load(
+                                               new InputStreamReader(Files.newInputStream(privateKeyPath), StandardCharsets.US_ASCII),
+                                               password);
+                               // TOD make sure public key is consistemt
+                       } else {
+                               sshKeyPair = generate(size);
+                               Files.write(privateKeyPath,
+                                               sshKeyPair.getPrivateKeyAsPemString(password).getBytes(StandardCharsets.US_ASCII));
+                               Path publicKeyPath = privateKeyPath.resolveSibling(privateKeyPath.getFileName() + ".pub");
+                               Files.write(publicKeyPath,
+                                               sshKeyPair.getPublicKeyAsOpenSshString().getBytes(StandardCharsets.US_ASCII));
+                       }
+                       return sshKeyPair;
+               } catch (IOException e) {
+                       throw new RuntimeException("Cannot read or write private key " + privateKeyPath, e);
+               }
+       }
+
        public static SshKeyPair generate(int size) {
+               return generate(RSA_KEY_TYPE, size);
+       }
+
+       public static SshKeyPair generate(String keyType, int size) {
                try {
-                       KeyPair keyPair = KeyUtils.generateKeyPair("ssh-rsa", size);
+                       KeyPair keyPair = KeyUtils.generateKeyPair(keyType, size);
                        PublicKey publicKey = keyPair.getPublic();
                        PrivateKey privateKey = keyPair.getPrivate();
                        return new SshKeyPair(publicKey, privateKey);
@@ -93,7 +128,7 @@ public class SshKeyPair {
        public static SshKeyPair load(Reader reader, char[] password) {
                try (PEMParser pemParser = new PEMParser(reader)) {
                        Object object = pemParser.readObject();
-                       JcaPEMKeyConverter converter = new JcaPEMKeyConverter().setProvider("BC");
+                       JcaPEMKeyConverter converter = new JcaPEMKeyConverter();// .setProvider("BC");
                        KeyPair kp;
                        if (object instanceof PKCS8EncryptedPrivateKeyInfo) {
                                // Encrypted key - we will use provided password
@@ -125,22 +160,23 @@ public class SshKeyPair {
        }
 
        public static void main(String args[]) {
-               SshKeyPair okp = SshKeyPair.generate(1024);
-               System.out.println("Public:\n" + okp.getPublicKeyAsOpenSshString());
-               System.out.println("Private (plain):\n" + okp.getPrivateKeyAsString(null));
-               System.out.println("Private (encrypted):\n" + okp.getPrivateKeyAsString("demo".toCharArray()));
-
-               StringReader reader = new StringReader(okp.getPrivateKeyAsString(null));
-               okp = SshKeyPair.load(reader, null);
-               System.out.println("Public:\n" + okp.getPublicKeyAsOpenSshString());
-               System.out.println("Private (plain):\n" + okp.getPrivateKeyAsString(null));
-               System.out.println("Private (encrypted):\n" + okp.getPrivateKeyAsString("demo".toCharArray()));
-
-               reader = new StringReader(okp.getPrivateKeyAsString("demo".toCharArray()));
-               okp = SshKeyPair.load(reader, "demo".toCharArray());
-               System.out.println("Public:\n" + okp.getPublicKeyAsOpenSshString());
-               System.out.println("Private (plain):\n" + okp.getPrivateKeyAsString(null));
-               System.out.println("Private (encrypted):\n" + okp.getPrivateKeyAsString("demo".toCharArray()));
+               Path privateKeyPath = Paths.get(System.getProperty("user.dir") + "/id_rsa");
+               SshKeyPair skp = SshKeyPair.loadOrGenerate(privateKeyPath, 1024, null);
+               System.out.println("Public:\n" + skp.getPublicKeyAsOpenSshString());
+               System.out.println("Private (plain):\n" + skp.getPrivateKeyAsPemString(null));
+               System.out.println("Private (encrypted):\n" + skp.getPrivateKeyAsPemString("demo".toCharArray()));
+
+               StringReader reader = new StringReader(skp.getPrivateKeyAsPemString(null));
+               skp = SshKeyPair.load(reader, null);
+               System.out.println("Public:\n" + skp.getPublicKeyAsOpenSshString());
+               System.out.println("Private (plain):\n" + skp.getPrivateKeyAsPemString(null));
+               System.out.println("Private (encrypted):\n" + skp.getPrivateKeyAsPemString("demo".toCharArray()));
+
+               reader = new StringReader(skp.getPrivateKeyAsPemString("demo".toCharArray()));
+               skp = SshKeyPair.load(reader, "demo".toCharArray());
+               System.out.println("Public:\n" + skp.getPublicKeyAsOpenSshString());
+               System.out.println("Private (plain):\n" + skp.getPrivateKeyAsPemString(null));
+               System.out.println("Private (encrypted):\n" + skp.getPrivateKeyAsPemString("demo".toCharArray()));
        }
 
 }