package com.ximai.common.utils.reflect;

import java.io.File;
import java.net.URL;
import java.net.URLDecoder;
import java.util.Enumeration;
import java.util.LinkedHashSet;
import java.util.Set;

public class PackageUtil {

    public static Set<Class<?>> getAllClasses(Package pack) {
        // 获取包名 如 com.example
        String packageName = pack.getName();
        // 获取类加载器
        ClassLoader classLoader = Thread.currentThread().getContextClassLoader();
        Set<Class<?>> classes = new LinkedHashSet<>();
        try {
            // 获取该包名下所有资源的URL
            Enumeration<URL> resources = classLoader.getResources(packageName.replace(".", "/"));
            while (resources.hasMoreElements()) {
                URL resource = resources.nextElement();
                // 获取协议名称
                String protocol = resource.getProtocol();
                if ("file".equals(protocol)) {
                    // 如果是文件协议, 则找到class文件
                    String filePath = URLDecoder.decode(resource.getFile(), "UTF-8");
                    findClassesInPackageByFile(packageName, filePath, classes);
                }
            }
        } catch (Exception e) {
            e.printStackTrace();
        }
        return classes;
    }

    private static void findClassesInPackageByFile(String packageName, String packagePath, Set<Class<?>> classes) throws ClassNotFoundException {
        File dir = new File(packagePath);
        File[] files = dir.listFiles();
        for (File file : files) {
            if (file.isDirectory()) {
                // 递归调用
                findClassesInPackageByFile(packageName + "." + file.getName(), file.getAbsolutePath(), classes);
            } else if (file.getName().endsWith(".class")) {
                String className = packageName + "." + file.getName().substring(0, file.getName().length() - 6);
                classes.add(Thread.currentThread().getContextClassLoader().loadClass(className));
            }
        }
    }
}