diff options
Diffstat (limited to 'juick-server/src/main/java/com/juick/server/security')
-rw-r--r-- | juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java | 51 |
1 files changed, 36 insertions, 15 deletions
diff --git a/juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java b/juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java index df1ae38cb..ce48adbe7 100644 --- a/juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java +++ b/juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java @@ -5,14 +5,15 @@ import com.juick.server.security.entities.JuickUser; import com.juick.service.UserService; import org.springframework.security.authentication.AnonymousAuthenticationToken; import org.springframework.security.authentication.RememberMeAuthenticationToken; -import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; import org.springframework.security.core.context.SecurityContextHolder; -import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter; +import org.springframework.security.web.authentication.RememberMeServices; +import org.springframework.util.Assert; import org.springframework.web.filter.OncePerRequestFilter; import javax.servlet.FilterChain; import javax.servlet.ServletException; +import javax.servlet.http.Cookie; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; @@ -24,10 +25,17 @@ public class HashParamAuthenticationFilter extends OncePerRequestFilter { public static final String PARAM_NAME = "hash"; private final UserService userService; + private final RememberMeServices rememberMeServices; - public HashParamAuthenticationFilter(UserService userService) { + public HashParamAuthenticationFilter( + final UserService userService, + final RememberMeServices rememberMeServices) { + Assert.notNull(userService, "userService should not be null"); + Assert.notNull(rememberMeServices, "rememberMeServices should not be null"); + this.userService = userService; + this.rememberMeServices = rememberMeServices; } @Override @@ -36,17 +44,19 @@ public class HashParamAuthenticationFilter extends OncePerRequestFilter { HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { - String hash = request.getHeader(PARAM_NAME); - - if (hash == null) - hash = request.getParameter(PARAM_NAME); + String hash = getHashFromRequest(request); if (hash != null && authenticationIsRequired()) { User user = userService.getUserByHash(hash); - if (!user.isAnonymous()) - SecurityContextHolder.getContext().setAuthentication( - new RememberMeAuthenticationToken(hash, new JuickUser(user), JuickUser.USER_AUTHORITY)); + if (!user.isAnonymous()) { + Authentication authentication = new RememberMeAuthenticationToken( + hash, new JuickUser(user), JuickUser.USER_AUTHORITY); + + SecurityContextHolder.getContext().setAuthentication(authentication); + + rememberMeServices.loginSuccess(request, response, authentication); + } } filterChain.doFilter(request, response); @@ -55,12 +65,23 @@ public class HashParamAuthenticationFilter extends OncePerRequestFilter { private boolean authenticationIsRequired() { Authentication existingAuth = SecurityContextHolder.getContext().getAuthentication(); - if (existingAuth == null || !existingAuth.isAuthenticated()) - return true; + return existingAuth == null || + !existingAuth.isAuthenticated() || + existingAuth instanceof AnonymousAuthenticationToken; + } + + private String getHashFromRequest(HttpServletRequest request) { + String hash = request.getHeader(PARAM_NAME); - if (existingAuth instanceof AnonymousAuthenticationToken) - return true; + if (hash == null) + hash = request.getParameter(PARAM_NAME); - return false; + if (hash == null) + for (Cookie cookie : request.getCookies()) + if (PARAM_NAME.equals(cookie.getName())) { + hash = cookie.getValue(); + break; + } + return hash; } } |