WebSocket authentication
[lgpl/argeo-commons.git] / org.argeo.cms.ee / src / org / argeo / cms / websocket / server / CmsWebSocketConfigurator.java
index 880eb0ed5b7a8b1b9e381d70e362efb501a28a70..279b610b6edf8cb9d0ef57943cde4ae73762eafb 100644 (file)
@@ -1,11 +1,12 @@
 package org.argeo.cms.websocket.server;
 
-import java.security.AccessController;
 import java.security.PrivilegedAction;
+import java.util.ArrayList;
 import java.util.List;
 
 import javax.security.auth.Subject;
 import javax.security.auth.login.LoginContext;
+import javax.security.auth.login.LoginException;
 import javax.websocket.Extension;
 import javax.websocket.HandshakeResponse;
 import javax.websocket.server.HandshakeRequest;
@@ -14,10 +15,12 @@ import javax.websocket.server.ServerEndpointConfig.Configurator;
 
 import org.argeo.api.cms.CmsAuth;
 import org.argeo.api.cms.CmsLog;
-import org.argeo.api.cms.CmsState;
 import org.argeo.cms.auth.RemoteAuthCallbackHandler;
+import org.argeo.cms.auth.RemoteAuthRequest;
+import org.argeo.cms.auth.RemoteAuthResponse;
 import org.argeo.cms.auth.RemoteAuthSession;
-import org.argeo.cms.servlet.ServletHttpSession;
+import org.argeo.cms.auth.RemoteAuthUtils;
+import org.argeo.cms.servlet.CmsServletContext;
 
 /**
  * <strong>Disabled until third party issues are solved.</strong>. Customises
@@ -29,16 +32,6 @@ public class CmsWebSocketConfigurator extends Configurator {
 
        private final static CmsLog log = CmsLog.getLog(CmsWebSocketConfigurator.class);
        final static String HEADER_WWW_AUTHENTICATE = "WWW-Authenticate";
-       
-       private CmsState cmsState;
-       
-       public void start() {
-               
-       }
-       
-       public void stop() {
-               
-       }
 
        @Override
        public boolean checkOrigin(String originHeaderValue) {
@@ -76,34 +69,56 @@ public class CmsWebSocketConfigurator extends Configurator {
 
        @Override
        public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
-               if (true)
-                       return;
+//             if (true)
+//                     return;
 
-               RemoteAuthSession httpSession = new ServletHttpSession(
-                               (javax.servlet.http.HttpSession) request.getHttpSession());
-               if (log.isDebugEnabled() && httpSession != null)
-                       log.debug("Web socket HTTP session id: " + httpSession.getId());
+               WebSocketHandshakeRequest remoteAuthRequest = new WebSocketHandshakeRequest(request);
+               WebSocketHandshakeResponse remoteAuthResponse = new WebSocketHandshakeResponse(response);
+//             RemoteAuthSession httpSession = new ServletHttpSession(
+//                             (javax.servlet.http.HttpSession) request.getHttpSession());
+               RemoteAuthSession remoteAuthSession = remoteAuthRequest.getSession();
+               if (log.isDebugEnabled() && remoteAuthSession != null)
+                       log.debug("Web socket HTTP session id: " + remoteAuthSession.getId());
 
-               if (httpSession == null) {
-                       rejectResponse(response, null);
-               }
+//             if (remoteAuthSession == null) {
+//                     rejectResponse(response, null);
+//             }
+               ClassLoader currentThreadContextClassLoader = Thread.currentThread().getContextClassLoader();
+               Thread.currentThread().setContextClassLoader(CmsServletContext.class.getClassLoader());
+               LoginContext lc;
                try {
-                       LoginContext lc = new LoginContext(CmsAuth.LOGIN_CONTEXT_USER, new RemoteAuthCallbackHandler(httpSession));
+                       lc = CmsAuth.USER.newLoginContext(new RemoteAuthCallbackHandler(remoteAuthRequest, remoteAuthResponse));
                        lc.login();
-                       if (log.isDebugEnabled())
-                               log.debug("Web socket logged-in as " + lc.getSubject());
-                       Subject.doAs(lc.getSubject(), new PrivilegedAction<Void>() {
+               } catch (LoginException e) {
+                       // FIXME better analyse failure so as not to try endlessly
+                       if (authIsRequired(remoteAuthRequest, remoteAuthResponse)) {
+                               int statusCode = RemoteAuthUtils.askForWwwAuth(remoteAuthResponse, "Argeo", true);
+                               remoteAuthResponse.setHeader("Status-Code", Integer.toString(statusCode));
+                               return;
+                       } else {
+                               lc = RemoteAuthUtils.anonymousLogin(remoteAuthRequest, remoteAuthResponse);
+                       }
+                       if (lc == null)
+                               rejectResponse(response, e);
+               } finally {
+                       Thread.currentThread().setContextClassLoader(currentThreadContextClassLoader);
+               }
 
-                               @Override
-                               public Void run() {
-                                       sec.getUserProperties().put(REMOTE_USER, AccessController.getContext());
-                                       return null;
-                               }
+               Subject subject = lc.getSubject();
+               Subject.doAs(subject, new PrivilegedAction<Void>() {
 
-                       });
-               } catch (Exception e) {
-                       rejectResponse(response, e);
-               }
+                       @Override
+                       public Void run() {
+                               // TODO also set login context in order to log out ?
+                               RemoteAuthUtils.configureRequestSecurity(remoteAuthRequest);
+                               return null;
+                       }
+
+               });
+       }
+
+       protected boolean authIsRequired(RemoteAuthRequest remoteAuthRequest, RemoteAuthResponse remoteAuthResponse) {
+               return true;
        }
 
        /**
@@ -113,6 +128,7 @@ public class CmsWebSocketConfigurator extends Configurator {
         * @param e can be null
         */
        protected void rejectResponse(HandshakeResponse response, Exception e) {
+               response.getHeaders().put(HandshakeResponse.SEC_WEBSOCKET_ACCEPT, new ArrayList<String>());
                // violent implementation, as suggested in
                // https://stackoverflow.com/questions/21763829/jsr-356-how-to-abort-a-websocket-connection-during-the-handshake
 //             throw new IllegalStateException("Web socket cannot be authenticated");