/*
	Tag recommender tool

	Author:   Steffen Rendle, http://www.libfm.org/
	modified: 2011-07-14

	Copyright 2010-2011 Steffen Rendle, see license.txt for more information
*/
	
#include <cstdlib>
#include <cstdio>
#include <iostream>
#include <string>
#include <iterator>
#include <algorithm>
#include <iomanip>
#include "../util/util.h"
#include "../util/cmdline.h"
#include "../fm_core/fm_model.h"

#include "src/Data.h"
#include "src/tag_rec_pitf.h"
#include "src/tag_rec_fm.h"
#include "src/tag_rec_fm_fast.h"
#include "src/tag_rec_fm_generic.h"


using namespace std;

int main(int argc, char **argv) { 
 	
 	srand ( time(NULL) );
	try {
		CMDLine cmdline(argc, argv);
		std::cout << "Tag Recommender" << std::endl;
		std::cout << "  Version: 1.10" << std::endl;
		std::cout << "  Author:  Steffen Rendle, srendle@ismll.de, http://www.libfm.org/" << std::endl;
		std::cout << "  License: Free for academic use. See license.txt." << std::endl;
		std::cout << "----------------------------------------------------------------------------" << std::endl;

		const std::string param_train_file	= cmdline.registerParameter("train", "filename for training data [MANDATORY]");
		const std::string param_test_file	= cmdline.registerParameter("test", "filename for test data [MANDATORY]");
		const std::string param_out		= cmdline.registerParameter("out", "filename for output; default=''");
		const std::string param_num_pred_out	= cmdline.registerParameter("num_out", "how many tags per post should be written; default=10");

		const std::string param_method		= cmdline.registerParameter("method", "method: 'pitf' or 'fm' or 'fmgeneric' [MANDATORY]");
		const std::string param_dim		= cmdline.registerParameter("dim", "dim of factorization; default=64");
		const std::string param_regular		= cmdline.registerParameter("regular", "regularization; default=0.0");
		const std::string param_init_stdev	= cmdline.registerParameter("init_stdev", "stdev for initialization of 2-way factors; default=0.01");			
		const std::string param_num_iter	= cmdline.registerParameter("iter", "number of iterations for SGD; default=100");
		const std::string param_learn_rate	= cmdline.registerParameter("learn_rate", "learn_rate for SGD; default=0.1");
		const std::string param_num_sample      = cmdline.registerParameter("num_sample", "number of the pair samples drawn for each training tuple, default 100");

		const std::string param_help            = cmdline.registerParameter("help", "this screen");

		if (cmdline.hasParameter(param_help) || (argc == 1)) {
			cmdline.print_help();
			return 0;
		}
		cmdline.checkParameters();

		// (1) Load the data
		std::cout << "Loading train...\t";
		Dataset dataset = Dataset(cmdline.getValue(param_train_file));
		std::cout << "Loading test... \t";
	  	dataset.loadTestSplit(cmdline.getValue(param_test_file));
		
		// (2) Setup the learning method:
		TagRecommender* rec;
		if (! cmdline.getValue(param_method).compare("pitf")) {
			std::cout << "Method: PITF (BPR)" << std::endl;
	 		TagRecommenderPITF *tf = new TagRecommenderPITF();
 			
			tf->loss_function = LOSS_FUNCTION_LN_SIGMOID;
			tf->learn_rate = cmdline.getValue(param_learn_rate, 0.1);
			tf->num_neg_samples = cmdline.getValue(param_num_sample, 100);
	 		tf->num_iterations = cmdline.getValue(param_num_iter, 100);
				
			tf->num_user = dataset.max_user_id+1;
			tf->num_item = dataset.max_item_id+1;
			tf->num_tag  = dataset.max_tag_id+1;
		
			tf->init_mean = 0;
			tf->init_stdev = cmdline.getValue(param_init_stdev, 0.01);
			
			tf->num_feature = cmdline.getValue(param_dim, 64);
			tf->regular = cmdline.getValue(param_regular, 0.0);
			
			tf->init();
			rec = tf;
		} else if ((! cmdline.getValue(param_method).compare("fm")) || (! cmdline.getValue(param_method).compare("fmgeneric"))) {
			TagRecommenderFactorizationMachine *tf;
			if (! cmdline.getValue(param_method).compare("fm")) {
				std::cout << "Method: FM (BPR)" << std::endl;
		 		tf = new TagRecommenderFactorizationMachineFast();
			} else {
				std::cout << "Method: FM (BPR), generic implementation" << std::endl;
		 		tf = new TagRecommenderFactorizationMachineGeneric();
			}

			tf->learn_rate = cmdline.getValue(param_learn_rate, 0.1);
 			tf->num_neg_samples = cmdline.getValue(param_num_sample, 100);
			tf->num_iterations = cmdline.getValue(param_num_iter, 100);
			
			tf->num_user = dataset.max_user_id+1;
			tf->num_item = dataset.max_item_id+1;
			tf->num_tag  = dataset.max_tag_id+1;
			tf->fm.num_attribute = tf->num_user + tf->num_item + tf->num_tag;
		
			tf->fm.init_mean = 0;
			tf->fm.init_stdev = cmdline.getValue(param_init_stdev, 0.01);
			
			tf->fm.k0 = 0;
			tf->fm.k1 = 1;
			tf->fm.num_factor = cmdline.getValue(param_dim, 64);

			tf->fm.init();
			
			tf->fm.reg0 = 0;
			tf->fm.regw.init(0);
			tf->fm.regv.init(cmdline.getValue(param_regular, 0.0));

			tf->init();
			rec = tf;
		} else {
			throw "unknown method";
		}
		rec->N = 10;

		// (3) learning		
		rec->train(dataset);
		std::cout << "model trained" << std::endl;std::cout.flush();
	 	rec->evaluate(&dataset);
		// (4) Save prediction
		if (cmdline.hasParameter(param_out)) {
			rec->savePrediction(dataset.test_posts, cmdline.getValue(param_out), dataset.max_tag_id+1, cmdline.getValue(param_num_pred_out, 10));	 	
		}

	} catch (std::string &e) {
		std::cerr << e << std::endl;
	}

}
