package com.topsunit.query.binding;

import com.topsunit.query.annotations.Param;

import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.util.Arrays;
import java.util.Collection;
import java.util.Objects;
import java.util.stream.Collectors;

/**
 * @author yonghuan
 */
public class DefaultSqlTranslator implements SqlTranslator {

    @Override
    public String translate(String sql, Method method, Object[] args) {
        if (useNamedParameter(method, args)) {
            Parameter[] parameters = method.getParameters();
            for (int i = 0; i < parameters.length; i++) {
                Parameter parameter = parameters[i];
                Param param = parameter.getAnnotation(Param.class);
                String paramName = param.value();
                Object paramValue = args[i];
                if (paramValue instanceof Collection) {
                    String str = (String) ((Collection) paramValue).stream()
                            .map(obj -> "?")
                            .collect(Collectors.joining(","));
                    sql = sql.replaceAll(":" + paramName, str);
                } else {
                    sql = sql.replaceAll(":" + paramName, "?");
                }
            }
        }
        // 不要试图缓存解析生成SQL，即使是同一个源SQL，不同的参数（args）也会解析生成不同的SQL
        // 例如有如下SQL：select t.id,t.name from t where t.id in(:ids)，ids的长度不同，替换生成占位符（?）个数也会不同
        return sql;
    }

    private boolean useNamedParameter(Method method, Object[] args) {
        long paramAnnotationCount = Arrays.stream(method.getParameters())
                .map(parameter -> parameter.getAnnotation(Param.class))
                .filter(Objects::nonNull)
                .count();
        if (paramAnnotationCount == 0) {
            return false;
        }
        if (paramAnnotationCount != args.length) {
            throw new IllegalArgumentException("Param注解数量与实参数量不一致");
        }
        return true;
    }

}
