Functionally complete JShell
[lgpl/argeo-commons.git] / org.argeo.cms.jshell / src / org / argeo / cms / jshell / JShellClient.java
index f458c5c29d4e8559a3824a57596fbedf6c9861ef..f6846f196c0192204e54acc5e3e0ec895461c9bc 100644 (file)
@@ -1,12 +1,19 @@
 package org.argeo.cms.jshell;
 
-import java.io.ByteArrayOutputStream;
-import java.io.Console;
+import static java.lang.System.Logger.Level.ERROR;
+import static java.lang.System.Logger.Level.TRACE;
+import static java.net.StandardProtocolFamily.UNIX;
+import static java.nio.charset.StandardCharsets.UTF_8;
+
+import java.io.BufferedReader;
+import java.io.File;
 import java.io.IOException;
 import java.io.InputStream;
+import java.io.InputStreamReader;
 import java.io.OutputStream;
 import java.lang.System.Logger;
-import java.net.StandardProtocolFamily;
+import java.lang.management.ManagementFactory;
+import java.net.StandardSocketOptions;
 import java.net.UnixDomainSocketAddress;
 import java.nio.ByteBuffer;
 import java.nio.channels.Channels;
@@ -17,39 +24,83 @@ import java.nio.channels.WritableByteChannel;
 import java.nio.file.Files;
 import java.nio.file.Path;
 import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 import java.util.UUID;
 
 public class JShellClient {
        private final static Logger logger = System.getLogger(JShellClient.class.getName());
 
-       public final static String STDIO = "stdio";
-       public final static String STDERR = "stderr";
-       public final static String CMDIO = "cmdio";
+       public final static String STD = "std";
+       public final static String CTL = "ctl";
 
-       private static String ttyConfig;
+       public final static String JSH = "jsh";
+       public final static String JTERM = "jterm";
 
-       public static void main(String[] args) throws IOException, InterruptedException {
+       private static String sttyExec = "/usr/bin/stty";
+
+       /** Benchmark based on uptime. */
+       private static boolean benchmark = false;
+
+       /**
+        * The real path (following symbolic links) to the directory were to create
+        * sessions.
+        */
+       private Path localBase;
+
+       /** The symbolic name of the bundle from which to run. */
+       private String symbolicName;
+
+       /** The script to run. */
+       private Path script;
+       /** Additional arguments of the script */
+       private List<String> scriptArgs;
+
+       private String ttyConfig;
+       private boolean terminal;
+
+       /** Workaround to be able to test in Eclipse console */
+       private boolean inEclipse = false;
+
+       public JShellClient(Path targetStateDirectory, String symbolicName, Path script, List<String> scriptArgs) {
                try {
-                       Path targetStateDirectory = Paths.get(args[0]);
-                       String symbolicName = args[1];
-                       Path localBase = targetStateDirectory.resolve("jsh");
+                       this.terminal = System.console() != null && script == null;
+                       if (inEclipse && script == null)
+                               terminal = true;
+                       if (terminal) {
+                               localBase = targetStateDirectory.resolve(JTERM);
+                       } else {
+                               localBase = targetStateDirectory.resolve(JSH);
+                       }
                        if (Files.isSymbolicLink(localBase)) {
                                localBase = localBase.toRealPath();
                        }
+                       this.symbolicName = symbolicName;
+                       this.script = script;
+                       this.scriptArgs = scriptArgs == null ? new ArrayList<>() : scriptArgs;
+               } catch (IOException e) {
+                       throw new IllegalStateException("Cannot initialise client", e);
+               }
+       }
 
-                       Console console = System.console();
-                       if (console != null) {
+       public void run() {
+               try {
+                       if (terminal)
                                toRawTerminal();
-                       }
+                       SocketPipeSource std = new SocketPipeSource(STD, script != null);
+                       std.setInputStream(System.in);
+                       std.setOutputStream(System.out);
 
-                       SocketPipeSource stdio = new SocketPipeSource();
-                       stdio.setInputStream(System.in);
-                       stdio.setOutputStream(System.out);
+                       SocketPipeSource ctl = new SocketPipeSource(CTL, false);
+                       ctl.setOutputStream(System.err);
 
                        Runtime.getRuntime().addShutdownHook(new Thread(() -> {
-                               // logger.log(Logger.Level.INFO, "Shutting down...");
                                System.out.println("\nShutting down...");
-                               stdio.shutdown();
+                               toOriginalTerminal();
+                               std.shutdown();
+                               ctl.shutdown();
                        }, "Shut down JShell client"));
 
                        Path bundleSnDir = localBase.resolve(symbolicName);
@@ -57,282 +108,235 @@ public class JShellClient {
                                Files.createDirectory(bundleSnDir);
                        UUID uuid = UUID.randomUUID();
                        Path sessionDir = bundleSnDir.resolve(uuid.toString());
+
+                       // creating the directory will trigger opening of the session on server side
                        Files.createDirectory(sessionDir);
-                       Path stdioPath = sessionDir.resolve(JShellClient.STDIO);
 
-                       while (!(Files.exists(stdioPath))) {
+                       Path stdPath = sessionDir.resolve(JShellClient.STD);
+                       Path ctlPath = sessionDir.resolve(JShellClient.CTL);
+
+                       while (!(Files.exists(stdPath) && Files.exists(ctlPath))) {
                                // TODO timeout
-                               Thread.sleep(50);
-
-//                             // wait for sockets to be available
-//                             WatchService watchService = FileSystems.getDefault().newWatchService();
-//                             sessionDir.register(watchService, StandardWatchEventKinds.ENTRY_CREATE);
-//                             WatchKey key;
-//                             watch: while ((key = watchService.take()) != null) {
-//                                     for (WatchEvent<?> event : key.pollEvents()) {
-//                                             Path path = sessionDir.resolve((Path) event.context());
-//                                             if (Files.isSameFile(stdioPath, path)) {
-//                                                     break watch;
-//                                             }
-//                                     }
-//                             }
-//                             watchService.close();
+                               try {
+                                       Thread.sleep(1);
+                               } catch (InterruptedException e) {
+                                       // silent
+                               }
                        }
 
-                       UnixDomainSocketAddress stdioSocketAddress = UnixDomainSocketAddress.of(stdioPath.toRealPath());
+                       UnixDomainSocketAddress stdSocketAddress = UnixDomainSocketAddress.of(stdPath.toRealPath());
+                       UnixDomainSocketAddress ctlSocketAddress = UnixDomainSocketAddress.of(ctlPath.toRealPath());
 
-                       try (SocketChannel stdioChannel = SocketChannel.open(StandardProtocolFamily.UNIX)) {
-                               stdioChannel.connect(stdioSocketAddress);
-                               stdio.forward(stdioChannel);
-                       }
+                       try (SocketChannel stdChannel = SocketChannel.open(UNIX);
+                                       SocketChannel ctlChannel = SocketChannel.open(UNIX);) {
+                               ctlChannel.connect(ctlSocketAddress);
+                               ctl.process(ctlChannel);
+                               if (script != null) {
+                                       new ScriptThread(ctlChannel).start();
+                               }
+                               stdChannel.connect(stdSocketAddress);
+                               std.process(stdChannel);
 
-               } catch (IOException | InterruptedException e) {
-                       // TODO Auto-generated catch block
+                               while (!std.isCompleted() && !ctl.isCompleted()) {
+                                       // isCompleted() will block
+                               }
+                       }
+                       if (benchmark)
+                               System.err.println(ManagementFactory.getRuntimeMXBean().getUptime());
+                       std.shutdown();
+                       ctl.shutdown();
+               } catch (IOException e) {
                        e.printStackTrace();
                } finally {
-                       if (ttyConfig != null)
-                               try {
-                                       stty(ttyConfig.trim());
-                               } catch (Exception e) {
-                                       System.err.println("Exception restoring tty config");
-                               }
+                       toOriginalTerminal();
                }
 
        }
 
-       private static void toRawTerminal() throws IOException, InterruptedException {
+       public static void main(String[] args) throws IOException, InterruptedException {
+               if (benchmark)
+                       System.err.println(ManagementFactory.getRuntimeMXBean().getUptime());
+               List<String> plainArgs = new ArrayList<>();
+               Map<String, List<String>> options = new HashMap<>();
+               String currentOption = null;
+               for (int i = 0; i < args.length; i++) {
+                       if (args[i].startsWith("-")) {
+                               currentOption = args[i];
+                               if (!options.containsKey(currentOption))
+                                       options.put(currentOption, new ArrayList<>());
+                               i++;
+                               options.get(currentOption).add(args[i]);
+                       } else {
+                               plainArgs.add(args[i]);
+                       }
+               }
 
-               ttyConfig = stty("-g");
+               Path targetStateDirectory = Paths.get(options.get("-d").get(0));
+               String symbolicName = options.get("-b").get(0);
+
+               Path script = plainArgs.isEmpty() ? null : Paths.get(plainArgs.get(0));
+               List<String> scriptArgs = new ArrayList<>();
+               for (int i = 1; i < plainArgs.size(); i++)
+                       scriptArgs.add(plainArgs.get(i));
+
+               JShellClient client = new JShellClient(targetStateDirectory, symbolicName, script, scriptArgs);
+               client.run();
+       }
 
+       /*
+        * TERMINAL
+        */
+       /** Set the terminal to raw mode. */
+       protected synchronized void toRawTerminal() {
+               boolean isWindows = File.separatorChar == '\\';
+               if (isWindows)
+                       return;
+               if (inEclipse)
+                       return;
+               // save current configuration
+               ttyConfig = stty("-g");
+               if (ttyConfig == null)
+                       return;
+               ttyConfig.trim();
                // set the console to be character-buffered instead of line-buffered
                stty("-icanon min 1");
-
                // disable character echoing
                stty("-echo");
-
-               Runtime.getRuntime().addShutdownHook(new Thread(() -> toOriginalTerminal(), "Reset terminal"));
        }
 
-       private static void toOriginalTerminal() {
-               if (ttyConfig != null)
-                       try {
-                               stty(ttyConfig.trim());
-                       } catch (Exception e) {
-                               System.err.println("Exception restoring tty config");
-                       }
+       /** Restore original terminal configuration. */
+       protected synchronized void toOriginalTerminal() {
+               if (ttyConfig == null)
+                       return;
+               try {
+                       stty(ttyConfig);
+               } catch (Exception e) {
+                       e.printStackTrace();
+               }
+               ttyConfig = null;
        }
 
        /**
         * Execute the stty command with the specified arguments against the current
         * active terminal.
         */
-       private static String stty(final String args) throws IOException, InterruptedException {
-               String cmd = "stty " + args + " < /dev/tty";
+       protected String stty(String args) {
+               List<String> cmd = new ArrayList<>();
+               cmd.add("/bin/sh");
+               cmd.add("-c");
+               cmd.add(sttyExec + " " + args + " < /dev/tty");
 
-               return exec(new String[] { "sh", "-c", cmd });
-       }
-
-       /**
-        * Execute the specified command and return the output (both stdout and stderr).
-        */
-       private static String exec(final String[] cmd) throws IOException, InterruptedException {
-               ByteArrayOutputStream bout = new ByteArrayOutputStream();
-
-               Process p = Runtime.getRuntime().exec(cmd);
-               int c;
-               InputStream in = p.getInputStream();
+               logger.log(TRACE, () -> cmd.toString());
 
-               while ((c = in.read()) != -1) {
-                       bout.write(c);
+               try {
+                       ProcessBuilder pb = new ProcessBuilder(cmd);
+                       Process p = pb.start();
+                       String firstLine = new BufferedReader(new InputStreamReader(p.getInputStream())).readLine();
+                       p.waitFor();
+                       logger.log(TRACE, () -> firstLine);
+                       return firstLine;
+               } catch (IOException | InterruptedException e) {
+                       e.printStackTrace();
+                       return null;
                }
+       }
 
-               in = p.getErrorStream();
+       /*
+        * SCRIPT
+        */
+       private class ScriptThread extends Thread {
+               private SocketChannel channel;
 
-               while ((c = in.read()) != -1) {
-                       bout.write(c);
+               public ScriptThread(SocketChannel channel) {
+                       super("JShell script writer");
+                       this.channel = channel;
                }
 
-               p.waitFor();
+               @Override
+               public void run() {
+                       try {
+                               if (benchmark)
+                                       System.err.println(ManagementFactory.getRuntimeMXBean().getUptime());
+                               StringBuilder sb = new StringBuilder();
+//                             sb.append("/set feedback silent\n");
+                               if (!scriptArgs.isEmpty()) {
+                                       // additional arguments as $1, $2, etc.
+                                       for (String arg : scriptArgs)
+                                               sb.append('\"').append(arg).append('\"').append(";\n");
+                               }
+                               if (sb.length() > 0)
+                                       writeLine(sb);
 
-               String result = new String(bout.toByteArray());
-               return result;
-       }
+                               ByteBuffer buffer = ByteBuffer.allocate(1024);
+                               try (BufferedReader reader = Files.newBufferedReader(script)) {
+                                       String line;
+                                       lines: while ((line = reader.readLine()) != null) {
+                                               if (line.startsWith("#"))
+                                                       continue lines;
+                                               buffer.put((line + "\n").getBytes(UTF_8));
+                                               buffer.flip();
+                                               channel.write(buffer);
+                                               buffer.rewind();
+                                       }
+                               }
 
-//     void pipe() throws IOException {
-//             // Set up Server Socket and bind to the port 8000
-//             ServerSocketChannel server = ServerSocketChannel.open();
-//             SocketAddress endpoint = new InetSocketAddress(8000);
-//             server.socket().bind(endpoint);
-//
-//             server.configureBlocking(false);
-//
-//             // Set up selector so we can run with a single thread but multiplex between 2
-//             // channels
-//             Selector selector = Selector.open();
-//             server.register(selector, SelectionKey.OP_ACCEPT);
-//
-//             ByteBuffer buffer = ByteBuffer.allocate(1024);
-//
-//             while (true) {
-//                     // block until data comes in
-//                     selector.select();
-//
-//                     Set<SelectionKey> keys = selector.selectedKeys();
-//
-//                     for (SelectionKey key : keys) {
-//                             if (!key.isValid()) {
-//                                     // not valid or writable so skip
-//                                     continue;
-//                             }
-//
-//                             if (key.isAcceptable()) {
-//                                     // Accept socket channel for client connection
-//                                     ServerSocketChannel channel = (ServerSocketChannel) key.channel();
-//                                     SocketChannel accept = channel.accept();
-//                                     setupConnection(selector, accept);
-//                             } else if (key.isReadable()) {
-//                                     try {
-//                                             // Read into the buffer from the socket and then write the buffer into the
-//                                             // attached socket.
-//                                             SocketChannel recv = (SocketChannel) key.channel();
-//                                             SocketChannel send = (SocketChannel) key.attachment();
-//                                             recv.read(buffer);
+//                             ByteBuffer buffer = ByteBuffer.allocate(1024);
+//                             try (SeekableByteChannel scriptChannel = Files.newByteChannel(script, StandardOpenOption.READ)) {
+//                                     while (channel.isConnected()) {
+//                                             if (scriptChannel.read(buffer) < 0)
+//                                                     break;
 //                                             buffer.flip();
-//                                             send.write(buffer);
+//                                             channel.write(buffer);
 //                                             buffer.rewind();
-//                                     } catch (IOException e) {
-//                                             e.printStackTrace();
-//
-//                                             // Close sockets
-//                                             if (key.channel() != null)
-//                                                     key.channel().close();
-//                                             if (key.attachment() != null)
-//                                                     ((SocketChannel) key.attachment()).close();
-//                                     }
-//                             } else if (key.isWritable()) {
-//
-//                             }
-//                     }
-//
-//                     // Clear keys for next select
-//                     keys.clear();
-//             }
-//
-//     }
-
-//     public static void mainX(String[] args) throws IOException, InterruptedException {
-//             toRawTerminal();
-//             try {
-//                     boolean client = true;
-//                     if (client) {
-//                             ReadableByteChannel inChannel;
-//                             WritableByteChannel outChannel;
-//                             inChannel = Channels.newChannel(System.in);
-//                             outChannel = Channels.newChannel(System.out);
-//
-//                             SocketChannel channel = SocketChannel.open(StandardProtocolFamily.UNIX);
-//                             channel.connect(ioSocketAddress());
-//
-//                             new Thread(() -> {
-//
-//                                     try {
-//                                             ByteBuffer buffer = ByteBuffer.allocate(1024);
-//                                             while (true) {
-//                                                     if (channel.read(buffer) < 0)
-//                                                             break;
-//                                                     buffer.flip();
-//                                                     outChannel.write(buffer);
-//                                                     buffer.rewind();
-//                                             }
-//                                             System.exit(0);
-//                                     } catch (IOException e) {
-//                                             e.printStackTrace();
 //                                     }
-//                             }, "Read out").start();
-//
-//                             ByteBuffer buffer = ByteBuffer.allocate(1);
-//                             while (channel.isConnected()) {
-//                                     if (inChannel.read(buffer) < 0)
-//                                             break;
-//                                     buffer.flip();
-//                                     channel.write(buffer);
-//                                     buffer.rewind();
-//                             }
-//
-//                     } else {
-//                             ServerSocketChannel serverChannel = ServerSocketChannel.open(StandardProtocolFamily.UNIX);
-//                             serverChannel.bind(ioSocketAddress());
-//
-//                             SocketChannel channel = serverChannel.accept();
-//
-//                             while (true) {
-//                                     readSocketMessage(channel).ifPresent(message -> System.out.printf("[Client message] %s", message));
-//                                     Thread.sleep(100);
 //                             }
-//                     }
-//             } finally {
-//                     toOriginalTerminal();
-//             }
-//     }
-//
-//     private static Optional<String> readSocketMessage(SocketChannel channel) throws IOException {
-//             ByteBuffer buffer = ByteBuffer.allocate(1024);
-//             int bytesRead = channel.read(buffer);
-//             if (bytesRead < 0)
-//                     return Optional.empty();
-//
-//             byte[] bytes = new byte[bytesRead];
-//             buffer.flip();
-//             buffer.get(bytes);
-//             String message = new String(bytes);
-//             return Optional.of(message);
-//     }
-//
-//     public static void setupConnection(Selector selector, SocketChannel client) throws IOException {
-//             // Connect to the remote server
-//             SocketAddress address = new InetSocketAddress("192.168.1.74", 8000);
-//             SocketChannel remote = SocketChannel.open(address);
-//
-//             // Make sockets non-blocking (should be better performance)
-//             client.configureBlocking(false);
-//             remote.configureBlocking(false);
-//
-//             client.register(selector, SelectionKey.OP_READ, remote);
-//             remote.register(selector, SelectionKey.OP_READ, client);
-//     }
-//
-//     static UnixDomainSocketAddress ioSocketAddress() throws IOException {
-//             String system = "default";
-//             String bundleSn = "org.argeo.slc.jshell";
-//
-//             String xdgRunDir = System.getenv("XDG_RUNTIME_DIR");
-//             Path baseRunDir = Paths.get(xdgRunDir);
-//             Path jshellSocketBase = baseRunDir.resolve("jshell").resolve(system).resolve(bundleSn);
-//
-//             Files.createDirectories(jshellSocketBase);
-//
-//             Path ioSocketPath = jshellSocketBase.resolve("io");
-//
-//             UnixDomainSocketAddress ioSocketAddress = UnixDomainSocketAddress.of(ioSocketPath);
-//             System.out.println(ioSocketAddress);
-//             return ioSocketAddress;
-//     }
 
+                               // exit
+                               if (channel.isConnected())
+                                       writeLine("/exit");
+                       } catch (IOException e) {
+                               logger.log(ERROR, "Cannot execute " + script, e);
+                       }
+               }
+
+               private void writeLine(Object obj) throws IOException {
+                       channel.write(ByteBuffer.wrap((obj + "\n").getBytes(UTF_8)));
+               }
+       }
 }
 
+/** Pipe streams to a channel. */
 class SocketPipeSource {
        private ReadableByteChannel inChannel;
        private WritableByteChannel outChannel;
 
-       private Thread readOutThread;
+       private Thread readThread;
        private Thread forwardThread;
 
-       public void forward(SocketChannel channel) throws IOException {
-               forwardThread = Thread.currentThread();
-               readOutThread = new Thread(() -> {
+       private int inBufferSize = 1;
+       private int outBufferSize = 1024;
+
+       private final String id;
+       private final boolean batch;
+
+       private boolean completed = false;
+
+       public SocketPipeSource(String id, boolean batch) {
+               this.id = id;
+               this.batch = batch;
+       }
+
+       public void process(SocketChannel channel) throws IOException {
+               if (batch) {
+                       Integer socketRcvBuf = channel.getOption(StandardSocketOptions.SO_RCVBUF);
+                       inBufferSize = socketRcvBuf;
+                       outBufferSize = socketRcvBuf;
+               }
+
+               readThread = new Thread(() -> {
 
                        try {
-                               ByteBuffer buffer = ByteBuffer.allocate(1024);
+                               ByteBuffer buffer = ByteBuffer.allocate(outBufferSize);
                                while (true) {
                                        if (channel.read(buffer) < 0)
                                                break;
@@ -340,14 +344,14 @@ class SocketPipeSource {
                                        outChannel.write(buffer);
                                        buffer.rewind();
                                }
-                               System.exit(0);
                        } catch (ClosedByInterruptException e) {
                                // silent
                        } catch (IOException e) {
                                e.printStackTrace();
                        }
-               }, "Read out");
-               readOutThread.start();
+                       markCompleted();
+               }, "JShell read " + id);
+               readThread.start();
 
                // TODO make it smarter than a 1 byte buffer
                // we should recognize control characters
@@ -357,43 +361,71 @@ class SocketPipeSource {
 //                     break;
 //             }
 
-               ByteBuffer buffer = ByteBuffer.allocate(1);
-               while (channel.isConnected()) {
-                       if (inChannel.read(buffer) < 0)
-                               break;
+               if (inChannel != null) {
+                       forwardThread = new Thread(() -> {
+                               try {
+                                       ByteBuffer buffer = ByteBuffer.allocate(inBufferSize);
+                                       while (channel.isConnected()) {
+                                               if (inChannel.read(buffer) < 0)
+                                                       break;
 //                     int b = (int) buffer.get(0);
 //                     if (b == 0x1B) {
 //                             System.out.println("Ctrl+C");
 //                     }
 
-                       buffer.flip();
-                       channel.write(buffer);
-                       buffer.rewind();
-               }
+                                               buffer.flip();
+                                               channel.write(buffer);
+                                               buffer.rewind();
+                                       }
+                               } catch (IOException e) {
+                                       e.printStackTrace();
+                               }
+                       }, "JShell write " + id);
+                       forwardThread.setDaemon(true);
+                       forwardThread.start();
+                       // end
+                       // TODO make it more robust
+                       // we want to be asynchronous when read only
+//                     try {
+//                             // TODO add timeout
+//                             readThread.join();
+//                     } catch (InterruptedException e) {
+//                             e.printStackTrace();
+//                     }
 
-               // end
-               // TODO make it more robust
-               try {
-                       // TODO add timeout
-                       readOutThread.join();
-               } catch (InterruptedException e) {
-                       e.printStackTrace();
                }
        }
 
+       public synchronized boolean isCompleted() {
+               if (!completed)
+                       try {
+                               wait();
+                       } catch (InterruptedException e) {
+                               // silent
+                       }
+               return completed;
+       }
+
+       protected synchronized void markCompleted() {
+               completed = true;
+               notifyAll();
+       }
+
        public void shutdown() {
-               try {
-                       inChannel.close();
-               } catch (IOException e) {
-                       e.printStackTrace();
-               }
+               if (inChannel != null)
+                       try {
+                               inChannel.close();
+                       } catch (IOException e) {
+                               e.printStackTrace();
+                       }
                try {
                        outChannel.close();
                } catch (IOException e) {
                        e.printStackTrace();
                }
-               forwardThread.interrupt();
-               readOutThread.interrupt();
+               if (inChannel != null)
+                       forwardThread.interrupt();
+               readThread.interrupt();
        }
 
        public void setInputStream(InputStream in) {