/*
	Tagrecommender based on Factorization Machines using binary user, item, tag indicators.

	Based on the publication(s):
	Steffen Rendle (2010): Factorization Machines, in Proceedings of the 10th IEEE International Conference on Data Mining (ICDM 2010), Sydney, Australia.

	Notes:
	This implementation demonstrates how to use the generic prediction and gradient functions that are offered by the FM library.
	A faster implementation of exactly the same model can be found in tag_rec_fm_fast.h.
	Both implementations are offered to demonstrate (1) how the generic functionality can be used without the need to implement any prediction and
	gradient steps and (2) how to get a faster implementation if both steps are integrated.
	Note that both implementations have the same runtime complexity. 

	Author:   Steffen Rendle, http://www.libfm.org/
	modified: 2010-12-10

	Copyright 2010 Steffen Rendle, see license.txt for more information
*/

#ifndef TAGRECOMMENDERFACTORIZATIONMACHINE_GENERIC_H_
#define TAGRECOMMENDERFACTORIZATIONMACHINE_GENERIC_H_

#include "../../fm_core/fm_data.h"
#include "../../fm_core/fm_model.h"
#include "../../fm_core/fm_sgd.h"



class TagRecommenderFactorizationMachineGeneric : public TagRecommenderFactorizationMachine {
	public:
		fv_vector x, x_pos, x_neg;
		DVector<double> sum_pos, sum_neg, sum_sqr_pos, sum_sqr_neg;
		DVector<bool> grad_visited;
		DVector<double> grad;
	public:				
		virtual void init() {
			x.data = new fv_pair[3];
			x.size = 3;
			x.data[0].value = 1;
			x.data[1].value = 1;
			x.data[2].value = 1;
			x_pos.data = new fv_pair[3];
			x_pos.size = 3;
			x_pos.data[0].value = 1;
			x_pos.data[1].value = 1;
			x_pos.data[2].value = 1;
			x_neg.data = new fv_pair[3];
			x_neg.size = 3;
			x_neg.data[0].value = 1;
			x_neg.data[1].value = 1;
			x_neg.data[2].value = 1;
			
			sum_pos.setSize(fm.num_factor);
			sum_neg.setSize(fm.num_factor);
			sum_sqr_pos.setSize(fm.num_factor);
			sum_sqr_neg.setSize(fm.num_factor);
			grad_visited.setSize(fm.num_attribute);
			grad.setSize(fm.num_attribute);
		}
		
		virtual double predict(int user_id, int item_id, int tag_id) {
			x.data[0].feature_id = user_id;
			x.data[1].feature_id = item_id + num_user;
			x.data[2].feature_id = tag_id + num_user + num_item;
			return fm.predict(x);
		}
		
		virtual void learn(int user_id, int item_id, int tp_id, int tn_id) {
			x_pos.data[0].feature_id = user_id;
			x_pos.data[1].feature_id = item_id + num_user;
			x_pos.data[2].feature_id = tp_id + num_user + num_item;
			x_neg.data[0].feature_id = user_id;
			x_neg.data[1].feature_id = item_id + num_user;
			x_neg.data[2].feature_id = tn_id + num_user + num_item;
			
			double x_uitp = fm.predict(x_pos, sum_pos, sum_sqr_pos);
     			double x_uitn = fm.predict(x_neg, sum_neg, sum_sqr_neg);
     			double normalizer = TagLearner::partial_loss(LOSS_FUNCTION_LN_SIGMOID, x_uitp - x_uitn);
     		
			fm_pairSGD(&fm, learn_rate, x_pos, x_neg, -normalizer, sum_pos, sum_neg, grad_visited, grad);
     	}
		
		
};

#endif /*TAGRECOMMENDERFACTORIZATIONMACHINE_GENERIC_H_*/
