001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.examples.group.bgd; 020 021import org.apache.reef.examples.group.bgd.data.Example; 022import org.apache.reef.examples.group.bgd.loss.LossFunction; 023import org.apache.reef.examples.group.bgd.operatornames.*; 024import org.apache.reef.examples.group.bgd.parameters.AllCommunicationGroup; 025import org.apache.reef.examples.group.bgd.parameters.ProbabilityOfFailure; 026import org.apache.reef.examples.group.bgd.utils.StepSizes; 027import org.apache.reef.examples.group.utils.math.DenseVector; 028import org.apache.reef.examples.group.utils.math.Vector; 029import org.apache.reef.io.network.group.api.operators.Broadcast; 030import org.apache.reef.io.network.group.api.operators.Reduce; 031import org.apache.reef.io.network.group.api.task.CommunicationGroupClient; 032import org.apache.reef.io.network.group.api.task.GroupCommClient; 033import org.apache.reef.io.network.util.Pair; 034import org.apache.reef.tang.annotations.Parameter; 035import org.apache.reef.task.Task; 036 037import javax.inject.Inject; 038import java.util.List; 039import java.util.logging.Logger; 040 041/** 042 * Slave task for BGD example. 043 */ 044public class SlaveTask implements Task { 045 046 private static final Logger LOG = Logger.getLogger(SlaveTask.class.getName()); 047 048 private final double failureProb; 049 050 private final CommunicationGroupClient communicationGroup; 051 private final Broadcast.Receiver<ControlMessages> controlMessageBroadcaster; 052 private final Broadcast.Receiver<Vector> modelBroadcaster; 053 private final Reduce.Sender<Pair<Pair<Double, Integer>, Vector>> lossAndGradientReducer; 054 private final Broadcast.Receiver<Pair<Vector, Vector>> modelAndDescentDirectionBroadcaster; 055 private final Broadcast.Receiver<Vector> descentDirectionBroadcaster; 056 private final Reduce.Sender<Pair<Vector, Integer>> lineSearchEvaluationsReducer; 057 private final Broadcast.Receiver<Double> minEtaBroadcaster; 058 private List<Example> examples = null; 059 private final ExampleList dataSet; 060 private final LossFunction lossFunction; 061 private final StepSizes ts; 062 063 private Vector model = null; 064 private Vector descentDirection = null; 065 066 @Inject 067 public SlaveTask( 068 final GroupCommClient groupCommClient, 069 final ExampleList dataSet, 070 final LossFunction lossFunction, 071 @Parameter(ProbabilityOfFailure.class) final double pFailure, 072 final StepSizes ts) { 073 074 this.dataSet = dataSet; 075 this.lossFunction = lossFunction; 076 this.failureProb = pFailure; 077 LOG.info("Using pFailure=" + this.failureProb); 078 this.ts = ts; 079 080 this.communicationGroup = groupCommClient.getCommunicationGroup(AllCommunicationGroup.class); 081 this.controlMessageBroadcaster = communicationGroup.getBroadcastReceiver(ControlMessageBroadcaster.class); 082 this.modelBroadcaster = communicationGroup.getBroadcastReceiver(ModelBroadcaster.class); 083 this.lossAndGradientReducer = communicationGroup.getReduceSender(LossAndGradientReducer.class); 084 this.modelAndDescentDirectionBroadcaster = 085 communicationGroup.getBroadcastReceiver(ModelAndDescentDirectionBroadcaster.class); 086 this.descentDirectionBroadcaster = communicationGroup.getBroadcastReceiver(DescentDirectionBroadcaster.class); 087 this.lineSearchEvaluationsReducer = communicationGroup.getReduceSender(LineSearchEvaluationsReducer.class); 088 this.minEtaBroadcaster = communicationGroup.getBroadcastReceiver(MinEtaBroadcaster.class); 089 } 090 091 @Override 092 public byte[] call(final byte[] memento) throws Exception { 093 /* 094 * In the case where there will be evaluator failure and data is not in 095 * memory we want to load the data while waiting to join the communication 096 * group 097 */ 098 loadData(); 099 100 for (boolean repeat = true; repeat;) { 101 102 final ControlMessages controlMessage = controlMessageBroadcaster.receive(); 103 switch (controlMessage) { 104 105 case Stop: 106 repeat = false; 107 break; 108 109 case ComputeGradientWithModel: 110 failPerhaps(); 111 this.model = modelBroadcaster.receive(); 112 lossAndGradientReducer.send(computeLossAndGradient()); 113 break; 114 115 case ComputeGradientWithMinEta: 116 failPerhaps(); 117 final double minEta = minEtaBroadcaster.receive(); 118 assert descentDirection != null; 119 this.descentDirection.scale(minEta); 120 assert model != null; 121 this.model.add(descentDirection); 122 lossAndGradientReducer.send(computeLossAndGradient()); 123 break; 124 125 case DoLineSearch: 126 failPerhaps(); 127 this.descentDirection = descentDirectionBroadcaster.receive(); 128 lineSearchEvaluationsReducer.send(lineSearchEvals()); 129 break; 130 131 case DoLineSearchWithModel: 132 failPerhaps(); 133 final Pair<Vector, Vector> modelAndDescentDir = modelAndDescentDirectionBroadcaster.receive(); 134 this.model = modelAndDescentDir.getFirst(); 135 this.descentDirection = modelAndDescentDir.getSecond(); 136 lineSearchEvaluationsReducer.send(lineSearchEvals()); 137 break; 138 139 default: 140 break; 141 } 142 } 143 144 return null; 145 } 146 147 private void failPerhaps() { 148 if (Math.random() < failureProb) { 149 throw new RuntimeException("Simulated Failure"); 150 } 151 } 152 153 private Pair<Vector, Integer> lineSearchEvals() { 154 155 if (examples == null) { 156 loadData(); 157 } 158 159 final Vector zed = new DenseVector(examples.size()); 160 final Vector ee = new DenseVector(examples.size()); 161 162 for (int i = 0; i < examples.size(); i++) { 163 final Example example = examples.get(i); 164 double f = example.predict(model); 165 zed.set(i, f); 166 f = example.predict(descentDirection); 167 ee.set(i, f); 168 } 169 170 final double[] t = ts.getT(); 171 final Vector evaluations = new DenseVector(t.length); 172 int i = 0; 173 for (final double d : t) { 174 double loss = 0; 175 for (int j = 0; j < examples.size(); j++) { 176 final Example example = examples.get(j); 177 final double val = zed.get(j) + d * ee.get(j); 178 loss += this.lossFunction.computeLoss(example.getLabel(), val); 179 } 180 evaluations.set(i++, loss); 181 } 182 183 return new Pair<>(evaluations, examples.size()); 184 } 185 186 private Pair<Pair<Double, Integer>, Vector> computeLossAndGradient() { 187 188 if (examples == null) { 189 loadData(); 190 } 191 192 final Vector gradient = new DenseVector(model.size()); 193 double loss = 0.0; 194 for (final Example example : examples) { 195 final double f = example.predict(model); 196 final double g = this.lossFunction.computeGradient(example.getLabel(), f); 197 example.addGradient(gradient, g); 198 loss += this.lossFunction.computeLoss(example.getLabel(), f); 199 } 200 201 return new Pair<>(new Pair<>(loss, examples.size()), gradient); 202 } 203 204 private void loadData() { 205 LOG.info("Loading data"); 206 examples = dataSet.getExamples(); 207 } 208}