X-Git-Url: https://git.argeo.org/?a=blobdiff_plain;ds=inline;f=org.argeo.cms.lib.sshd%2Fsrc%2Forg%2Fargeo%2Fcms%2Fssh%2FAbstractSsh.java;fp=org.argeo.cms.lib.sshd%2Fsrc%2Forg%2Fargeo%2Fcms%2Fssh%2FAbstractSsh.java;h=f2525bff850b0c914709309f4389d0b2679a7671;hb=4e548693acc16f97b74eaaa95d6841054a172b85;hp=0000000000000000000000000000000000000000;hpb=7b242851c0094d13cbaca5b68261ad92c873a59f;p=lgpl%2Fargeo-commons.git diff --git a/org.argeo.cms.lib.sshd/src/org/argeo/cms/ssh/AbstractSsh.java b/org.argeo.cms.lib.sshd/src/org/argeo/cms/ssh/AbstractSsh.java new file mode 100644 index 000000000..f2525bff8 --- /dev/null +++ b/org.argeo.cms.lib.sshd/src/org/argeo/cms/ssh/AbstractSsh.java @@ -0,0 +1,211 @@ +package org.argeo.cms.ssh; + +import java.io.Console; +import java.io.IOException; +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Scanner; +import java.util.Set; + +import org.apache.sshd.client.SshClient; +import org.apache.sshd.client.channel.ClientChannel; +import org.apache.sshd.client.channel.ClientChannelEvent; +import org.apache.sshd.client.future.ConnectFuture; +import org.apache.sshd.client.session.ClientSession; +import org.apache.sshd.common.util.io.input.NoCloseInputStream; +import org.apache.sshd.common.util.io.output.NoCloseOutputStream; +import org.apache.sshd.sftp.client.fs.SftpFileSystemProvider; +import org.argeo.api.cms.CmsLog; + +@SuppressWarnings("restriction") +public abstract class AbstractSsh { + private final static CmsLog log = CmsLog.getLog(AbstractSsh.class); + + private static SshClient sshClient; + private static SftpFileSystemProvider sftpFileSystemProvider; + + private boolean passwordSet = false; + private ClientSession session; + + private SshKeyPair sshKeyPair; + + public synchronized SshClient getSshClient() { + if (sshClient == null) { + long begin = System.currentTimeMillis(); + sshClient = SshClient.setUpDefaultClient(); + sshClient.start(); + long duration = System.currentTimeMillis() - begin; + if (log.isDebugEnabled()) + log.debug("SSH client started in " + duration + " ms"); + Runtime.getRuntime().addShutdownHook(new Thread(() -> sshClient.stop(), "Stop SSH client")); + } + return sshClient; + } + + synchronized SftpFileSystemProvider getSftpFileSystemProvider() { + if (sftpFileSystemProvider == null) { + sftpFileSystemProvider = new SftpFileSystemProvider(sshClient); + } + return sftpFileSystemProvider; + } + + public void authenticate() { + if (sshKeyPair != null) { + session.addPublicKeyIdentity(sshKeyPair.asKeyPair()); + } else { + + if (!passwordSet) { + String password; + Console console = System.console(); + if (console == null) {// IDE + System.out.print("Password: "); + try (Scanner s = new Scanner(System.in)) { + password = s.next(); + } + } else { + console.printf("Password: "); + char[] pwd = console.readPassword(); + password = new String(pwd); + Arrays.fill(pwd, ' '); + } + session.addPasswordIdentity(password); + passwordSet = true; + } + } + verifyAuth(); + } + + public void verifyAuth() { + try { + session.auth().verify(1000l); + } catch (IOException e) { + throw new IllegalStateException("Cannot verify auth", e); + } + } + + public static char[] readPassword() { + Console console = System.console(); + if (console == null) {// IDE + System.out.print("Password: "); + try (Scanner s = new Scanner(System.in)) { + String password = s.next(); + return password.toCharArray(); + } + } else { + console.printf("Password: "); + char[] pwd = console.readPassword(); + return pwd; + } + } + + void addPassword(String password) { + session.addPasswordIdentity(password); + } + + void loadKey(String password) { + loadKey(password, System.getProperty("user.home") + "/.ssh/id_rsa"); + } + + void loadKey(String password, String keyPath) { +// try { +// KeyPair keyPair = ClientIdentityLoader.DEFAULT.loadClientIdentity(keyPath, +// FilePasswordProvider.of(password)); +// session.addPublicKeyIdentity(keyPair); +// } catch (IOException | GeneralSecurityException e) { +// throw new IllegalStateException(e); +// } + } + + void openSession(URI uri) { + openSession(uri.getUserInfo(), uri.getHost(), uri.getPort() > 0 ? uri.getPort() : null); + } + + void openSession(String login, String host, Integer port) { + if (session != null) + throw new IllegalStateException("Session is already open"); + + if (host == null) + host = "localhost"; + if (port == null) + port = 22; + if (login == null) + login = System.getProperty("user.name"); + String password = null; + int sepIndex = login.indexOf(':'); + if (sepIndex > 0) + if (sepIndex + 1 < login.length()) { + password = login.substring(sepIndex + 1); + login = login.substring(0, sepIndex); + } else { + throw new IllegalArgumentException("Illegal authority: " + login); + } + try { + ConnectFuture connectFuture = getSshClient().connect(login, host, port); + connectFuture.await(); + ClientSession session = connectFuture.getSession(); + if (password != null) { + session.addPasswordIdentity(password); + passwordSet = true; + } + this.session = session; + } catch (IOException e) { + throw new IllegalStateException("Cannot connect to " + host + ":" + port); + } + } + + public void closeSession() { + if (session == null) + throw new IllegalStateException("No session is open"); + try { + session.close(); + } catch (IOException e) { + e.printStackTrace(); + } finally { + session = null; + } + } + + ClientSession getSession() { + return session; + } + + public void setSshKeyPair(SshKeyPair sshKeyPair) { + this.sshKeyPair = sshKeyPair; + } + + public static void openShell(AbstractSsh ssh) { + openShell(ssh.getSession()); + } + + public static void openShell(ClientSession session) { + try (ClientChannel channel = session.createChannel(ClientChannel.CHANNEL_SHELL)) { + channel.setIn(new NoCloseInputStream(System.in)); + channel.setOut(new NoCloseOutputStream(System.out)); + channel.setErr(new NoCloseOutputStream(System.err)); + channel.open(); + + Set events = new HashSet<>(); + events.add(ClientChannelEvent.CLOSED); + channel.waitFor(events, 0); + } catch (IOException e) { + // TODO Auto-generated catch block + e.printStackTrace(); + } finally { + session.close(false); + } + } + + static URI toUri(String username, String host, int port) { + try { + if (username == null) + username = "root"; + return new URI("ssh://" + username + "@" + host + ":" + port); + } catch (URISyntaxException e) { + throw new IllegalArgumentException("Cannot generate SSH URI to " + host + ":" + port + " for " + username, + e); + } + } + +}