/*
	Tagrecommender based on the PITF method.

	Based on the publication(s):
	Steffen Rendle, Lars Schmidt-Thieme (2010): Pairwise Interaction Tensor Factorization for Personalized Tag Recommendation, in Proceedings of the Third ACM International Conference on Web Search and Data Mining (WSDM 2010), ACM.

	Author:   Steffen Rendle, http://www.libfm.org/
	modified: 2010-12-10

	Copyright 2010 Steffen Rendle, see license.txt for more information
*/

#ifndef TAG_REC_PITF_H_
#define TAG_REC_PITF_H_

#include "BPRLearner.h"

class TagRecommenderPITF : public TagRecommender {
	protected:	
		DMatrixDouble U, I, T_U, T_I;	
	public:	
		int loss_function;
		int num_neg_samples;
		int num_iterations;
		double learn_rate;

		int num_feature;
		double regular;
		
		int num_user;
		int num_item;
		int num_tag;
	
		double init_stdev;
		double init_mean;
				
		virtual void train(Dataset& dataset) {
			TagLearnerBPR learner;
			learner.num_iterations = this->num_iterations;
			learner.num_neg_samples = this->num_neg_samples;
			learner.train(dataset, *this);
		}
				
		virtual void init() {
			this->U.setSize(num_user, num_feature);
			this->I.setSize(num_item, num_feature);			
			this->T_U.setSize(num_tag,  num_feature);
			this->T_I.setSize(num_tag,  num_feature);

			this->U.init(init_mean, init_stdev);
			this->I.init(init_mean, init_stdev);			
			this->T_U.init(init_mean, init_stdev);
			this->T_I.init(init_mean, init_stdev);
		}
				
		virtual void predictTopTags(int user_id, int item_id, WeightedTag* tags, int num_tags) {
			for (int t_x = 0; t_x < num_tags; t_x++) {
      				tags[t_x].weight = predict(user_id, item_id, tags[t_x].tag_id);
			}
		}
		
		virtual double predict(int user_id, int item_id, int tag_id) {
			double result = 0;
			double u_dot_t = 0;
			double i_dot_t = 0;
			for (int f = 0; f < num_feature; f++) {
				u_dot_t += U(user_id,f) * T_U(tag_id,f);
				i_dot_t += I(item_id,f) * T_I(tag_id,f);
			}
			result += u_dot_t + i_dot_t;
			
			if (isnan(result)) {
				throw "Prediction is NAN";
			}
			return result;
		}
		
		inline virtual void learn(int user_id, int item_id, int tp_id, int tn_id) {
		double x_uitp = predict(user_id, item_id, tp_id);
     		double x_uitn = predict(user_id, item_id, tn_id);
     		double normalizer = TagLearner::partial_loss(loss_function, x_uitp - x_uitn);
     		
     		// update the features:
     		for (int f = 0; f < num_feature; f++) {
     			double u_f = U(user_id,f);
     			double i_f = I(item_id,f);
     			double t_p_U_f = T_U(tp_id,f);
     			double t_p_I_f = T_I(tp_id,f);
     			double t_n_U_f = T_U(tn_id,f);
     			double t_n_I_f = T_I(tn_id,f);
     			
	     		U(user_id,f) += learn_rate * (normalizer * (t_p_U_f - t_n_U_f) - regular * U(user_id,f));
	     		I(item_id,f) += learn_rate * (normalizer * (t_p_I_f - t_n_I_f) - regular * I(item_id,f));
			T_U(tp_id,f) += learn_rate * (normalizer * (u_f) - regular * T_U(tp_id,f));
			T_U(tn_id,f) += learn_rate * (normalizer * (-u_f) - regular * T_U(tn_id,f));
			T_I(tp_id,f) += learn_rate * (normalizer * (i_f) - regular * T_I(tp_id,f));
			T_I(tn_id,f) += learn_rate * (normalizer * (-i_f) - regular * T_I(tn_id,f));
     		}
     	}

};

#endif /*TAG_REC_PITF_H_*/
