001/* 002 * Licensed to the Apache Software Foundation (ASF) under one 003 * or more contributor license agreements. See the NOTICE file 004 * distributed with this work for additional information 005 * regarding copyright ownership. The ASF licenses this file 006 * to you under the Apache License, Version 2.0 (the 007 * "License"); you may not use this file except in compliance 008 * with the License. You may obtain a copy of the License at 009 * 010 * http://www.apache.org/licenses/LICENSE-2.0 011 * 012 * Unless required by applicable law or agreed to in writing, 013 * software distributed under the License is distributed on an 014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 015 * KIND, either express or implied. See the License for the 016 * specific language governing permissions and limitations 017 * under the License. 018 */ 019package org.apache.reef.tests.fail.driver; 020 021import org.apache.reef.driver.context.*; 022import org.apache.reef.driver.evaluator.*; 023import org.apache.reef.driver.task.*; 024import org.apache.reef.tang.annotations.Name; 025import org.apache.reef.tang.annotations.NamedParameter; 026import org.apache.reef.tang.annotations.Parameter; 027import org.apache.reef.tang.annotations.Unit; 028import org.apache.reef.tang.exceptions.BindException; 029import org.apache.reef.tests.library.exceptions.DriverSideFailure; 030import org.apache.reef.tests.library.exceptions.SimulatedDriverFailure; 031import org.apache.reef.wake.EventHandler; 032import org.apache.reef.wake.remote.impl.ObjectSerializableCodec; 033import org.apache.reef.wake.time.Clock; 034import org.apache.reef.wake.time.event.Alarm; 035import org.apache.reef.wake.time.event.StartTime; 036import org.apache.reef.wake.time.event.StopTime; 037 038import javax.inject.Inject; 039import javax.xml.bind.DatatypeConverter; 040import java.util.Arrays; 041import java.util.logging.Level; 042import java.util.logging.Logger; 043 044import static org.apache.reef.tests.fail.driver.FailDriver.ExpectedMessage.RequiredFlag.OPTIONAL; 045import static org.apache.reef.tests.fail.driver.FailDriver.ExpectedMessage.RequiredFlag.REQUIRED; 046 047/** 048 * Driver which fails on receiving certain message class. 049 */ 050@Unit 051public final class FailDriver { 052 053 private static final Logger LOG = Logger.getLogger(FailDriver.class.getName()); 054 private static final ObjectSerializableCodec<String> CODEC = new ObjectSerializableCodec<>(); 055 private static final byte[] HELLO_STR = CODEC.encode("MESSAGE::HELLO"); 056 /** 057 * Send message to the Task MSG_DELAY milliseconds after start. 058 */ 059 private static final int MSG_DELAY = 1000; 060 private static final ExpectedMessage[] EVENT_SEQUENCE = { 061 new ExpectedMessage(FailDriver.class, REQUIRED), 062 new ExpectedMessage(StartTime.class, REQUIRED), 063 new ExpectedMessage(AllocatedEvaluator.class, REQUIRED), 064 new ExpectedMessage(FailedEvaluator.class, OPTIONAL), 065 new ExpectedMessage(ActiveContext.class, REQUIRED), 066 new ExpectedMessage(ContextMessage.class, OPTIONAL), 067 new ExpectedMessage(FailedContext.class, OPTIONAL), 068 new ExpectedMessage(RunningTask.class, REQUIRED), 069 new ExpectedMessage(Alarm.class, REQUIRED), 070 new ExpectedMessage(TaskMessage.class, REQUIRED), 071 new ExpectedMessage(Alarm.class, REQUIRED), 072 new ExpectedMessage(SuspendedTask.class, REQUIRED), 073 new ExpectedMessage(RunningTask.class, REQUIRED), 074 new ExpectedMessage(Alarm.class, REQUIRED), 075 new ExpectedMessage(FailedTask.class, OPTIONAL), 076 new ExpectedMessage(CompletedTask.class, REQUIRED), 077 new ExpectedMessage(ClosedContext.class, OPTIONAL), 078 new ExpectedMessage(CompletedEvaluator.class, REQUIRED), 079 new ExpectedMessage(StopTime.class, REQUIRED) 080 }; 081 private final transient Class<?> failMsgClass; 082 private final transient EvaluatorRequestor requestor; 083 private final transient Clock clock; 084 private transient RunningTask task = null; 085 private transient int expectIdx = 0; 086 private transient DriverState state = DriverState.INIT; 087 088 @Inject 089 public FailDriver(@Parameter(FailMsgClassName.class) final String failMsgClassName, 090 final EvaluatorRequestor requestor, final Clock clock) 091 throws ClassNotFoundException { 092 this.failMsgClass = ClassLoader.getSystemClassLoader().loadClass(failMsgClassName); 093 this.requestor = requestor; 094 this.clock = clock; 095 this.checkMsgOrder(this); 096 } 097 098 /** 099 * Check if observer methods are called in the right order 100 * and generate an exception at the given point in the message sequence. 101 * 102 * @param msg a message from one of the observers. 103 * @throws SimulatedDriverFailure if failMsgClass matches the message class. 104 * @throws DriverSideFailure if messages are out of order. 105 */ 106 private void checkMsgOrder(final Object msg) throws SimulatedDriverFailure, DriverSideFailure { 107 108 final String msgClassName = msg.getClass().getName(); 109 LOG.log(Level.FINE, "At {0} {1}:{2}", new Object[]{ 110 this.state, this.expectIdx, msgClassName}); 111 112 if (this.state == DriverState.FAILED) { 113 // If already failed, do not do anything 114 return; 115 } 116 117 // Simulate failure at this step? 118 if (this.failMsgClass.isInstance(msg)) { 119 this.state = DriverState.FAILED; 120 } 121 122 // Make sure events arrive in the right order (specified in EVENT_SEQUENCE): 123 boolean notFound = true; 124 for (; this.expectIdx < EVENT_SEQUENCE.length; ++this.expectIdx) { 125 if (EVENT_SEQUENCE[expectIdx].msgClass.isInstance(msg)) { 126 notFound = false; 127 break; 128 } else if (EVENT_SEQUENCE[expectIdx].requiredFlag == REQUIRED) { 129 break; 130 } 131 } 132 133 if (notFound) { 134 LOG.log(Level.SEVERE, "Event out of sequence: {0} {1}:{2}", 135 new Object[]{this.state, this.expectIdx, msgClassName}); 136 throw new DriverSideFailure("Event out of sequence: " + msgClassName); 137 } 138 139 LOG.log(Level.INFO, "{0}: send: {1} got: {2}", new Object[]{ 140 this.state, EVENT_SEQUENCE[this.expectIdx], msgClassName}); 141 142 ++this.expectIdx; 143 144 if (this.state == DriverState.FAILED) { 145 final SimulatedDriverFailure ex = new SimulatedDriverFailure( 146 "Simulated Failure at FailDriver :: " + msgClassName); 147 LOG.log(Level.INFO, "Simulated Failure: {0}", ex); 148 throw ex; 149 } 150 } 151 152 private enum DriverState {INIT, SEND_MSG, SUSPEND, RESUME, CLOSE, FAILED} 153 154 /** 155 * Name of the message class to specify the failing message handler. 156 */ 157 @NamedParameter(doc = "Full name of the message class to fail on", short_name = "fail") 158 public static final class FailMsgClassName implements Name<String> { 159 } 160 161 /** 162 * Expected message class. 163 */ 164 public static final class ExpectedMessage { 165 166 private final transient Class<?> msgClass; 167 private final transient RequiredFlag requiredFlag; 168 private final transient String repr; 169 170 public ExpectedMessage(final Class<?> clazz, final RequiredFlag requiredFlag) { 171 this.msgClass = clazz; 172 this.requiredFlag = requiredFlag; 173 this.repr = this.msgClass.getSimpleName() + ":" + this.requiredFlag; 174 } 175 176 @Override 177 public String toString() { 178 return this.repr; 179 } 180 181 /** 182 * "Required" flag for message class. 183 */ 184 public enum RequiredFlag {OPTIONAL, REQUIRED} 185 } 186 187 final class AllocatedEvaluatorHandler implements EventHandler<AllocatedEvaluator> { 188 @Override 189 public void onNext(final AllocatedEvaluator eval) { 190 checkMsgOrder(eval); 191 try { 192 eval.submitContext(ContextConfiguration.CONF 193 .set(ContextConfiguration.IDENTIFIER, "FailContext_" + eval.getId()) 194 .build()); 195 } catch (final BindException ex) { 196 LOG.log(Level.WARNING, "Context configuration error", ex); 197 throw new RuntimeException(ex); 198 } 199 } 200 } 201 202 final class CompletedEvaluatorHandler implements EventHandler<CompletedEvaluator> { 203 @Override 204 public void onNext(final CompletedEvaluator eval) { 205 checkMsgOrder(eval); 206 // noop 207 } 208 } 209 210 final class FailedEvaluatorHandler implements EventHandler<FailedEvaluator> { 211 @Override 212 public void onNext(final FailedEvaluator eval) { 213 LOG.log(Level.WARNING, "Evaluator failed: " + eval.getId(), eval.getEvaluatorException()); 214 checkMsgOrder(eval); 215 throw new RuntimeException(eval.getEvaluatorException()); 216 } 217 } 218 219 final class ActiveContextHandler implements EventHandler<ActiveContext> { 220 @Override 221 public void onNext(final ActiveContext context) { 222 checkMsgOrder(context); 223 try { 224 context.submitTask(TaskConfiguration.CONF 225 .set(TaskConfiguration.IDENTIFIER, "FailTask_" + context.getId()) 226 .set(TaskConfiguration.TASK, NoopTask.class) 227 .set(TaskConfiguration.ON_MESSAGE, NoopTask.DriverMessageHandler.class) 228 .set(TaskConfiguration.ON_SUSPEND, NoopTask.TaskSuspendHandler.class) 229 .set(TaskConfiguration.ON_CLOSE, NoopTask.TaskCloseHandler.class) 230 .set(TaskConfiguration.ON_TASK_STOP, NoopTask.TaskStopHandler.class) 231 .set(TaskConfiguration.ON_SEND_MESSAGE, NoopTask.class) 232 .build()); 233 } catch (final BindException ex) { 234 LOG.log(Level.WARNING, "Task configuration error", ex); 235 throw new RuntimeException(ex); 236 } 237 } 238 } 239 240 final class ContextMessageHandler implements EventHandler<ContextMessage> { 241 @Override 242 public void onNext(final ContextMessage message) { 243 checkMsgOrder(message); 244 // noop 245 } 246 } 247 248 final class ClosedContextHandler implements EventHandler<ClosedContext> { 249 @Override 250 public void onNext(final ClosedContext context) { 251 checkMsgOrder(context); 252 // noop 253 } 254 } 255 256 final class FailedContextHandler implements EventHandler<FailedContext> { 257 @Override 258 public void onNext(final FailedContext context) { 259 LOG.log(Level.WARNING, "Context failed: " + context.getId(), context.getReason().orElse(null)); 260 checkMsgOrder(context); 261 262 // if (context.getParentContext().isPresent()) { 263 // context.getParentContext().get().close(); 264 // } 265 } 266 } 267 268 final class RunningTaskHandler implements EventHandler<RunningTask> { 269 @Override 270 @SuppressWarnings("checkstyle:hiddenfield") 271 public void onNext(final RunningTask task) { 272 checkMsgOrder(task); 273 FailDriver.this.task = task; 274 switch (state) { 275 case INIT: 276 state = DriverState.SEND_MSG; 277 break; 278 case RESUME: 279 state = DriverState.CLOSE; 280 break; 281 default: 282 LOG.log(Level.WARNING, "Unexpected state at TaskRuntime: {0}", state); 283 throw new DriverSideFailure("Unexpected state: " + state); 284 } 285 // After a delay, send message or suspend the task: 286 clock.scheduleAlarm(MSG_DELAY, new AlarmHandler()); 287 } 288 } 289 290 final class SuspendedTaskHandler implements EventHandler<SuspendedTask> { 291 @Override 292 @SuppressWarnings("checkstyle:hiddenfield") 293 public void onNext(final SuspendedTask task) { 294 checkMsgOrder(task); 295 state = DriverState.RESUME; 296 try { 297 task.getActiveContext().submitTask(TaskConfiguration.CONF 298 .set(TaskConfiguration.IDENTIFIER, task.getId() + "_RESUMED") 299 .set(TaskConfiguration.TASK, NoopTask.class) 300 .set(TaskConfiguration.ON_MESSAGE, NoopTask.DriverMessageHandler.class) 301 .set(TaskConfiguration.ON_SUSPEND, NoopTask.TaskSuspendHandler.class) 302 .set(TaskConfiguration.ON_CLOSE, NoopTask.TaskCloseHandler.class) 303 .set(TaskConfiguration.ON_TASK_STOP, NoopTask.TaskStopHandler.class) 304 .set(TaskConfiguration.ON_SEND_MESSAGE, NoopTask.class) 305 .set(TaskConfiguration.MEMENTO, DatatypeConverter.printBase64Binary(HELLO_STR)) 306 .build()); 307 } catch (final BindException ex) { 308 LOG.log(Level.SEVERE, "Task configuration error", ex); 309 throw new DriverSideFailure("Task configuration error", ex); 310 } 311 } 312 } 313 314 final class TaskMessageHandler implements EventHandler<TaskMessage> { 315 @Override 316 public void onNext(final TaskMessage msg) { 317 checkMsgOrder(msg); 318 assert Arrays.equals(HELLO_STR, msg.get()); 319 assert state == DriverState.SEND_MSG; 320 state = DriverState.SUSPEND; 321 clock.scheduleAlarm(MSG_DELAY, new AlarmHandler()); 322 } 323 } 324 325 final class FailedTaskHandler implements EventHandler<FailedTask> { 326 @Override 327 @SuppressWarnings("checkstyle:hiddenfield") 328 public void onNext(final FailedTask task) { 329 LOG.log(Level.WARNING, "Task failed: " + task.getId(), task.getReason().orElse(null)); 330 checkMsgOrder(task); 331 if (task.getActiveContext().isPresent()) { 332 task.getActiveContext().get().close(); 333 } 334 } 335 } 336 337 final class CompletedTaskHandler implements EventHandler<CompletedTask> { 338 @Override 339 @SuppressWarnings("checkstyle:hiddenfield") 340 public void onNext(final CompletedTask task) { 341 checkMsgOrder(task); 342 task.getActiveContext().close(); 343 } 344 } 345 346 final class StartHandler implements EventHandler<StartTime> { 347 @Override 348 public void onNext(final StartTime time) { 349 FailDriver.this.checkMsgOrder(time); 350 FailDriver.this.requestor.submit(EvaluatorRequest.newBuilder() 351 .setNumber(1).setMemory(128).setNumberOfCores(1).build()); 352 } 353 } 354 355 final class AlarmHandler implements EventHandler<Alarm> { 356 @Override 357 public void onNext(final Alarm time) { 358 FailDriver.this.checkMsgOrder(time); 359 switch (FailDriver.this.state) { 360 case SEND_MSG: 361 FailDriver.this.task.send(HELLO_STR); 362 break; 363 case SUSPEND: 364 FailDriver.this.task.suspend(); 365 break; 366 case CLOSE: 367 FailDriver.this.task.close(); 368 break; 369 default: 370 LOG.log(Level.WARNING, "Unexpected state at AlarmHandler: {0}", FailDriver.this.state); 371 throw new DriverSideFailure("Unexpected state: " + FailDriver.this.state); 372 } 373 } 374 } 375 376 final class StopHandler implements EventHandler<StopTime> { 377 @Override 378 public void onNext(final StopTime time) { 379 FailDriver.this.checkMsgOrder(time); 380 // noop 381 } 382 } 383}