001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.group.bgd; 020 021import org.apache.reef.examples.group.bgd.operatornames.*; 022import org.apache.reef.examples.group.bgd.parameters.*; 023import org.apache.reef.examples.group.bgd.utils.StepSizes; 024import org.apache.reef.examples.group.utils.math.DenseVector; 025import org.apache.reef.examples.group.utils.math.Vector; 026import org.apache.reef.examples.group.utils.timer.Timer; 027import org.apache.reef.exception.evaluator.NetworkException; 028import org.apache.reef.io.Tuple; 029import org.apache.reef.io.network.group.api.operators.Broadcast; 030import org.apache.reef.io.network.group.api.operators.Reduce; 031import org.apache.reef.io.network.group.api.GroupChanges; 032import org.apache.reef.io.network.group.api.task.CommunicationGroupClient; 033import org.apache.reef.io.network.group.api.task.GroupCommClient; 034import org.apache.reef.io.network.util.Pair; 035import org.apache.reef.io.serialization.Codec; 036import org.apache.reef.io.serialization.SerializableCodec; 037import org.apache.reef.tang.annotations.Parameter; 038import org.apache.reef.task.Task; 039 040import javax.inject.Inject; 041import java.util.ArrayList; 042import java.util.logging.Level; 043import java.util.logging.Logger; 044 045/** 046 * Master task for BGD example. 047 */ 048public class MasterTask implements Task { 049 050 public static final String TASK_ID = "MasterTask"; 051 052 private static final Logger LOG = Logger.getLogger(MasterTask.class.getName()); 053 054 private final CommunicationGroupClient communicationGroupClient; 055 private final Broadcast.Sender<ControlMessages> controlMessageBroadcaster; 056 private final Broadcast.Sender<Vector> modelBroadcaster; 057 private final Reduce.Receiver<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer; 058 private final Broadcast.Sender<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster; 059 private final Broadcast.Sender<Vector> descentDriectionBroadcaster; 060 private final Reduce.Receiver<Pair<Vector, Integer>> lineSearchEvaluationsReducer; 061 private final Broadcast.Sender<Double> minEtaBroadcaster; 062 private final boolean ignoreAndContinue; 063 private final StepSizes ts; 064 private final double lambda; 065 private final int maxIters; 066 private final ArrayList<Double> losses = new ArrayList<>(); 067 private final Codec<ArrayList<Double>> lossCodec = new SerializableCodec<>(); 068 private final Vector model; 069 070 private boolean sendModel = true; 071 private double minEta = 0; 072 073 @Inject 074 public MasterTask( 075 final GroupCommClient groupCommClient, 076 @Parameter(ModelDimensions.class) final int dimensions, 077 @Parameter(Lambda.class) final double lambda, 078 @Parameter(Iterations.class) final int maxIters, 079 @Parameter(EnableRampup.class) final boolean rampup, 080 final StepSizes ts) { 081 082 this.lambda = lambda; 083 this.maxIters = maxIters; 084 this.ts = ts; 085 this.ignoreAndContinue = rampup; 086 this.model = new DenseVector(dimensions); 087 this.communicationGroupClient = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class); 088 this.controlMessageBroadcaster = communicationGroupClient.getBroadcastSender(ControlMessageBroadcaster.class); 089 this.modelBroadcaster = communicationGroupClient.getBroadcastSender(ModelBroadcaster.class); 090 this.lossAndGradientReducer = communicationGroupClient.getReduceReceiver(LossAndGradientReducer.class); 091 this.modelAndDescentDirectionBroadcaster = 092 communicationGroupClient.getBroadcastSender(ModelAndDescentDirectionBroadcaster.class); 093 this.descentDriectionBroadcaster = communicationGroupClient.getBroadcastSender(DescentDirectionBroadcaster.class); 094 this.lineSearchEvaluationsReducer = communicationGroupClient.getReduceReceiver(LineSearchEvaluationsReducer.class); 095 this.minEtaBroadcaster = communicationGroupClient.getBroadcastSender(MinEtaBroadcaster.class); 096 } 097 098 @Override 099 public byte[] call(final byte[] memento) throws Exception { 100 101 double gradientNorm = Double.MAX_VALUE; 102 for (int iteration = 1; !converged(iteration, gradientNorm); ++iteration) { 103 try (final Timer t = new Timer("Current Iteration(" + iteration + ")")) { 104 final Pair<Double, Vector> lossAndGradient = computeLossAndGradient(); 105 losses.add(lossAndGradient.getFirst()); 106 final Vector descentDirection = getDescentDirection(lossAndGradient.getSecond()); 107 108 updateModel(descentDirection); 109 110 gradientNorm = descentDirection.norm2(); 111 } 112 } 113 LOG.log(Level.INFO, "OUT: Stop"); 114 controlMessageBroadcaster.send(ControlMessages.Stop); 115 116 for (final Double loss : losses) { 117 LOG.log(Level.INFO, "OUT: LOSS = {0}", loss); 118 } 119 return lossCodec.encode(losses); 120 } 121 122 private void updateModel(final Vector descentDirection) throws NetworkException, InterruptedException { 123 try (final Timer t = new Timer("GetDescentDirection + FindMinEta + UpdateModel")) { 124 final Vector lineSearchEvals = lineSearch(descentDirection); 125 minEta = findMinEta(model, descentDirection, lineSearchEvals); 126 model.multAdd(minEta, descentDirection); 127 } 128 129 LOG.log(Level.INFO, "OUT: New Model = {0}", model); 130 } 131 132 private Vector lineSearch(final Vector descentDirection) throws NetworkException, InterruptedException { 133 Vector lineSearchResults = null; 134 boolean allDead = false; 135 do { 136 try (final Timer t = new Timer("LineSearch - Broadcast(" 137 + (sendModel ? "ModelAndDescentDirection" : "DescentDirection") + ") + Reduce(LossEvalsInLineSearch)")) { 138 if (sendModel) { 139 LOG.log(Level.INFO, "OUT: DoLineSearchWithModel"); 140 controlMessageBroadcaster.send(ControlMessages.DoLineSearchWithModel); 141 modelAndDescentDirectionBroadcaster.send(new Pair<>(model, descentDirection)); 142 } else { 143 LOG.log(Level.INFO, "OUT: DoLineSearch"); 144 controlMessageBroadcaster.send(ControlMessages.DoLineSearch); 145 descentDriectionBroadcaster.send(descentDirection); 146 } 147 final Pair<Vector, Integer> lineSearchEvals = lineSearchEvaluationsReducer.reduce(); 148 if (lineSearchEvals != null) { 149 final int numExamples = lineSearchEvals.getSecond(); 150 lineSearchResults = lineSearchEvals.getFirst(); 151 lineSearchResults.scale(1.0 / numExamples); 152 LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples); 153 LOG.log(Level.INFO, "OUT: LineSearchEvals: {0}", lineSearchResults); 154 allDead = false; 155 } else { 156 allDead = true; 157 } 158 } 159 160 sendModel = chkAndUpdate(); 161 } while (allDead || !ignoreAndContinue && sendModel); 162 return lineSearchResults; 163 } 164 165 private Pair<Double, Vector> computeLossAndGradient() throws NetworkException, InterruptedException { 166 Pair<Double, Vector> returnValue = null; 167 boolean allDead = false; 168 do { 169 try (final Timer t = new Timer("Broadcast(" + (sendModel ? "Model" : "MinEta") + ") + Reduce(LossAndGradient)")) { 170 if (sendModel) { 171 LOG.log(Level.INFO, "OUT: ComputeGradientWithModel"); 172 controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithModel); 173 modelBroadcaster.send(model); 174 } else { 175 LOG.log(Level.INFO, "OUT: ComputeGradientWithMinEta"); 176 controlMessageBroadcaster.send(ControlMessages.ComputeGradientWithMinEta); 177 minEtaBroadcaster.send(minEta); 178 } 179 final Pair<Pair<Double, Integer>, Vector> lossAndGradient = lossAndGradientReducer.reduce(); 180 181 if (lossAndGradient != null) { 182 final int numExamples = lossAndGradient.getFirst().getSecond(); 183 LOG.log(Level.INFO, "OUT: #Examples: {0}", numExamples); 184 final double lossPerExample = lossAndGradient.getFirst().getFirst() / numExamples; 185 LOG.log(Level.INFO, "OUT: Loss: {0}", lossPerExample); 186 final double objFunc = (lambda / 2) * model.norm2Sqr() + lossPerExample; 187 LOG.log(Level.INFO, "OUT: Objective Func Value: {0}", objFunc); 188 final Vector gradient = lossAndGradient.getSecond(); 189 gradient.scale(1.0 / numExamples); 190 LOG.log(Level.INFO, "OUT: Gradient: {0}", gradient); 191 returnValue = new Pair<>(objFunc, gradient); 192 allDead = false; 193 } else { 194 allDead = true; 195 } 196 } 197 sendModel = chkAndUpdate(); 198 } while (allDead || !ignoreAndContinue && sendModel); 199 return returnValue; 200 } 201 202 private boolean chkAndUpdate() { 203 long t1 = System.currentTimeMillis(); 204 final GroupChanges changes = communicationGroupClient.getTopologyChanges(); 205 long t2 = System.currentTimeMillis(); 206 LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec"); 207 if (changes.exist()) { 208 LOG.log(Level.INFO, "OUT: There exist topology changes. Asking to update Topology"); 209 t1 = System.currentTimeMillis(); 210 communicationGroupClient.updateTopology(); 211 t2 = System.currentTimeMillis(); 212 LOG.log(Level.INFO, "OUT: Time to get TopologyChanges = " + (t2 - t1) / 1000.0 + " sec"); 213 return true; 214 } else { 215 LOG.log(Level.INFO, "OUT: No changes in topology exist. So not updating topology"); 216 return false; 217 } 218 } 219 220 private boolean converged(final int iters, final double gradNorm) { 221 return iters >= maxIters || Math.abs(gradNorm) <= 1e-3; 222 } 223 224 private double findMinEta(final Vector theModel, final Vector descentDir, final Vector lineSearchEvals) { 225 final double wNormSqr = theModel.norm2Sqr(); 226 final double dNormSqr = descentDir.norm2Sqr(); 227 final double wDotd = theModel.dot(descentDir); 228 final double[] t = ts.getT(); 229 int i = 0; 230 for (final double eta : t) { 231 final double modelNormSqr = wNormSqr + (eta * eta) * dNormSqr + 2 * eta * wDotd; 232 final double loss = lineSearchEvals.get(i) + ((lambda / 2) * modelNormSqr); 233 lineSearchEvals.set(i, loss); 234 ++i; 235 } 236 LOG.log(Level.INFO, "OUT: Regularized LineSearchEvals: {0}", lineSearchEvals); 237 final Tuple<Integer, Double> minTup = lineSearchEvals.min(); 238 LOG.log(Level.INFO, "OUT: MinTup: {0}", minTup); 239 final double minT = t[minTup.getKey()]; 240 LOG.log(Level.INFO, "OUT: MinT: {0}", minT); 241 return minT; 242 } 243 244 private Vector getDescentDirection(final Vector gradient) { 245 gradient.multAdd(lambda, model); 246 gradient.scale(-1); 247 LOG.log(Level.INFO, "OUT: DescentDirection: {0}", gradient); 248 return gradient; 249 } 250}