]> git.argeo.org Git - lgpl/argeo-commons.git/blob - AbstractSsh.java
f2525bff850b0c914709309f4389d0b2679a7671
[lgpl/argeo-commons.git] / AbstractSsh.java
1 package org.argeo.cms.ssh;
2
3 import java.io.Console;
4 import java.io.IOException;
5 import java.net.URI;
6 import java.net.URISyntaxException;
7 import java.util.Arrays;
8 import java.util.HashSet;
9 import java.util.Scanner;
10 import java.util.Set;
11
12 import org.apache.sshd.client.SshClient;
13 import org.apache.sshd.client.channel.ClientChannel;
14 import org.apache.sshd.client.channel.ClientChannelEvent;
15 import org.apache.sshd.client.future.ConnectFuture;
16 import org.apache.sshd.client.session.ClientSession;
17 import org.apache.sshd.common.util.io.input.NoCloseInputStream;
18 import org.apache.sshd.common.util.io.output.NoCloseOutputStream;
19 import org.apache.sshd.sftp.client.fs.SftpFileSystemProvider;
20 import org.argeo.api.cms.CmsLog;
21
22 @SuppressWarnings("restriction")
23 public abstract class AbstractSsh {
24 private final static CmsLog log = CmsLog.getLog(AbstractSsh.class);
25
26 private static SshClient sshClient;
27 private static SftpFileSystemProvider sftpFileSystemProvider;
28
29 private boolean passwordSet = false;
30 private ClientSession session;
31
32 private SshKeyPair sshKeyPair;
33
34 public synchronized SshClient getSshClient() {
35 if (sshClient == null) {
36 long begin = System.currentTimeMillis();
37 sshClient = SshClient.setUpDefaultClient();
38 sshClient.start();
39 long duration = System.currentTimeMillis() - begin;
40 if (log.isDebugEnabled())
41 log.debug("SSH client started in " + duration + " ms");
42 Runtime.getRuntime().addShutdownHook(new Thread(() -> sshClient.stop(), "Stop SSH client"));
43 }
44 return sshClient;
45 }
46
47 synchronized SftpFileSystemProvider getSftpFileSystemProvider() {
48 if (sftpFileSystemProvider == null) {
49 sftpFileSystemProvider = new SftpFileSystemProvider(sshClient);
50 }
51 return sftpFileSystemProvider;
52 }
53
54 public void authenticate() {
55 if (sshKeyPair != null) {
56 session.addPublicKeyIdentity(sshKeyPair.asKeyPair());
57 } else {
58
59 if (!passwordSet) {
60 String password;
61 Console console = System.console();
62 if (console == null) {// IDE
63 System.out.print("Password: ");
64 try (Scanner s = new Scanner(System.in)) {
65 password = s.next();
66 }
67 } else {
68 console.printf("Password: ");
69 char[] pwd = console.readPassword();
70 password = new String(pwd);
71 Arrays.fill(pwd, ' ');
72 }
73 session.addPasswordIdentity(password);
74 passwordSet = true;
75 }
76 }
77 verifyAuth();
78 }
79
80 public void verifyAuth() {
81 try {
82 session.auth().verify(1000l);
83 } catch (IOException e) {
84 throw new IllegalStateException("Cannot verify auth", e);
85 }
86 }
87
88 public static char[] readPassword() {
89 Console console = System.console();
90 if (console == null) {// IDE
91 System.out.print("Password: ");
92 try (Scanner s = new Scanner(System.in)) {
93 String password = s.next();
94 return password.toCharArray();
95 }
96 } else {
97 console.printf("Password: ");
98 char[] pwd = console.readPassword();
99 return pwd;
100 }
101 }
102
103 void addPassword(String password) {
104 session.addPasswordIdentity(password);
105 }
106
107 void loadKey(String password) {
108 loadKey(password, System.getProperty("user.home") + "/.ssh/id_rsa");
109 }
110
111 void loadKey(String password, String keyPath) {
112 // try {
113 // KeyPair keyPair = ClientIdentityLoader.DEFAULT.loadClientIdentity(keyPath,
114 // FilePasswordProvider.of(password));
115 // session.addPublicKeyIdentity(keyPair);
116 // } catch (IOException | GeneralSecurityException e) {
117 // throw new IllegalStateException(e);
118 // }
119 }
120
121 void openSession(URI uri) {
122 openSession(uri.getUserInfo(), uri.getHost(), uri.getPort() > 0 ? uri.getPort() : null);
123 }
124
125 void openSession(String login, String host, Integer port) {
126 if (session != null)
127 throw new IllegalStateException("Session is already open");
128
129 if (host == null)
130 host = "localhost";
131 if (port == null)
132 port = 22;
133 if (login == null)
134 login = System.getProperty("user.name");
135 String password = null;
136 int sepIndex = login.indexOf(':');
137 if (sepIndex > 0)
138 if (sepIndex + 1 < login.length()) {
139 password = login.substring(sepIndex + 1);
140 login = login.substring(0, sepIndex);
141 } else {
142 throw new IllegalArgumentException("Illegal authority: " + login);
143 }
144 try {
145 ConnectFuture connectFuture = getSshClient().connect(login, host, port);
146 connectFuture.await();
147 ClientSession session = connectFuture.getSession();
148 if (password != null) {
149 session.addPasswordIdentity(password);
150 passwordSet = true;
151 }
152 this.session = session;
153 } catch (IOException e) {
154 throw new IllegalStateException("Cannot connect to " + host + ":" + port);
155 }
156 }
157
158 public void closeSession() {
159 if (session == null)
160 throw new IllegalStateException("No session is open");
161 try {
162 session.close();
163 } catch (IOException e) {
164 e.printStackTrace();
165 } finally {
166 session = null;
167 }
168 }
169
170 ClientSession getSession() {
171 return session;
172 }
173
174 public void setSshKeyPair(SshKeyPair sshKeyPair) {
175 this.sshKeyPair = sshKeyPair;
176 }
177
178 public static void openShell(AbstractSsh ssh) {
179 openShell(ssh.getSession());
180 }
181
182 public static void openShell(ClientSession session) {
183 try (ClientChannel channel = session.createChannel(ClientChannel.CHANNEL_SHELL)) {
184 channel.setIn(new NoCloseInputStream(System.in));
185 channel.setOut(new NoCloseOutputStream(System.out));
186 channel.setErr(new NoCloseOutputStream(System.err));
187 channel.open();
188
189 Set<ClientChannelEvent> events = new HashSet<>();
190 events.add(ClientChannelEvent.CLOSED);
191 channel.waitFor(events, 0);
192 } catch (IOException e) {
193 // TODO Auto-generated catch block
194 e.printStackTrace();
195 } finally {
196 session.close(false);
197 }
198 }
199
200 static URI toUri(String username, String host, int port) {
201 try {
202 if (username == null)
203 username = "root";
204 return new URI("ssh://" + username + "@" + host + ":" + port);
205 } catch (URISyntaxException e) {
206 throw new IllegalArgumentException("Cannot generate SSH URI to " + host + ":" + port + " for " + username,
207 e);
208 }
209 }
210
211 }