aboutsummaryrefslogtreecommitdiff
path: root/src/main/java/com/juick/service
diff options
context:
space:
mode:
Diffstat (limited to 'src/main/java/com/juick/service')
-rw-r--r--src/main/java/com/juick/service/PostgreGINService.java36
1 files changed, 25 insertions, 11 deletions
diff --git a/src/main/java/com/juick/service/PostgreGINService.java b/src/main/java/com/juick/service/PostgreGINService.java
index f4f70bf4..ca326bd5 100644
--- a/src/main/java/com/juick/service/PostgreGINService.java
+++ b/src/main/java/com/juick/service/PostgreGINService.java
@@ -1,5 +1,5 @@
/*
- * Copyright (C) 2008-2020, Juick
+ * Copyright (C) 2008-2022, Juick
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as
@@ -19,16 +19,16 @@ package com.juick.service;
import com.juick.model.User;
import org.apache.commons.lang3.StringUtils;
+import org.postgresql.PGConnection;
import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
-import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
import org.springframework.stereotype.Repository;
import org.springframework.transaction.annotation.Transactional;
import javax.inject.Inject;
+import javax.sql.DataSource;
+import java.sql.SQLException;
import java.util.Collections;
-import java.util.HashMap;
import java.util.List;
-import java.util.Map;
import java.util.stream.Collectors;
/**
@@ -45,10 +45,24 @@ public class PostgreGINService extends BaseJdbcService implements SearchService
@Inject
UserService userService;
+ @Inject
+ DataSource dataSource;
+
+ private String escapeQuery(String searchString) {
+ try {
+ if (dataSource.getConnection().isWrapperFor(PGConnection.class)) {
+ PGConnection connection = dataSource.getConnection().unwrap(PGConnection.class);
+ return connection.escapeLiteral(searchString);
+ }
+ } catch (SQLException e) {
+ e.printStackTrace();
+ }
+ return searchString;
+ }
- public String sortHint(String searchString) {
- boolean isOneWord = searchString.split("[^\\S\\+]+").length == 1;
- return isOneWord ? " message_id desc" : String.format(" ts_rank(to_tsvector('russian', \"txt\"), plainto_tsquery('russian', '%s')) DESC, message_id desc", searchString);
+ public String sortHint(String escapedSearchString) {
+ boolean isOneWord = escapedSearchString.split("[^\\S\\+]+").length == 1;
+ return isOneWord ? " message_id desc" : String.format(" ts_rank(to_tsvector('russian', txt), plainto_tsquery('russian', '%s')) DESC, message_id desc", escapedSearchString);
}
@Override
@@ -58,8 +72,8 @@ public class PostgreGINService extends BaseJdbcService implements SearchService
String usersFilter = userService.getUserBLUsers(visitor.getUid()).stream().map(u -> String.valueOf(u.getUid())).collect(Collectors.joining(","));
var filter = usersFilter.isBlank() ?
- "to_tsvector('russian', txt) @@ plainto_tsquery('russian', '" + searchString + "')"
- : "to_tsvector('russian', txt) @@ plainto_tsquery('russian', '" + searchString + "') AND messages.user_id not in (" + usersFilter + ")";
+ "to_tsvector('russian', txt) @@ plainto_tsquery('russian', '" + escapeQuery(searchString) + "')"
+ : "to_tsvector('russian', txt) @@ plainto_tsquery('russian', '" + escapeQuery(searchString) + "') AND messages.user_id not in (" + usersFilter + ")";
var offset = page * maxResult;
var sql = String.format("""
SELECT messages.message_id FROM messages_txt inner join messages
@@ -67,7 +81,7 @@ public class PostgreGINService extends BaseJdbcService implements SearchService
WHERE %s
ORDER BY %s
LIMIT %d OFFSET %d
- """, filter, sortHint(searchString), maxResult, offset);
+ """, filter, sortHint(escapeQuery(searchString)), maxResult, offset);
return getJdbcTemplate().queryForList(sql, Integer.class);
}
@@ -83,7 +97,7 @@ public class PostgreGINService extends BaseJdbcService implements SearchService
WHERE to_tsvector('russian', txt) @@ plainto_tsquery('russian', '%s') AND user_id=%d
ORDER BY %s
LIMIT %d OFFSET %d
- """, searchString, userId, sortHint(searchString), maxResult, offset);
+ """, escapeQuery(searchString), userId, sortHint(escapeQuery(searchString)), maxResult, offset);
return getJdbcTemplate().queryForList(sql, Integer.class);
}