/*
	Tagrecommender based on Factorization Machines using binary user, item, tag indicators.

	Based on the publication(s):
	Steffen Rendle (2010): Factorization Machines, in Proceedings of the 10th IEEE International Conference on Data Mining (ICDM 2010), Sydney, Australia.

	Notes:
	This implementation integrates the prediction and gradient steps of the fm model for the special case of pitf-like data.
	An implementation using the generic prediction and gradient functions that are offered by the FM library can be found in tag_rec_fm_generic.h.
	Both implementations are offered to demonstrate (1) how the generic functionality can be used without the need to implement any prediction and
	gradient steps and (2) how to get a faster implementation if both steps are integrated.
	Note that both implementations have the same runtime complexity. 

	Author:   Steffen Rendle, http://www.libfm.org/
	modified: 2010-07-14

	Copyright 2010-2011 Steffen Rendle, see license.txt for more information
*/

#ifndef TAGRECOMMENDERFACTORIZATIONMACHINE_FAST_H_
#define TAGRECOMMENDERFACTORIZATIONMACHINE_FAST_H_

#include "../../fm_core/fm_data.h"
#include "../../fm_core/fm_model.h"
#include "../../fm_core/fm_sgd.h"



class TagRecommenderFactorizationMachineFast : public TagRecommenderFactorizationMachine {
	public:

		virtual double predict(int user_id, int item_id, int tag_id) {
			double result = fm.w(tag_id+num_user+num_item);
                        for (int f = 0; f < fm.num_factor; f++) {
                                result += (fm.v(user_id, f) + fm.v(item_id + num_user, f)) * fm.v(tag_id+num_user+num_item, f);
                        }
                        return result;
		}
		
		inline virtual void learn(int user_id, int item_id, int tp_id, int tn_id) {
			double x_uitp = predict(user_id, item_id, tp_id);
                        double x_uitn = predict(user_id, item_id, tn_id);

     			double normalizer = TagLearner::partial_loss(LOSS_FUNCTION_LN_SIGMOID, x_uitp - x_uitn);
 
                        normalizer = -normalizer; 
                        double& wp = fm.w(tp_id + num_user + num_item);
                        wp = wp - learn_rate*(normalizer + fm.regw(tp_id + num_user + num_item) * wp);
                        double& wn = fm.w(tn_id + num_user + num_item);
                        wn = wn - learn_rate*(-normalizer + fm.regw(tn_id + num_user + num_item) * wn);
                        for (int f = 0; f < fm.num_factor; f++) {
                                double& v_u = fm.v(user_id, f);
                                double& v_i = fm.v(item_id + num_user, f);
                                double& v_tp = fm.v(tp_id + num_user + num_item, f);
                                double& v_tn = fm.v(tn_id + num_user + num_item, f);
                                double v_ui = v_i + v_u;
                                double v_tpn = v_tp - v_tn;
                                v_u = v_u - learn_rate*(normalizer*v_tpn + fm.regv(user_id)*v_u);
                                v_i = v_i - learn_rate*(normalizer*v_tpn + fm.regv(item_id + num_user)*v_i);
                                v_tp = v_tp - learn_rate*(normalizer*v_ui + fm.regv(tp_id + num_user + num_item)*v_tp);
                                v_tn = v_tn - learn_rate*(-normalizer*v_ui + fm.regv(tn_id + num_user + num_item)*v_tn);
                        }
     	}
		
		
};

#endif /*TAGRECOMMENDERFACTORIZATIONMACHINE_FAST_H_*/
