001/*
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *   http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing,
013 * software distributed under the License is distributed on an
014 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
015 * KIND, either express or implied.  See the License for the
016 * specific language governing permissions and limitations
017 * under the License.
018 */
019package org.apache.reef.examples.group.bgd.parameters;
020
021import org.apache.reef.examples.group.bgd.loss.LogisticLossFunction;
022import org.apache.reef.examples.group.bgd.loss.LossFunction;
023import org.apache.reef.examples.group.bgd.loss.SquaredErrorLossFunction;
024import org.apache.reef.examples.group.bgd.loss.WeightedLogisticLossFunction;
025import org.apache.reef.tang.annotations.Parameter;
026
027import javax.inject.Inject;
028import java.util.HashMap;
029import java.util.Map;
030
031/**
032 * Type of loss function used in example.
033 */
034public class BGDLossType {
035
036  private static final Map<String, Class<? extends LossFunction>> LOSS_FUNCTIONS =
037      new HashMap<String, Class<? extends LossFunction>>() {{
038        put("logLoss", LogisticLossFunction.class);
039        put("weightedLogLoss", WeightedLogisticLossFunction.class);
040        put("squaredError", SquaredErrorLossFunction.class);
041      }};
042
043  private final Class<? extends LossFunction> lossFunction;
044
045  private final String lossFunctionStr;
046
047  @Inject
048  public BGDLossType(@Parameter(LossFunctionType.class) final String lossFunctionStr) {
049    this.lossFunctionStr = lossFunctionStr;
050    this.lossFunction = LOSS_FUNCTIONS.get(lossFunctionStr);
051    if (this.lossFunction == null) {
052      throw new RuntimeException("Specified loss function type: " + lossFunctionStr +
053          " is not implemented. Supported types are logLoss|weightedLogLoss|squaredError");
054    }
055  }
056
057  public Class<? extends LossFunction> getLossFunction() {
058    return this.lossFunction;
059  }
060
061  public String lossFunctionString() {
062    return lossFunctionStr;
063  }
064}