<< Snappy(오픈 소스 압축 툴) 개요 | Home | 생각의 조각들. >>

Mahout - Random Forests

1. Random Forest 개요
Random Forest는 Leo Breiman과 Adele Cutler에 의해 개발된, decision tree들이 여러 개 모여서 만들어진 ensemble classifier이다.
새로운 object에 대한 input vector를 이용하여 분류하기 위해서는, forest내 각 tree에 input vector을 넣는다. 각 tree는 classification 결과를 주게 되는데, 이를 그 class에 대한 "votes"라 한다. forest는 가장 많은 vote를 가진 class를 선택하게 된다.
각 tree는 다음과 같이 grow된다:
   1. 만약 training set의 case(class) 개수가 N 이라면, 랜덤하게 N개의 case를 샘플링한다.
   - 그러나 replacement를 가지고 original data로부터 샘플링한다. 그 샘플은 트리를 growing시키기 위한 training set 이 된다.
   2. 만약 M개의 input variable가 존재하고, m<<M이 node를 best split 하는 데 사용되는 M개 중에 랜덤하게 선택된 variable라 하면, m 값은 forest growing 동안에 constant 하다.
   3. 각 tree는 가능한한 가장 크게 확장되며, pruning은 없다.

아래는 Hadoop 기반에 돌아가는 샘플 테스트를 위한 방법을 기술한다.

2. 데이터 준비
- 출처 : http://nsl.cs.unb.ca/NSL-KDD/
- 학습(Train) 데이터 : dataid, attributes1, attribute2 ...
. Path : testdata/kdd/KDDTrain.arff
. 데이터 정의 행 (@로 시작하는 행)를 삭제.
> hadoop fs -put testdata//kdd/KDDTrain.arff testdata/kdd

- 테스트 데이터 : dataid, attributes1, attribute2 ...
. Path : testdata/kdd/KDDTest.arff
. 데이터 정의 행 (@로 시작하는 행)를 삭제.
> hadoop fs -put testdata//kdd/KDDTest.arff testdata/kdd

- 출력 정보 : dataid, predictionIndex, realLabelIndex
. Path : testdata/kdd/predictions/part-m-00000

3. Random Forest 샘플 소스들
- ForestDriver.java(Random Forest 분산 학습·결정)
Random Forest 분산 학습에서 분산 결정하고 결정 결과 출력 및 정밀도 평가까지 할 Driver.
public class ForestDriver extends AbstractJob {
	public static void run(Configuration conf, String dataPathName,
			List description, String descriptorPathName,
			String forestPathName, int m, Long seed, int nbTrees,
			boolean isPartial, String testdataPathName,
			String predictionOutPathName, boolean complemented,
			Integer minSplitNum, Double minVarianceProportion) 
			throws DescriptorException, IOException, 
			ClassNotFoundException, InterruptedException {

		// Create Descriptor
		Path dataPath = validateInput(dataPathName);
		Path descriptorPath = validateOutput(descriptorPathName);
		createDescriptor(conf, dataPath, description, descriptorPath);

		// Build Forest
		Path forestPath = validateOutput(forestPathName);
		buildForest(conf, dataPath, description, descriptorPath, forestPath, 
		 m, seed, nbTrees, isPartial, complemented, 
				minSplitNum, minVarianceProportion);

		// Predict
		boolean analyze = true;
		ForestClassificationDriver.run(conf, forestPathName, testdataPathName,
				descriptorPathName, predictionOutPathName, analyze);
	}

	/**
	 * Greate Descreptor
	 */
	private static void createDescriptor(Configuration conf, Path dataPath,
			List description, Path outPath) throws DescriptorException,
			IOException {

		log.info("Generating the descriptor...");
		String descriptor = DescriptorUtils.generateDescriptor(description);
		log.info("generating the descriptor dataset...");
		Dataset dataset = generateDataset(descriptor, dataPath);
		log.info("storing the dataset description");

		DFUtils.storeWritable(conf, outPath, dataset);
	}

	/**
	 * Generate DataSet
	 */
	private static Dataset generateDataset(String descriptor, Path dataPath)
			throws IOException, DescriptorException {

		FileSystem fs = dataPath.getFileSystem(new Configuration());
		Path[] files = DFUtils.listOutputFiles(fs, dataPath);
		return DataLoader.generateDataset(descriptor, false, fs, files[0]);
				//.generateDataset(descriptor, fs, files[0]);
	}

	/**
	 * Build Forest
	 */
	private static void buildForest(Configuration conf, Path dataPath,
			List description, Path descriptorPath, Path forestPath,
			int m, Long seed, int nbTrees, boolean isPartial, /*boolean isOob,*/ 
			boolean complemented, Integer minSplitNum, Double minVarianceProportion)
			throws IOException, ClassNotFoundException, InterruptedException {

		FileSystem ofs = forestPath.getFileSystem(conf);
		if (ofs.exists(forestPath)) {
			log.error("Forest Output Path already exists");
			return;
		}

		DecisionTreeBuilder treeBuilder = new DecisionTreeBuilder();
		treeBuilder.setM(m);
    treeBuilder.setComplemented(complemented);
    if (minSplitNum != null) {
      treeBuilder.setMinSplitNum(minSplitNum);
    }
    if (minVarianceProportion != null) {
      treeBuilder.setMinVarianceProportion(minVarianceProportion);
    }
		Builder forestBuilder;

		if (isPartial) {
			log.info("Partial Mapred");
			forestBuilder = new PartialBuilder(treeBuilder, dataPath,
					descriptorPath, seed, conf);
		} else {
			log.info("InMemory Mapred");
			forestBuilder = new InMemBuilder(treeBuilder, dataPath,
					descriptorPath, seed, conf);
		}

		forestBuilder.setOutputDirName(forestPath.getName());
		log.info("Building the forest...");
		long time = System.currentTimeMillis();

		DecisionForest forest = forestBuilder.build(nbTrees);
		time = System.currentTimeMillis() - time;
		log.info("Build Time: {}", DFUtils.elapsedTime(time));
	  log.info("Forest num Nodes: {}", forest.nbNodes());
	  log.info("Forest mean num Nodes: {}", forest.meanNbNodes());
	  log.info("Forest mean max Depth: {}", forest.meanMaxDepth());
	    
		// store the forest
		Path forestoutPath = new Path(forestPath, "forest.seq");
		log.info("Storing the forest in: " + forestoutPath);
		DFUtils.storeWritable(conf, forestPath, forest);
	}

	/**
	 * Load data
	 */
	protected static Data loadData(Configuration conf, Path dataPath,
			Dataset dataset) throws IOException {
		log.info("Loading the data...");
		FileSystem fs = dataPath.getFileSystem(conf);
		Data data = DataLoader.loadData(dataset, fs, dataPath);
		log.info("Data Loaded");

		return data;
	}

	/**
	 * Convert Collections to a String List
	 */
	private static List convert(Collection values) {
		List list = new ArrayList(values.size());
		for (Object value : values) {
			list.add(value.toString());
		}
		return list;
	}

	/**
	 * Validation of the Output Path
	 */
	private static Path validateOutput(String filePath) throws IOException {
		Path path = new Path(filePath);
		FileSystem fs = path.getFileSystem(new Configuration());
		if (fs.exists(path)) {
			throw new IllegalStateException(path.toString() + " already exists");
		}
		return path;
	}

	/**
	 * Validation of the Input Path
	 */
	private static Path validateInput(String filePath) throws IOException {
		Path path = new Path(filePath);
		FileSystem fs = path.getFileSystem(new Configuration());
		if (!fs.exists(path)) {
			throw new IllegalArgumentException(path.toString()
					+ " does not exist");
		}
		return path;
	}

	public static void main(String[] args) throws Exception {
		ToolRunner.run(new Configuration(), new ForestDriver(), args);
	}
}

- ForestClassificationDriver.java(분산 판별)
생성된 Forest Model을 이용하여 분산 결정하고 결정 결과 출력 및 정밀도 평가까지 할 Driver.
public class ForestClassificationDriver extends AbstractJob 
{
  .....
	public static void run(Configuration conf, String forestPathName,
			String testDataPathName, String descriptorPathName,
			String predictionPathName, boolean analyze) throws IOException,
			ClassNotFoundException, InterruptedException {

		// Classify data
		Path testDataPath = validateInput(testDataPathName);
		Path descriptorPath = validateInput(descriptorPathName);

		Path forestPath = validateInput(forestPathName);
		Path predictionPath = validateOutput(conf, predictionPathName);

		ForestClassifier classifier = new ForestClassifier(conf, forestPath,
				testDataPath, descriptorPath, predictionPath, analyze);
		classifier.run();

		// Analyze Results
		if (analyze) {
			log.info(classifier.getAnalyzer().toString());
		}
	}

	/**
	 * Validation of the Output Path
	 */
	private static Path validateOutput(Configuration conf, String filePath)
			throws IOException {
		Path path = new Path(filePath);
		FileSystem fs = path.getFileSystem(conf);
		if (fs.exists(path)) {
			throw new IllegalStateException(path.toString() + " already exists");
		}
		return path;
	}

	/**
	 * Validation of the Input Path
	 */
	private static Path validateInput(String filePath) throws IOException {
		Path path = new Path(filePath);
		FileSystem fs = path.getFileSystem(new Configuration());
		if (!fs.exists(path)) {
			throw new IllegalArgumentException(path.toString()
					+ " does not exist");
		}
		return path;
	}
}

- ForestClassifier.java
public class ForestClassifier {
	private static final Logger log = LoggerFactory
			.getLogger(ForestClassifier.class);
	private final Path forestPath;
	private final Path inputPath;
	private final Path datasetPath;
	private final Configuration conf;

	private final ResultAnalyzer analyzer;
	private final Dataset dataset;
	private final Path outputPath;

	public ForestClassifier(Configuration conf, Path forestPath,
			Path inputPath, Path datasetPath, Path outputPath, boolean analyze)
			throws IOException {
		this.forestPath = forestPath;
		this.inputPath = inputPath;
		this.datasetPath = datasetPath;
		this.conf = conf;

		if (analyze) {
			dataset = Dataset.load(conf, datasetPath);
			analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()),
					"unknown");
		} else {
			dataset = null;
			analyzer = null;
		}
		this.outputPath = outputPath;
	}

	/**
	 * Classification Job Configure
	 */
	private void configureClsssifyJob(Job job) throws IOException {

		job.setJarByClass(ForestClassifier.class);

		FileInputFormat.setInputPaths(job, inputPath);
		FileOutputFormat.setOutputPath(job, outputPath);

		job.setOutputKeyClass(Text.class);
		job.setOutputValueClass(Text.class);

		job.setMapperClass(ClassifyMapper.class);
		job.setNumReduceTasks(0); // Classification Mapper Only

		job.setInputFormatClass(ClassifyTextInputFormat.class);
		job.setOutputFormatClass(TextOutputFormat.class);
	}

	public void run() throws IOException, ClassNotFoundException,
			InterruptedException {
		FileSystem fs = FileSystem.get(conf);

		if (fs.exists(outputPath)) {
			throw new IOException(outputPath + " already exists");
		}

		// put the dataset
		log.info("Adding the dataset to the DistributedCache");
		DistributedCache.addCacheFile(datasetPath.toUri(), conf);

		// load the forest
		log.info("Adding the forest to the DistributedCache");
		DistributedCache.addCacheFile(forestPath.toUri(), conf);

		// Classification
		Job cjob = new Job(conf, "Decision Forest classification");
		log.info("Configuring the Classification Job...");
		configureClsssifyJob(cjob);

		log.info("Running the Classification Job...");
		if (!cjob.waitForCompletion(true)) {
			log.error("Classification Job failed!");
			return;
		}

		// Analyze Results
		if (analyzer != null) {
			analyzeOutput(cjob);
		}
	}

	public ResultAnalyzer getAnalyzer() {
		return analyzer;
	}

	/**
	 * Analyze the Classification Results
	 * 
	 * @param job
	 */
	private void analyzeOutput(Job job) throws IOException {
		Configuration conf = job.getConfiguration();
		double prediction;
		double realLabel;

		FileSystem fs = outputPath.getFileSystem(conf);
		Path[] outfiles = DFUtils.listOutputFiles(fs, outputPath);

		int cnt_cnp = 0;

		for (Path path : outfiles) {
			FSDataInputStream input = fs.open(path);
			Scanner scanner = new Scanner(input);

			while (scanner.hasNext()) {
				String line = scanner.nextLine();
				if (line.isEmpty()) {
					continue;
				}

				// id, predict, realLabel with \t sep
				String[] tmp = line.split("\t", -1);
				prediction = Double.parseDouble(tmp[1]);
				realLabel = Double.parseDouble(tmp[2]);

				if (prediction == -1) {
					// label cannot be predicted
					cnt_cnp++;
				} else {
					if (analyzer != null) {
				analyzer.addInstance(
						dataset.getLabelString(prediction),
				new ClassifierResult(dataset.getLabelString(realLabel), 1.0));
					}
				}
			}
		}
		log.info("# of data which cannot be predicted: " + cnt_cnp);
	}

	/**
	 * Text Input Format: Each file is processed by single Mapper.
	 */
	public static class ClassifyTextInputFormat extends TextInputFormat {
		@Override
		protected boolean isSplitable(JobContext jobContext, Path path) {
			return false;
		}
	}

	/**
	 * Classification Mapper.
	 */
	public static class ClassifyMapper extends
			Mapper {

		private DataConverter converter;
		private DecisionForest forest;
		private final Random rng = RandomUtils.getRandom();
		private Dataset dataset;
		
		private final Text val = new Text();

		@Override
		protected void setup(Context context) throws IOException,
				InterruptedException {
			super.setup(context);//To change body of overridden methods use

			Configuration conf = context.getConfiguration();
			URI[] files = DistributedCache.getCacheFiles(conf);

			if ((files == null) || (files.length < 2)) {
				throw new IOException(
						"not enough paths in the DistributedCache");
			}

			dataset = Dataset.load(conf, new Path(files[0].getPath()));
			converter = new DataConverter(dataset);
			forest = DecisionForest.load(conf, new Path(files[1].getPath()));
			if (forest == null) {
				throw new InterruptedException("DecisionForest not found!");
			}
		}

		@Override
		protected void map(LongWritable key, Text value, Context context)
				throws IOException, InterruptedException {

			String line = value.toString();
			if (!line.isEmpty()) {
				String[] idVal = line.split(",", -1);
				Integer id = Integer.parseInt(idVal[0]);
				Instance instance = converter.convert(line);
				double prediction = forest.classify(dataset, rng, instance);

				// key:id
				key.set(id);
				// val: prediction, originalLabel (with tab sep)
				StringBuffer sb = new StringBuffer();
				sb.append(Double.toString(prediction));
				sb.append("\t");
				sb.append(dataset.getLabel(instance));
				val.set(sb.toString());

				context.write(key, val);
			}
		}
	}
}

4. 실행 방법 및 실행 결과
> bin/hadoop jar mahout-forest-1.0-jar-with-dependencies.jar \
com.mimul.mahout.forest.ForestDriver \
-Dmapred.max.split.size=5074231 \
-d testdata/kdd/KDDTrain.arff \
-ds testdata/kdd/KDD.info \
-fo testdata/kdd/forest \
-dsc N 4 C 2 N C 4 N C 8 N 2 C 19 N L \
-oob \
-sl 7 \
-p \
-t 100 \
-td testdata/kdd/KDDTest.arff \
-po testdata/kdd/predictions

forest.ForestDriver: Generating the descriptor...
forest.ForestDriver: generating the descriptor dataset...
forest.ForestDriver: storing the dataset description
forest.ForestDriver: Partial Mapred
forest.ForestDriver: Building the forest...
input.FileInputFormat: Total input paths to process : 1
mapred.JobClient: Running job: job_201204031111_0006
mapred.JobClient:  map 0% reduce 0%
mapred.JobClient:  map 66% reduce 0%
mapred.JobClient:  map 100% reduce 0%
mapred.JobClient: Job complete: job_201204031111_0006
mapred.JobClient: Counters: 20
mapred.JobClient:   Job Counters 
mapred.JobClient:     SLOTS_MILLIS_MAPS=191122
mapred.JobClient:     Total time spent by all reduces waiting after reserving slots (ms)=0
mapred.JobClient:     Total time spent by all maps waiting after reserving slots (ms)=0
mapred.JobClient:     Rack-local map tasks=2
mapred.JobClient:     Launched map tasks=3
mapred.JobClient:     Data-local map tasks=1
mapred.JobClient:     SLOTS_MILLIS_REDUCES=0
mapred.JobClient:   File Output Format Counters 
mapred.JobClient:     Bytes Written=10360478
mapred.JobClient:   FileSystemCounters
mapred.JobClient:     HDFS_BYTES_READ=15458862
mapred.JobClient:     FILE_BYTES_WRITTEN=67706
mapred.JobClient:     HDFS_BYTES_WRITTEN=10360478
mapred.JobClient:   File Input Format Counters 
mapred.JobClient:     Bytes Read=15380922
mapred.JobClient:   Map-Reduce Framework
mapred.JobClient:     Map input records=125973
mapred.JobClient:     Physical memory (bytes) snapshot=348418048
mapred.JobClient:     Spilled Records=0
mapred.JobClient:     CPU time spent (ms)=171750
mapred.JobClient:     Total committed heap usage (bytes)=166068224
mapred.JobClient:     Virtual memory (bytes) snapshot=1749139456
mapred.JobClient:     Map output records=100
mapred.JobClient:     SPLIT_RAW_BYTES=345
common.HadoopUtil: Deleting hdfs://kth:9000/user/k2/forest
forest.ForestDriver: Build Time: 0h 1m 28s 328
forest.ForestDriver: Forest num Nodes: 526674
forest.ForestDriver: Forest mean num Nodes: 5266
forest.ForestDriver: Forest mean max Depth: 17
forest.ForestDriver: Storing the forest in: testdata/kdd/forest/forest.seq
forest.ForestClassifier: Adding the dataset to the DistributedCache
forest.ForestClassifier: Adding the forest to the DistributedCache
forest.ForestClassifier: Configuring the Classification Job...
forest.ForestClassifier: Running the Classification Job...
input.FileInputFormat: Total input paths to process : 1
mapred.JobClient: Running job: job_201204031111_0007
mapred.JobClient:  map 0% reduce 0%
mapred.JobClient:  map 100% reduce 0%
mapred.JobClient: Job complete: job_201204031111_0007
mapred.JobClient: Counters: 19
mapred.JobClient:   Job Counters 
mapred.JobClient:     SLOTS_MILLIS_MAPS=15369
mapred.JobClient:     Total time spent by all reduces waiting after reserving slots (ms)=0
mapred.JobClient:     Total time spent by all maps waiting after reserving slots (ms)=0
mapred.JobClient:     Rack-local map tasks=1
mapred.JobClient:     Launched map tasks=1
mapred.JobClient:     SLOTS_MILLIS_REDUCES=0
mapred.JobClient:   File Output Format Counters 
mapred.JobClient:     Bytes Written=360704
mapred.JobClient:   FileSystemCounters
mapred.JobClient:     HDFS_BYTES_READ=13148874
mapred.JobClient:     FILE_BYTES_WRITTEN=22093
mapred.JobClient:     HDFS_BYTES_WRITTEN=360704
mapred.JobClient:   File Input Format Counters 
mapred.JobClient:     Bytes Read=2766567
mapred.JobClient:   Map-Reduce Framework
mapred.JobClient:     Map input records=22544
mapred.JobClient:     Physical memory (bytes) snapshot=124559360
mapred.JobClient:     Spilled Records=0
mapred.JobClient:     CPU time spent (ms)=4300
mapred.JobClient:     Total committed heap usage (bytes)=99090432
mapred.JobClient:     Virtual memory (bytes) snapshot=555745280
mapred.JobClient:     Map output records=22544
mapred.JobClient:     SPLIT_RAW_BYTES=114
forest.ForestClassifier: # of data which cannot be predicted: 0
forest.ForestClassificationDriver: ================================
Summary
-------------------------------------------------------
Correctly Classified Instances     :  18036    80.0035%
Incorrectly Classified Instances   :   4508    19.9965%
Total Classified Instances         :  22544

=======================================================
Confusion Matrix
-------------------------------------------------------
a       b       <--Classified as
9456    4253     |  13709       a     = normal
255     8580     |  8835        b     = anomaly

- Hadoop filesystem 조회
> hadoop fs -ls testdata/kdd/
Found 5 items
-rw-r--r-- 1 k2 sg   25865 2012-04-09 /user/k2/testdata/kdd/KDD.info
-rw-r--r-- 1 k2 sg 2766567 2012-04-09 /user/k2/testdata/kdd/KDDTest.arff
-rw-r--r-- 1 k2 sg15379495 2012-04-09 /user/k2/testdata/kdd/KDDTrain.arff
-rw-r--r-- 1 k2 sg10356328 2012-04-09 /user/k2/testdata/kdd/forest
drwxr-xr-x - k2 sg       0 2012-04-09 /user/k2/testdata/kdd/predictions

> hadoop fs -ls testdata/kdd/predictions
Found 3 items
-rw-r--r-- 1 k2 sg      0 2012-04-09 /user/k2/testdata/kdd/predictions/_SUCCESS
drwxr-xr-x - k2 sg      0 2012-04-09 /user/k2/testdata/kdd/predictions/_logs
-rw-r--r-- 1 k2 sg 360704 2012-04-09 /user/k2/testdata/kdd/predictions/part-m-00000

hadoop fs -cat testdata/kdd/predictions/part-m-00000
2022533 0.0     0.0
2022534 0.0     0.0
2022535 1.0     1.0
2022536 0.0     0.0
2022537 1.0     1.0
2022538 0.0     1.0
2022539 1.0     1.0
2022540 0.0     0.0
2022541 0.0     0.0
2022542 1.0     1.0
2022543 0.0     0.0
2022544 1.0     1.0

5. R로 표현된 Random Forests
>install.packages('randomForest')
>library(randomForest)
> iris <- read.csv("C:\\Project\\workspace\\2012\\mahout-random_forest\\
 src\\test\\resources\\iris.csv", header=F, sep=",", dec=".", quote="")
> 
> train = iris[ c(11:50, 61:100, 111:150), ]
> test = iris[ c(1:10, 51:60, 101:110), ]
> r = randomForest(V5~., data=train, importance=TRUE, do.trace=100)
ntree      OOB      1      2      3
  100:   4.17%  0.00%  7.50%  5.00%
  200:   5.83%  0.00% 10.00%  7.50%
  300:   5.83%  0.00% 10.00%  7.50%
  400:   5.83%  0.00% 10.00%  7.50%
  500:   5.83%  0.00% 10.00%  7.50%
> iris.predict = predict(r, test)
> t = table(observed=test[,'V5'], predict=iris.predict)
> print(r)

Call:
 randomForest(formula=V5~., data=train, importance=TRUE,do.trace = 100) 
     Type of random forest: classification
              Number of trees: 500
No. of variables tried at each split: 2
   OOB estimate of  error rate: 5.83%
Confusion matrix:
           setosa versicolor virginica class.error
setosa         40          0         0       0.000
versicolor      0         36         4       0.100
virginica       0          3        37       0.075

6. 기타
Weka 기반의 Random Forest가 궁금하신 분들은 여기를 보시면 됩니다.
[참조 사이트]



Add a comment Send a TrackBack