package com.juick.database; import org.apache.commons.io.IOUtils; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.support.TransactionTemplate; import org.springframework.util.Assert; import javax.annotation.PostConstruct; import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; import java.util.function.Supplier; import java.util.regex.Matcher; import java.util.regex.Pattern; import static org.springframework.transaction.TransactionDefinition.PROPAGATION_REQUIRED; /** * Created by aalexeev on 12/13/16. */ public class MySqlUpdater { private static final Pattern UPDATE_PATTERN = Pattern.compile( "update\\s+(version|`version`)\\s+set\\s+(version|`version`)\\s+=\\s+(\\d+)", Pattern.CASE_INSENSITIVE); private final Logger logger = LoggerFactory.getLogger(getClass()); private final JdbcTemplate jdbcTemplate; private final TransactionTemplate transactionTemplate; private final String updateSqlResource; public MySqlUpdater(JdbcTemplate jdbcTemplate, PlatformTransactionManager transactionManager, String updateSqlResource) { Assert.notNull(jdbcTemplate); Assert.notNull(transactionManager); Assert.notNull(updateSqlResource); this.jdbcTemplate = jdbcTemplate; this.transactionTemplate = new TransactionTemplate(transactionManager); this.updateSqlResource = updateSqlResource; } @PostConstruct public void init() { try ( InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(updateSqlResource); ) { if (is != null) { String content = IOUtils.toString(is, StandardCharsets.UTF_8); if (StringUtils.isNotEmpty(content)) { String[] sqlArray = content.split(";"); if (sqlArray.length > 0) { List sqlList = new ArrayList<>(sqlArray.length); for (String sql : sqlArray) if (!sql.isEmpty()) { String sqlTrimmed = sql.trim(); if (!sqlTrimmed.isEmpty()) sqlList.add(sqlTrimmed); } if (!sqlList.isEmpty()) processingSql(sqlList); } } } } catch (Exception e) { logger.error("MySqlUpdater initialization exception", e); } } private void processingSql(final List sqls) { long currentDbVersion = getSingleResult(this::getVersionRaw); long actualVersion; List changesSql = new ArrayList<>(); for (String sql : sqls) { changesSql.add(sql); Matcher m = UPDATE_PATTERN.matcher(sql); if (m.matches()) { String actual = m.group(3); actualVersion = Long.valueOf(actual); if (actualVersion > currentDbVersion) { updateInTransaction(changesSql); currentDbVersion = actualVersion; } changesSql.clear(); } } } private void updateInTransaction(final List sqls) { transactionTemplate.setReadOnly(false); transactionTemplate.setPropagationBehavior(PROPAGATION_REQUIRED); transactionTemplate.execute(status -> { for (String sql : sqls) jdbcTemplate.execute(sql); return 0; }); } private T getSingleResult(Supplier supplier) { transactionTemplate.setReadOnly(true); transactionTemplate.setPropagationBehavior(PROPAGATION_REQUIRED); return transactionTemplate.execute(status -> supplier.get()); } private long getVersionRaw() { int cnt = jdbcTemplate.query( "SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ?", rs -> { int result = 0; if (rs.next()) result = rs.getInt(1); return result; }, "juick", "version"); long version = 0l; if (cnt == 1) { List list = jdbcTemplate.queryForList("select version from version", Long.class); if (!list.isEmpty()) version = list.get(0); } return version; } }