Lab 3: Bugs and java basics, continued

17 Jan 2014

In preparation for the next exercise and lab, it might be useful to start reading the following (but be assured that we will still provide more help during the next lab):

  • Learning the Java Language (you can skip the following topics on first reading: nested classes, annotations, generics)
  • Collections (you can skip the following topics on first reading: Aggregate Operations, custom implementation, Interoperability)

Updating the state of CRP for collapsed Gibbs sampling

In this lab, we will create a data structure that will be useful in the next exercise, when you implement collapsed Gibbs sampling.

Example of how to get documentation about existing classes that are offered to you by java:

http://docs.oracle.com/javase/6/docs/api/java/util/Map.html

Start with the following instructions:

  1. Start eclipse
  2. Create a project called exercise3
  3. Create a package called exercise3
  4. Create a class in this package called CRPState
  5. Add the following code:
package exercise3;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;


public class CRPState
{
  private Map<Integer,Set<Integer>> cluster2Customers = new HashMap<Integer,Set<Integer>>();
  private Map<Integer,Integer> customer2Cluster = new HashMap<Integer, Integer>();

  public Integer getClusterIdOfCustomer(Integer customer)
  {
    return customer2Cluster.get(customer);
  }


  public Set<Integer> getTableOfCustomer(Integer customer)
  {
    Integer clusterId = customer2Cluster.get(customer);
    return cluster2Customers.get(clusterId);
  }

  public List<Integer> getAllClusterIds()
  {
    List<Integer> result = new ArrayList<Integer>(cluster2Customers.keySet());
    Collections.sort(result);
    return result;
  }

  public void addCustomerToExistingTable(Integer customer, Integer clusterId)
  {
    checkCustomerNotAlreadyThere(customer);

    cluster2Customers.get(clusterId).add(customer);
    customer2Cluster.put(customer, clusterId);
  }

  public void addCustomerToNewTable(Integer customer)
  {
    checkCustomerNotAlreadyThere(customer);

    Set<Integer> newTable = new HashSet<Integer>();
    newTable.add(customer);

    Integer clusterId = getNextClusterId();
    cluster2Customers.put(clusterId, newTable);
    customer2Cluster.put(customer, clusterId);
  }

  private void checkCustomerNotAlreadyThere(Integer customer)
  {
    if (customer2Cluster.containsKey(customer))
      throw new RuntimeException("Customer already in there. Remove the customer first.");
  }

  public int nCustomers() 
  {
    return customer2Cluster.size();
  }


  public void removeCustomer(Integer customer)
  {
    Set<Integer> table = getTableOfCustomer(customer);

    boolean deleteSuccessful = table.remove(customer);
    if (!deleteSuccessful)
      throw new RuntimeException("Customer " + customer + " was not in the CRPState");

    if (table.isEmpty())
    {
      Integer clusterId = customer2Cluster.get(customer);
      cluster2Customers.remove(clusterId);
    }

    customer2Cluster.remove(customer);

  }

  public Set<Set<Integer>> partition()
  {
    return new HashSet<Set<Integer>>(cluster2Customers.values());
  }

  private int _nextClusterId = 0;

  private int getNextClusterId()
  {
    if (_nextClusterId == Integer.MAX_VALUE)
      throw new RuntimeException("Ids exhausted.");
    return _nextClusterId++;
  }

  public static void main(String [] args)
  {
    CRPState state = new CRPState();
    Random rand = new Random(1);
    double alpha0 = 0.5;

    for (int i = 0; i < 10; i++)
    {
      Integer customer = i;
      double tableCreatePr = alpha0 / (alpha0 + state.nCustomers());
      if (rand.nextDouble() < tableCreatePr)
        state.addCustomerToNewTable(customer);
      else
      {
        // pick an old customer at random
        Integer oldCustomer = rand.nextInt(state.nCustomers());
        // joint the table of that customer
        state.addCustomerToExistingTable(customer, state.getClusterIdOfCustomer(oldCustomer));
      }
      System.out.println(state.partition());
    }

  }
}