diff options
Diffstat (limited to 'src/main/java/com/juick/service')
-rw-r--r-- | src/main/java/com/juick/service/PostgreGINService.java | 36 |
1 files changed, 25 insertions, 11 deletions
diff --git a/src/main/java/com/juick/service/PostgreGINService.java b/src/main/java/com/juick/service/PostgreGINService.java index f4f70bf4..ca326bd5 100644 --- a/src/main/java/com/juick/service/PostgreGINService.java +++ b/src/main/java/com/juick/service/PostgreGINService.java @@ -1,5 +1,5 @@ /* - * Copyright (C) 2008-2020, Juick + * Copyright (C) 2008-2022, Juick * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as @@ -19,16 +19,16 @@ package com.juick.service; import com.juick.model.User; import org.apache.commons.lang3.StringUtils; +import org.postgresql.PGConnection; import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; -import org.springframework.jdbc.core.namedparam.MapSqlParameterSource; import org.springframework.stereotype.Repository; import org.springframework.transaction.annotation.Transactional; import javax.inject.Inject; +import javax.sql.DataSource; +import java.sql.SQLException; import java.util.Collections; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.stream.Collectors; /** @@ -45,10 +45,24 @@ public class PostgreGINService extends BaseJdbcService implements SearchService @Inject UserService userService; + @Inject + DataSource dataSource; + + private String escapeQuery(String searchString) { + try { + if (dataSource.getConnection().isWrapperFor(PGConnection.class)) { + PGConnection connection = dataSource.getConnection().unwrap(PGConnection.class); + return connection.escapeLiteral(searchString); + } + } catch (SQLException e) { + e.printStackTrace(); + } + return searchString; + } - public String sortHint(String searchString) { - boolean isOneWord = searchString.split("[^\\S\\+]+").length == 1; - return isOneWord ? " message_id desc" : String.format(" ts_rank(to_tsvector('russian', \"txt\"), plainto_tsquery('russian', '%s')) DESC, message_id desc", searchString); + public String sortHint(String escapedSearchString) { + boolean isOneWord = escapedSearchString.split("[^\\S\\+]+").length == 1; + return isOneWord ? " message_id desc" : String.format(" ts_rank(to_tsvector('russian', txt), plainto_tsquery('russian', '%s')) DESC, message_id desc", escapedSearchString); } @Override @@ -58,8 +72,8 @@ public class PostgreGINService extends BaseJdbcService implements SearchService String usersFilter = userService.getUserBLUsers(visitor.getUid()).stream().map(u -> String.valueOf(u.getUid())).collect(Collectors.joining(",")); var filter = usersFilter.isBlank() ? - "to_tsvector('russian', txt) @@ plainto_tsquery('russian', '" + searchString + "')" - : "to_tsvector('russian', txt) @@ plainto_tsquery('russian', '" + searchString + "') AND messages.user_id not in (" + usersFilter + ")"; + "to_tsvector('russian', txt) @@ plainto_tsquery('russian', '" + escapeQuery(searchString) + "')" + : "to_tsvector('russian', txt) @@ plainto_tsquery('russian', '" + escapeQuery(searchString) + "') AND messages.user_id not in (" + usersFilter + ")"; var offset = page * maxResult; var sql = String.format(""" SELECT messages.message_id FROM messages_txt inner join messages @@ -67,7 +81,7 @@ public class PostgreGINService extends BaseJdbcService implements SearchService WHERE %s ORDER BY %s LIMIT %d OFFSET %d - """, filter, sortHint(searchString), maxResult, offset); + """, filter, sortHint(escapeQuery(searchString)), maxResult, offset); return getJdbcTemplate().queryForList(sql, Integer.class); } @@ -83,7 +97,7 @@ public class PostgreGINService extends BaseJdbcService implements SearchService WHERE to_tsvector('russian', txt) @@ plainto_tsquery('russian', '%s') AND user_id=%d ORDER BY %s LIMIT %d OFFSET %d - """, searchString, userId, sortHint(searchString), maxResult, offset); + """, escapeQuery(searchString), userId, sortHint(escapeQuery(searchString)), maxResult, offset); return getJdbcTemplate().queryForList(sql, Integer.class); } |