package jp.agentec.sinaburocast.servlet;

import java.lang.ref.Reference;
import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.List;

import javax.servlet.ServletConfig;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;

import net.sf.ehcache.CacheManager;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.seasar.framework.util.DriverManagerUtil;

/**
 * StartupServletクラス<br>
 *
 * @author th-tsukada
 *
 */
public class StartupServlet extends HttpServlet {
	private static final long serialVersionUID = 1L;
	private final static Log logger= LogFactory.getLog(StartupServlet.class);
	private static String WEBINF_PATH;

	public StartupServlet() {
	}

	public static String getWebInfPath() {
		return WEBINF_PATH;
	}

	@Override
	public void destroy() {
		super.destroy();
		DriverManagerUtil.deregisterAllDrivers();
		CacheManager.getInstance().shutdown();
		stopAllThreads();
	}

	@Override
	public void init() throws ServletException { // init両方呼ばれるので注意
		super.init();
	}

	@Override
	public void init(ServletConfig config) throws ServletException {
		super.init(config);
		WEBINF_PATH = this.getServletContext().getRealPath("/WEB-INF/"); // staticへ書込みOK
	}

	/**
	 * ThreadLocal をクリアする。
	 * これをしないと、FATALエラーが、tomcat終了時にて続ける。
	 * 下記サイトにあったソースをコピペ
	 * http://d.hatena.ne.jp/shinsuke_sugaya/20100211/1265857652
	 *
	 */
    private void stopAllThreads() {
        Thread[] threads = getThreads();
        ClassLoader cl = this.getClass().getClassLoader();

        List<String> jvmThreadGroupList = new ArrayList<String>();
        jvmThreadGroupList.add("system");
        jvmThreadGroupList.add("RMI Runtime");

        // Iterate over the set of threads
        for (Thread thread : threads) {
            if (thread != null) {
                ClassLoader ccl = thread.getContextClassLoader();
                if (ccl != null && ccl == cl) {
                    // Don't warn about this thread
                    if (thread == Thread.currentThread()) {
                        continue;
                    }

                    // Don't warn about JVM controlled threads
                    ThreadGroup tg = thread.getThreadGroup();
                    if (tg != null && jvmThreadGroupList.contains(tg.getName())) {
                        continue;
                    }

                    waitThread(thread);
                    // Skip threads that have already died
                    if (!thread.isAlive()) {
                        continue;
                    }

                    if (logger.isInfoEnabled()) {
                        logger.info("Interrupting a thread ["
                                + thread.getName() + "]...");
                    }
                    thread.interrupt();

                    waitThread(thread);
                    // Skip threads that have already died
                    if (!thread.isAlive()) {
                        continue;
                    }

                    if (logger.isInfoEnabled()) {
                        logger.info("Stopping a thread [" + thread.getName()
                                + "]...");
                    }
                    thread.stop();
                }
            }
        }

        Field threadLocalsField = null;
        Field inheritableThreadLocalsField = null;
        Field tableField = null;
        try {
            threadLocalsField = Thread.class.getDeclaredField("threadLocals");
            threadLocalsField.setAccessible(true);
            inheritableThreadLocalsField = Thread.class
                    .getDeclaredField("inheritableThreadLocals");
            inheritableThreadLocalsField.setAccessible(true);
            // Make the underlying array of ThreadLoad.ThreadLocalMap.Entry objects
            // accessible
            Class<?> tlmClass = Class
                    .forName("java.lang.ThreadLocal$ThreadLocalMap");
            tableField = tlmClass.getDeclaredField("table");
            tableField.setAccessible(true);
        } catch (Exception e) {
            // ignore
        }
        for (Thread thread : threads) {
            if (thread != null) {

                Object threadLocalMap;
                try {
                    // Clear the first map
                    threadLocalMap = threadLocalsField.get(thread);
                    clearThreadLocalMap(cl, threadLocalMap, tableField);
                } catch (Exception e) {
                	// ignore
                }
                try { // Clear the second map
                    threadLocalMap = inheritableThreadLocalsField.get(thread);
                    clearThreadLocalMap(cl, threadLocalMap, tableField);
                } catch (Exception e) {
                    // ignore
                }
            }
        }
    }

    private void waitThread(Thread thread) {
        int count = 0;
        while (thread.isAlive() && count < 5) {
            try {
                Thread.sleep(100);
            } catch (InterruptedException e) {
            }
            count++;
        }
    }

    /*
     * Get the set of current threads as an array.
     */
    private Thread[] getThreads() {
        // Get the current thread group
        ThreadGroup tg = Thread.currentThread().getThreadGroup();
        // Find the root thread group
        while (tg.getParent() != null) {
            tg = tg.getParent();
        }

        int threadCountGuess = tg.activeCount() + 50;
        Thread[] threads = new Thread[threadCountGuess];
        int threadCountActual = tg.enumerate(threads);
        // Make sure we don't miss any threads
        while (threadCountActual == threadCountGuess) {
            threadCountGuess *= 2;
            threads = new Thread[threadCountGuess];
            // Note tg.enumerate(Thread[]) silently ignores any threads that
            // can't fit into the array
            threadCountActual = tg.enumerate(threads);
        }

        return threads;
    }

    private void clearThreadLocalMap(ClassLoader cl, Object map,
            Field internalTableField) throws NoSuchMethodException,
            IllegalAccessException, NoSuchFieldException,
            InvocationTargetException {
        if (map != null) {
            Method mapRemove = map.getClass().getDeclaredMethod("remove",
                    ThreadLocal.class);
            mapRemove.setAccessible(true);
            Object[] table = (Object[]) internalTableField.get(map);
            if (table != null) {
                for (Object element : table) {
                    if (element != null) {
                        boolean remove = false;
                        // Check the key
                        Field keyField = Reference.class
                                .getDeclaredField("referent");
                        keyField.setAccessible(true);
                        Object key = keyField.get(element);
                        if (cl.equals(key)
                                || key != null && cl == key.getClass()
                                        .getClassLoader()) {
                            remove = true;
                        }
                        // Check the value
                        Field valueField = element.getClass()
                                .getDeclaredField("value");
                        valueField.setAccessible(true);
                        Object value = valueField.get(element);
                        if (cl.equals(value)
                                || value != null && cl == value.getClass()
                                        .getClassLoader()) {
                            remove = true;
                        }
                        if (remove) {
                            Object entry = ((Reference<?>) element).get();
                            if (logger.isInfoEnabled()) {
                                logger.info("Removing " + key.toString()
                                        + " from a thread local...");
                            }
                            mapRemove.invoke(map, entry);
                        }
                    }
                }
            }
        }
    }

}