package com.huigou.uasp.log.application.impl;

import com.huigou.data.dialect.Dialect;
import com.huigou.data.dialect.DialectUtils;
import com.huigou.data.domain.query.QueryAbstractRequest;
import com.huigou.data.domain.query.QueryPageRequest;
import com.huigou.data.jdbc.util.RowSetUtil;
import com.huigou.data.jdbc.util.SQLRowSetOracleResultSetExtractor;
import com.huigou.data.query.model.QueryModel;
import com.huigou.uasp.log.application.SlicedQueryStrategy;
import com.huigou.util.Constants;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.domain.PageRequest;
import org.springframework.jdbc.core.JdbcTemplate;
import org.springframework.jdbc.core.ResultSetExtractor;
import org.springframework.jdbc.core.SqlRowSetResultSetExtractor;
import org.springframework.jdbc.support.rowset.SqlRowSet;
import org.springframework.stereotype.Component;

import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/**
 * 使用inner join方式实现分页查询，SQL格式如下：
 * <p>select t.* from t inner join (select t.id from t where t.name like 'xx%' order by create_date desc limit 20,40) lim on t.id=lim.id</p>
 *
 * @author yonghuan
 */
@Component
public class InnerJoinSlicedQueryStrategy implements SlicedQueryStrategy {

    private final static Logger LOG = LoggerFactory.getLogger(InnerJoinSlicedQueryStrategy.class);
    private JdbcTemplate jdbcTemplate;

    @Autowired
    public void setJdbcTemplate(JdbcTemplate jdbcTemplate) {
        this.jdbcTemplate = jdbcTemplate;
    }

    @Override
    public boolean supports(String tableName, QueryAbstractRequest queryRequest) {
        QueryModel queryModel = queryRequest.initQueryModel();
        if (queryModel.isExportQuery()) {
            return false;
        }
        Dialect dialect = DialectUtils.guessDialect(jdbcTemplate.getDataSource());
        return (dialect.isOracleFamily() || dialect.isMySqlFamily());
    }

    @Override
    public Map<String, Object> slicedQuery(String tableName, QueryAbstractRequest queryRequest, String whereClause, Object[] args) {
        QueryModel queryModel = queryRequest.initQueryModel();
        QueryPageRequest pageModel = queryRequest.getPageModel();
        String countSql = new StringBuilder("select count(*) from " + tableName + " t ")
                .append(whereClause).toString();
        long count = jdbcTemplate.queryForObject(countSql, Long.class, args);
        LOG.info(countSql);
        LOG.info("SQL Parameters: {}", ArrayUtils.toString(args));
        List<?> rows = Collections.emptyList();
        if (count > 0) {
            StringBuilder innerSql = new StringBuilder("select t.id from ").append(tableName).append(" t ")
                    .append(whereClause);
            String sortOrders = queryModel.getSortFieldList().stream()
                    .map(sortField -> String.format("%s %s", sortField.getColumnName(), sortField.getDirection()))
                    .distinct()
                    .collect(Collectors.joining(","));
            if (StringUtils.isNotBlank(sortOrders)) {
                innerSql.append(" order by ").append(sortOrders);
            }
            if (pageModel != null) {
                Dialect dialect = DialectUtils.guessDialect(jdbcTemplate.getDataSource());
                PageRequest pageRequest = new PageRequest(pageModel.getPageIndex() - 1, pageModel.getPageSize());
                if (dialect.isOracleFamily()) {
                    // Oracle分页
                    innerSql = new StringBuilder("select t_t_.id  from (select t_.id,rownum as rownum_ from (")
                            .append(innerSql)
                            .append(") t_ where rownum<=" + (pageRequest.getOffset() + pageRequest.getPageSize()))
                            .append(") t_t_ where t_t_.rownum_>=" + pageRequest.getOffset());

                } else if (dialect.isMySqlFamily()) {
                    // MySQL分页
                    innerSql.append(String.format(" limit %s,%s ", pageRequest.getOffset(), pageRequest.getPageSize()));
                }
            }
            String sql = new StringBuilder("select log.* from ")
                    .append(tableName)
                    .append(" log inner join (")
                    .append(innerSql)
                    .append(") lim on log.id=lim.id")
                    .append(" order by ")
                    .append(sortOrders)
                    .toString();
            LOG.info(sql);
            LOG.info("SQL Parameters: {}", ArrayUtils.toString(args));
            SqlRowSet srs = (SqlRowSet) jdbcTemplate.query(sql, args, getResultSetExtractor());
            rows = RowSetUtil.toMapList(srs);
        }

        Map<String, Object> result = new HashMap<>(3);
        result.put(Constants.ROWS, rows);
        result.put(Constants.RECORD, count);
        if (pageModel != null) {
            result.put(Constants.PAGE_PARAM_NAME, pageModel.getPageIndex());
        }
        return result;
    }

    private ResultSetExtractor<?> getResultSetExtractor() {
        Dialect dialect = DialectUtils.guessDialect(jdbcTemplate.getDataSource());
        if (dialect.isOracleFamily()) {
            return new SQLRowSetOracleResultSetExtractor();
        }
        return new SqlRowSetResultSetExtractor();
    }

}
