Improve web socket authentication
authorMathieu Baudier <mbaudier@argeo.org>
Thu, 9 May 2019 16:11:28 +0000 (18:11 +0200)
committerMathieu Baudier <mbaudier@argeo.org>
Thu, 9 May 2019 16:11:28 +0000 (18:11 +0200)
org.argeo.cms/src/org/argeo/cms/auth/AnonymousLoginModule.java
org.argeo.cms/src/org/argeo/cms/auth/CmsAuthUtils.java
org.argeo.cms/src/org/argeo/cms/auth/HttpRequestCallback.java
org.argeo.cms/src/org/argeo/cms/auth/HttpRequestCallbackHandler.java
org.argeo.cms/src/org/argeo/cms/auth/HttpSessionLoginModule.java
org.argeo.cms/src/org/argeo/cms/auth/UserAdminLoginModule.java
org.argeo.cms/src/org/argeo/cms/websocket/CmsWebSocketConfigurator.java [new file with mode: 0644]
org.argeo.ext.equinox.jetty/src/org/argeo/equinox/jetty/WebSocketJettyCustomizer.java

index 19c0d60edff4f892ec1170775bd28f452628b847..e91fd6033be2623557f38a14bca1b70a309b21e7 100644 (file)
@@ -54,7 +54,7 @@ public class AnonymousLoginModule implements LoginModule {
                Locale locale = Locale.getDefault();
                if (request != null)
                        locale = request.getLocale();
-               CmsAuthUtils.addAuthorization(subject, authorization, locale, request);
+               CmsAuthUtils.addAuthorization(subject, authorization);
                CmsAuthUtils.registerSessionAuthorization(request, subject, authorization, locale);
                if (log.isTraceEnabled())
                        log.trace("Anonymous logged in to CMS: " + subject);
index dde2d73f50efffa23f6267d454e9b9c24f75de76..9a60e913465eb530c4ebba7be3ecba72791681a9 100644 (file)
@@ -1,6 +1,7 @@
 package org.argeo.cms.auth;
 
 import java.security.Principal;
+import java.util.Collection;
 import java.util.Locale;
 import java.util.Set;
 import java.util.UUID;
@@ -25,6 +26,9 @@ import org.argeo.node.security.AnonymousPrincipal;
 import org.argeo.node.security.DataAdminPrincipal;
 import org.argeo.node.security.NodeSecurityUtils;
 import org.argeo.osgi.useradmin.AuthenticatingUser;
+import org.osgi.framework.BundleContext;
+import org.osgi.framework.InvalidSyntaxException;
+import org.osgi.framework.ServiceReference;
 import org.osgi.service.http.HttpContext;
 import org.osgi.service.useradmin.Authorization;
 
@@ -41,8 +45,7 @@ class CmsAuthUtils {
        final static String SHARED_STATE_SPNEGO_OUT_TOKEN = "org.argeo.cms.auth.spnegoOutToken";
        final static String SHARED_STATE_CERTIFICATE_CHAIN = "org.argeo.cms.auth.certificateChain";
 
-       static void addAuthorization(Subject subject, Authorization authorization, Locale locale,
-                       HttpServletRequest request) {
+       static void addAuthorization(Subject subject, Authorization authorization) {
                assert subject != null;
                checkSubjectEmpty(subject);
                assert authorization != null;
@@ -175,6 +178,29 @@ class CmsAuthUtils {
                }
        }
 
+       public static CmsSession cmsSessionFromHttpSession(BundleContext bc, String httpSessionId) {
+               Authorization authorization = null;
+               Collection<ServiceReference<CmsSession>> sr;
+               try {
+                       sr = bc.getServiceReferences(CmsSession.class,
+                                       "(" + CmsSession.SESSION_LOCAL_ID + "=" + httpSessionId + ")");
+               } catch (InvalidSyntaxException e) {
+                       throw new CmsException("Cannot get CMS session for id " + httpSessionId, e);
+               }
+               CmsSession cmsSession;
+               if (sr.size() == 1) {
+                       cmsSession = bc.getService(sr.iterator().next());
+//                     locale = cmsSession.getLocale();
+                       authorization = cmsSession.getAuthorization();
+                       if (authorization.getName() == null)
+                               return null;// anonymous is not sufficient
+               } else if (sr.size() == 0)
+                       return null;
+               else
+                       throw new CmsException(sr.size() + ">1 web sessions detected for http session " + httpSessionId);
+               return cmsSession;
+       }
+
        public static <T extends Principal> T getSinglePrincipal(Subject subject, Class<T> clss) {
                Set<T> principals = subject.getPrincipals(clss);
                if (principals.isEmpty())
index 611b324d54fb88d81f900f2f3ecb0b25a42e0de1..3a3a9f9e0a5639468b66a7cc6d52592073422290 100644 (file)
@@ -3,10 +3,12 @@ package org.argeo.cms.auth;
 import javax.security.auth.callback.Callback;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
 
 public class HttpRequestCallback implements Callback {
        private HttpServletRequest request;
        private HttpServletResponse response;
+       private HttpSession httpSession;
 
        public HttpServletRequest getRequest() {
                return request;
@@ -24,4 +26,12 @@ public class HttpRequestCallback implements Callback {
                this.response = response;
        }
 
+       public HttpSession getHttpSession() {
+               return httpSession;
+       }
+
+       public void setHttpSession(HttpSession httpSession) {
+               this.httpSession = httpSession;
+       }
+
 }
index bcc403fa996c4060cc4e3baa27158a40392cf298..df971e687f0c31bbd2eaf957a376be0c736a9c76 100644 (file)
@@ -8,6 +8,7 @@ import javax.security.auth.callback.LanguageCallback;
 import javax.security.auth.callback.UnsupportedCallbackException;
 import javax.servlet.http.HttpServletRequest;
 import javax.servlet.http.HttpServletResponse;
+import javax.servlet.http.HttpSession;
 
 /**
  * Callback handler populating {@link HttpRequestCallback}s with the provided
@@ -16,18 +17,27 @@ import javax.servlet.http.HttpServletResponse;
 public class HttpRequestCallbackHandler implements CallbackHandler {
        final private HttpServletRequest request;
        final private HttpServletResponse response;
+       final private HttpSession httpSession;
 
        public HttpRequestCallbackHandler(HttpServletRequest request, HttpServletResponse response) {
                this.request = request;
+               this.httpSession = request.getSession(false);
                this.response = response;
        }
 
+       public HttpRequestCallbackHandler(HttpSession httpSession) {
+               this.httpSession = httpSession;
+               this.request = null;
+               this.response = null;
+       }
+
        @Override
        public void handle(Callback[] callbacks) throws IOException, UnsupportedCallbackException {
                for (Callback callback : callbacks)
                        if (callback instanceof HttpRequestCallback) {
                                ((HttpRequestCallback) callback).setRequest(request);
                                ((HttpRequestCallback) callback).setResponse(response);
+                               ((HttpRequestCallback) callback).setHttpSession(httpSession);
                        } else if (callback instanceof LanguageCallback) {
                                ((LanguageCallback) callback).setLocale(request.getLocale());
                        }
index 1bbe359b9f8b478a6d986e75f32d47717e42873b..f42e79c98980d185e469329aedc8810b6775aa3d 100644 (file)
@@ -3,7 +3,6 @@ package org.argeo.cms.auth;
 import java.io.IOException;
 import java.security.cert.X509Certificate;
 import java.util.Base64;
-import java.util.Collection;
 import java.util.Locale;
 import java.util.Map;
 import java.util.StringTokenizer;
@@ -24,8 +23,6 @@ import org.argeo.cms.CmsException;
 import org.argeo.cms.internal.kernel.Activator;
 import org.osgi.framework.BundleContext;
 import org.osgi.framework.FrameworkUtil;
-import org.osgi.framework.InvalidSyntaxException;
-import org.osgi.framework.ServiceReference;
 import org.osgi.service.http.HttpContext;
 import org.osgi.service.useradmin.Authorization;
 
@@ -68,44 +65,46 @@ public class HttpSessionLoginModule implements LoginModule {
                        return false;
                }
                request = httpCallback.getRequest();
-               if (request == null)
-                       return false;
-               authorization = (Authorization) request.getAttribute(HttpContext.AUTHORIZATION);
-               if (authorization == null) {// search by session ID
-                       HttpSession httpSession = request.getSession(false);
-                       if (httpSession == null) {
-                               // TODO make sure this is always safe
-                               if (log.isTraceEnabled())
-                                       log.trace("Create http session");
-                               httpSession = request.getSession(true);
-                       }
+               if (request == null) {
+                       HttpSession httpSession = httpCallback.getHttpSession();
+                       if (httpSession == null)
+                               return false;
+                       // TODO factorize with below
                        String httpSessionId = httpSession.getId();
                        if (log.isTraceEnabled())
                                log.trace("HTTP login: " + request.getPathInfo() + " #" + httpSessionId);
-                       Collection<ServiceReference<CmsSession>> sr;
-                       try {
-                               sr = bc.getServiceReferences(CmsSession.class,
-                                               "(" + CmsSession.SESSION_LOCAL_ID + "=" + httpSessionId + ")");
-                       } catch (InvalidSyntaxException e) {
-                               throw new CmsException("Cannot get CMS session for id " + httpSessionId, e);
-                       }
-                       if (sr.size() == 1) {
-                               CmsSession cmsSession = bc.getService(sr.iterator().next());
-                               locale = cmsSession.getLocale();
+                       CmsSession cmsSession = CmsAuthUtils.cmsSessionFromHttpSession(bc, httpSessionId);
+                       if (cmsSession != null) {
                                authorization = cmsSession.getAuthorization();
-                               if (authorization.getName() == null)
-                                       authorization = null;// anonymous is not sufficient
+                               locale = cmsSession.getLocale();
                                if (log.isTraceEnabled())
                                        log.trace("Retrieved authorization from " + cmsSession);
-                       } else if (sr.size() == 0)
-                               authorization = null;
-                       else
-                               throw new CmsException(sr.size() + ">1 web sessions detected for http session " + httpSessionId);
-
+                       }
+               } else {
+                       authorization = (Authorization) request.getAttribute(HttpContext.AUTHORIZATION);
+                       if (authorization == null) {// search by session ID
+                               HttpSession httpSession = request.getSession(false);
+                               if (httpSession == null) {
+                                       // TODO make sure this is always safe
+                                       if (log.isTraceEnabled())
+                                               log.trace("Create http session");
+                                       httpSession = request.getSession(true);
+                               }
+                               String httpSessionId = httpSession.getId();
+                               if (log.isTraceEnabled())
+                                       log.trace("HTTP login: " + request.getPathInfo() + " #" + httpSessionId);
+                               CmsSession cmsSession = CmsAuthUtils.cmsSessionFromHttpSession(bc, httpSessionId);
+                               if (cmsSession != null) {
+                                       authorization = cmsSession.getAuthorization();
+                                       locale = cmsSession.getLocale();
+                                       if (log.isTraceEnabled())
+                                               log.trace("Retrieved authorization from " + cmsSession);
+                               }
+                       }
+                       sharedState.put(CmsAuthUtils.SHARED_STATE_HTTP_REQUEST, request);
+                       extractHttpAuth(request);
+                       extractClientCertificate(request);
                }
-               sharedState.put(CmsAuthUtils.SHARED_STATE_HTTP_REQUEST, request);
-               extractHttpAuth(request);
-               extractClientCertificate(request);
                if (authorization == null) {
                        if (log.isTraceEnabled())
                                log.trace("HTTP login: " + false);
@@ -127,10 +126,11 @@ public class HttpSessionLoginModule implements LoginModule {
 
                if (authorization != null) {
                        // Locale locale = request.getLocale();
-                       if (locale == null)
+                       if (locale == null && request != null)
                                locale = request.getLocale();
-                       subject.getPublicCredentials().add(locale);
-                       CmsAuthUtils.addAuthorization(subject, authorization, locale, request);
+                       if (locale != null)
+                               subject.getPublicCredentials().add(locale);
+                       CmsAuthUtils.addAuthorization(subject, authorization);
                        CmsAuthUtils.registerSessionAuthorization(request, subject, authorization, locale);
                        cleanUp();
                        return true;
@@ -159,6 +159,10 @@ public class HttpSessionLoginModule implements LoginModule {
 
        private void extractHttpAuth(final HttpServletRequest httpRequest) {
                String authHeader = httpRequest.getHeader(CmsAuthUtils.HEADER_AUTHORIZATION);
+               extractHttpAuth(authHeader);
+       }
+
+       private void extractHttpAuth(String authHeader) {
                if (authHeader != null) {
                        StringTokenizer st = new StringTokenizer(authHeader);
                        if (st.hasMoreTokens()) {
index 7297513c2849afe1ab98a2af95ef83691a27df48..cdb0f4ca27edfff0ad89340753d357ab097c09d7 100644 (file)
@@ -243,7 +243,7 @@ public class UserAdminLoginModule implements LoginModule {
 
                // Log and monitor new login
                HttpServletRequest request = (HttpServletRequest) sharedState.get(CmsAuthUtils.SHARED_STATE_HTTP_REQUEST);
-               CmsAuthUtils.addAuthorization(subject, authorization, locale, request);
+               CmsAuthUtils.addAuthorization(subject, authorization);
 
                // Unlock keyring (underlying login to the JCR repository)
                char[] password = (char[]) sharedState.get(CmsAuthUtils.SHARED_STATE_PWD);
diff --git a/org.argeo.cms/src/org/argeo/cms/websocket/CmsWebSocketConfigurator.java b/org.argeo.cms/src/org/argeo/cms/websocket/CmsWebSocketConfigurator.java
new file mode 100644 (file)
index 0000000..cd435aa
--- /dev/null
@@ -0,0 +1,91 @@
+package org.argeo.cms.websocket;
+
+import java.util.List;
+
+import javax.security.auth.login.LoginContext;
+import javax.security.auth.login.LoginException;
+import javax.servlet.http.HttpSession;
+import javax.websocket.Extension;
+import javax.websocket.HandshakeResponse;
+import javax.websocket.server.HandshakeRequest;
+import javax.websocket.server.ServerEndpointConfig;
+import javax.websocket.server.ServerEndpointConfig.Configurator;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.argeo.cms.auth.HttpRequestCallbackHandler;
+import org.argeo.node.NodeConstants;
+
+public final class CmsWebSocketConfigurator extends Configurator {
+       private final static Log log = LogFactory.getLog(CmsWebSocketConfigurator.class);
+       final static String HEADER_WWW_AUTHENTICATE = "WWW-Authenticate";
+
+       @Override
+       public boolean checkOrigin(String originHeaderValue) {
+               return true;
+       }
+
+       @Override
+       public <T> T getEndpointInstance(Class<T> endpointClass) throws InstantiationException {
+               try {
+                       return endpointClass.getDeclaredConstructor().newInstance();
+               } catch (Exception e) {
+                       throw new IllegalArgumentException("Cannot get endpoint instance", e);
+               }
+       }
+
+       @Override
+       public List<Extension> getNegotiatedExtensions(List<Extension> installed, List<Extension> requested) {
+               return requested;
+       }
+
+       @Override
+       public String getNegotiatedSubprotocol(List<String> supported, List<String> requested) {
+               if ((requested == null) || (requested.size() == 0))
+                       return "";
+               if ((supported == null) || (supported.isEmpty()))
+                       return "";
+               for (String possible : requested) {
+                       if (possible == null)
+                               continue;
+                       if (supported.contains(possible))
+                               return possible;
+               }
+               return "";
+       }
+
+       @Override
+       public void modifyHandshake(ServerEndpointConfig sec, HandshakeRequest request, HandshakeResponse response) {
+               HttpSession httpSession = (HttpSession) request.getHttpSession();
+               if (log.isDebugEnabled() && httpSession != null)
+                       log.debug("Web socket HTTP session id: " + httpSession.getId());
+
+               if (httpSession == null) {
+                       rejectResponse(response);
+               }
+               try {
+                       LoginContext lc = new LoginContext(NodeConstants.LOGIN_CONTEXT_USER,
+                                       new HttpRequestCallbackHandler(httpSession));
+                       lc.login();
+                       if (log.isDebugEnabled())
+                               log.debug("Web socket logged-in as " + lc.getSubject());
+                       sec.getUserProperties().put("subject", lc.getSubject());
+               } catch (LoginException e) {
+                       rejectResponse(response);
+               }
+
+//             List<String> authHeaders = request.getHeaders().get(HEADER_WWW_AUTHENTICATE);
+//             String authHeader;
+//             if (authHeaders != null && authHeaders.size() == 1) {
+//                     authHeader = authHeaders.get(0);
+//             } else {
+//                     return;
+//             }
+       }
+
+       private void rejectResponse(HandshakeResponse response) {
+               // violent implementation, as suggested in
+               // https://stackoverflow.com/questions/21763829/jsr-356-how-to-abort-a-websocket-connection-during-the-handshake
+               throw new IllegalStateException("Web socket cannot be authenticated");
+       }
+}
index 83934537c364a99e08d6304d35d4244e03d53b08..a74c706580a7621533426c14a0cae1a8c93e5665 100644 (file)
@@ -119,8 +119,9 @@ public class WebSocketJettyCustomizer extends JettyCustomizer {
 
                                                CmsSession cmsSession = getCmsSession(httpSessionId);
                                                if (cmsSession == null) {
-                                                       session.disconnect();
-                                                       return;
+//                                                     session.disconnect();
+//                                                     return;
+
 //                                                     try {
 //                                                             session.getUpgradeResponse().sendForbidden("Web Sockets must always be authenticated.");
 //                                                     } catch (IOException e) {