Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Factorization machines for multi-class classification #336

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions examples/2_BasicModels/factorization_machines.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
""" Factorisation Machines for Multiclass Classification.
A simple factorization machine algorithm implementation with TensorFlow.
This example is using the MNIST database
of handwritten digits (http://yann.lecun.com/exdb/mnist/).
The implementation includes regularisation of the parameters for the weights
as well as the interaction factors. It also includes eta (learning rate) parameter

Links:
[MNIST Dataset](http://yann.lecun.com/exdb/mnist/).
Author: Aymeric Damien
Project: https://github.com/aymericdamien/TensorFlow-Examples/
"""

from __future__ import print_function

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
from sklearn import metrics
import tensorflow as tf

# Parameters
learning_rate = 0.01
num_steps = 8000
batch_size = 256
display_step = 100

# FM Parameters
number_latent = 5
num_input = 784 # MNIST data input (img shape: 28*28)
num_classes = 10 # MNIST total classes (0-9 digits)

# Regularization terms for W and V matrices
lambda_w = tf.constant(0.0015, name='lambda_w')
lambda_v = tf.constant(0.0015, name='lambda_v')

# tf Graph input
X = tf.placeholder("float", [None, num_input])
Y = tf.placeholder("float", [None, num_classes])

# Store layers weight & bias
weights = {
'W': tf.Variable(tf.random_normal([num_input, num_classes])),
'b': tf.Variable(tf.random_normal([num_classes]))
}
interactions = {
'V': tf.Variable(tf.random_normal([num_input, number_latent, num_classes])),
}

# Create model
def fm(x):
# Linear terms
linear_terms = tf.add(tf.matmul(x, weights['W']), weights['b'])
# Simplification follows: Rendle, 2010 and Rendle, 2012.

# Squered sum of product (using tensordot as we are having tensors
# instead of matrices for multiclass classification)
s1 = tf.pow(
tf.reduce_sum(
tf.tensordot(
x,
interactions['V'], [[1], [0]]
), 1
), 2)

# Sum of squared product
s2 = tf.reduce_sum(
tf.tensordot(
tf.pow(x, 2),
tf.pow(interactions['V'], 2),
[[1], [0]]
), 1)
outputs = linear_terms + 1/2*(s1-s2)
return outputs

# Construct model
logits = fm(X)
prediction = tf.nn.softmax(logits)

# Define loss, regularisation, and optimizer
error = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
logits=logits, labels=Y))


regularisation = tf.add(tf.reduce_sum(lambda_v * tf.math.abs(interactions['V'])),
tf.reduce_sum(lambda_w * tf.pow(weights['W'],2)))
loss_op = error + regularisation
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
train_op = optimizer.minimize(loss_op)

# Evaluate model
correct_pred = tf.equal(tf.argmax(prediction, 1), tf.argmax(Y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
cm = tf.math.confusion_matrix(tf.argmax(Y, 1), tf.argmax(prediction, 1))

# Initialize the variables (i.e. assign their default value)
init = tf.global_variables_initializer()

# Start training
with tf.Session() as sess:

# Run the initializer
sess.run(init)

for step in range(1, num_steps+1):
if step%1000 == 0:
learning_rate = learning_rate/2
batch_x, batch_y = mnist.train.next_batch(batch_size)
# Run optimization op (backprop)
sess.run(train_op, feed_dict={X: batch_x, Y: batch_y})
if step % display_step == 0 or step == 1:
# Calculate batch loss and accuracy
loss, acc = sess.run([error, accuracy], feed_dict={X: batch_x, Y: batch_y})

print("Step " + str(step) + ", Minibatch Loss= " + \
"{:.4f}".format(loss) + ", Training Accuracy= " + \
"{:.3f}".format(acc))

print("Optimization Finished!")

# Calculate accuracy for MNIST test images
print("Testing Accuracy:", \
sess.run(accuracy, feed_dict={X: mnist.test.images,
Y: mnist.test.labels}))
print("Testing Confusion Matrix: \n", \
sess.run(cm, feed_dict={X: mnist.test.images,
Y: mnist.test.labels}))