aboutsummaryrefslogtreecommitdiff
path: root/juick-server/src/main/java/com/juick/server/security
diff options
context:
space:
mode:
Diffstat (limited to 'juick-server/src/main/java/com/juick/server/security')
-rw-r--r--juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java51
1 files changed, 36 insertions, 15 deletions
diff --git a/juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java b/juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java
index df1ae38cb..ce48adbe7 100644
--- a/juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java
+++ b/juick-server/src/main/java/com/juick/server/security/HashParamAuthenticationFilter.java
@@ -5,14 +5,15 @@ import com.juick.server.security.entities.JuickUser;
import com.juick.service.UserService;
import org.springframework.security.authentication.AnonymousAuthenticationToken;
import org.springframework.security.authentication.RememberMeAuthenticationToken;
-import org.springframework.security.authentication.UsernamePasswordAuthenticationToken;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.context.SecurityContextHolder;
-import org.springframework.security.web.authentication.UsernamePasswordAuthenticationFilter;
+import org.springframework.security.web.authentication.RememberMeServices;
+import org.springframework.util.Assert;
import org.springframework.web.filter.OncePerRequestFilter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
+import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
@@ -24,10 +25,17 @@ public class HashParamAuthenticationFilter extends OncePerRequestFilter {
public static final String PARAM_NAME = "hash";
private final UserService userService;
+ private final RememberMeServices rememberMeServices;
- public HashParamAuthenticationFilter(UserService userService) {
+ public HashParamAuthenticationFilter(
+ final UserService userService,
+ final RememberMeServices rememberMeServices) {
+ Assert.notNull(userService, "userService should not be null");
+ Assert.notNull(rememberMeServices, "rememberMeServices should not be null");
+
this.userService = userService;
+ this.rememberMeServices = rememberMeServices;
}
@Override
@@ -36,17 +44,19 @@ public class HashParamAuthenticationFilter extends OncePerRequestFilter {
HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
- String hash = request.getHeader(PARAM_NAME);
-
- if (hash == null)
- hash = request.getParameter(PARAM_NAME);
+ String hash = getHashFromRequest(request);
if (hash != null && authenticationIsRequired()) {
User user = userService.getUserByHash(hash);
- if (!user.isAnonymous())
- SecurityContextHolder.getContext().setAuthentication(
- new RememberMeAuthenticationToken(hash, new JuickUser(user), JuickUser.USER_AUTHORITY));
+ if (!user.isAnonymous()) {
+ Authentication authentication = new RememberMeAuthenticationToken(
+ hash, new JuickUser(user), JuickUser.USER_AUTHORITY);
+
+ SecurityContextHolder.getContext().setAuthentication(authentication);
+
+ rememberMeServices.loginSuccess(request, response, authentication);
+ }
}
filterChain.doFilter(request, response);
@@ -55,12 +65,23 @@ public class HashParamAuthenticationFilter extends OncePerRequestFilter {
private boolean authenticationIsRequired() {
Authentication existingAuth = SecurityContextHolder.getContext().getAuthentication();
- if (existingAuth == null || !existingAuth.isAuthenticated())
- return true;
+ return existingAuth == null ||
+ !existingAuth.isAuthenticated() ||
+ existingAuth instanceof AnonymousAuthenticationToken;
+ }
+
+ private String getHashFromRequest(HttpServletRequest request) {
+ String hash = request.getHeader(PARAM_NAME);
- if (existingAuth instanceof AnonymousAuthenticationToken)
- return true;
+ if (hash == null)
+ hash = request.getParameter(PARAM_NAME);
- return false;
+ if (hash == null)
+ for (Cookie cookie : request.getCookies())
+ if (PARAM_NAME.equals(cookie.getName())) {
+ hash = cookie.getValue();
+ break;
+ }
+ return hash;
}
}