EXAMPLES OF NEURAL NETWORK SURVIVAL MODELS

The neural network software can also be used to model survival data.
In this application, the network models the "hazard function", which
determines the instantaneous risk of "death".  The hazard may be
depend on the values of inputs (covariates) for a particular case.  It
may be constant through time (for each case), or it may change over
time.  The data may be censored - ie, for some cases, it may be known
only that the individual survived for at least a certain time.

The details of these models are described in model-spec.doc.  Command
files for the examples below are found in the ex-surv directory.  This
sub-directory "pbc" contains files for the examples used in my talk on
"Survival analysis using a Bayesian neural network" at the 2001 Joint
Statistical Meetings in Atlanta.


Survival models with constant hazard.

I generated 700 cases of synthetic data in which the hazard function
is constant through time, but depends on one input, which was randomly
generated from a standard normal distribution.  When the input was x,
the hazard was h(x) = 1+(x-1)^2, which makes the distribution of the
time of death be exponential with mean 1/h(x).  For the first 200
cases, the time of death was censored at 0.5; if the actual time of
death was later than this, the time recorded was minus the censoring
time (ie, -0.5).  

The first 500 of these cases were used for training, and the remaining
200 for testing.

The network used to model this data had one input (x), and one output,
which was interpreted as log h(x).  One layer of eight hidden units
was used.  The following commands specify the network, model, and data
source:

    > net-spec clog.net 1 8 1 / - 0.05:1 0.05:1 - x0.05:1 - 100 
    > model-spec clog.net survival const-hazard
    > data-spec clog.net 1 1 / cdata@1:500 . cdata@-1:500 .

We can now train the network using the same general methods as for
other models.  Here is one set of commands that work reasonably well:

    > net-gen clog.net fix 0.5
    > mc-spec clog.net repeat 10 heatbath hybrid 100:10 0.2
    > net-mc clog.net 1
    > mc-spec clog.net repeat 4 sample-sigmas heatbath hybrid 500:10 0.4
    > net-mc clog.net 100

This takes 7.0 minutes on the system used (see Ex-system.doc).

Predictions for test cases can now be made using net-pred.  For
example, the following command produces the 10% and 90% quantiles of
the predictive distribution for time of death, based on iterations
from 25 onward.  These predictions can be compared to the actual times
of death (the "targets"):

    > net-pred itq clog.net 25:

    Number of iterations used: 76

    Case  Inputs Targets 10% Qnt 90% Qnt

       1   -0.01    0.29    0.05    1.05
       2   -1.43    0.04    0.02    0.37
       3   -0.68    0.03    0.02    0.52
       4    0.47    0.11    0.09    1.67
       5    0.63    0.97    0.07    1.74

            ( middle lines omitted )

     196    1.39    0.21    0.08    1.99
     197    0.04    0.07    0.06    1.11
     198    1.66    0.56    0.09    1.61
     199    1.33    1.48    0.07    1.76
     200   -1.41    0.08    0.01    0.39

Since there is only one input for this model, we can easily look at
the posterior distribution of the log of the hazard function with the
following command:

    > net-eval clog.net 25:%5 / -3 3 100 | plot

This plots sixteen functions that are drawn from the posterior
distribution of the log of the hazard function.  These functions
mostly match the actual function, which is log(1+(x-1)^2), for values
of x up to about one.  There is considerable uncertainty for values of
x above one, however.  Some of the functions from the posterior are a
good match for the true function in this region, but most are below
the true function.  Note that the data here are fairly sparse.


Survival models with piecewise-constant hazard.

One can also define models in which the hazard function depends on
time, as well, perhaps, on various inputs.  To demonstrate this, I
generated synthetic data for which the hazard function was given by
h(x,t) = (1+(x-1)^2) * t, which makes the cumulative distribution
function for the time of death of an individual for whom the covariate
is x be 1-exp(-(1+(x-1)^2)*t^2/2).

I generated 1000 cases (none of which were censored), and used the
first 700 for training and the remaining 300 for testing.

The network for this problem has two input units.  The first input is
set to the time, t, the second to the value of the covariate, x.
There is a single output unit, whose value will be interpreted as the
log of the hazard function.  If we decide to use one layer of eight
hidden units, we might specify the network as follows:

    > net-spec vlog.net 2 8 1 / - 0.05:1:1 0.05:1 - x0.05:1 - 100 

The prior for the input-hidden weights uses separate hyperparameters
for the two input units (representing time, t, and the covariate, x),
allowing the smoothness of the hazard function to be different for the
two.  In particular, if the hazard is actually constant, the time
input should end up being almost ignored.  (This can be demonstrated
by applying this model to the data used in the previous example.)

The hazard function determines the likelihood by means of its integral
from time zero up to the observed time of death, or the censoring
time.  Computing this exactly would be infeasible if the hazard
function was defined to be the network output when the first input is
set to a given time.  To allow these integrals to be done in a
reasonable amount of compute time, the hazard function is instead
piecewise constant with respect to time (but continuous with respect
to the covariates).  The time points where the hazard function changes
are given in the model-spec command.  For example, we can use the
following specification:

    > model-spec vlog.net survival pw-const-hazard 0.05 0.1 0.2 0.35 \
                                                    0.5 0.7 1.0 1.5

The eight numbers above are the times at which the hazard function
changes.  The log of the value of the hazard before the first time
(0.05) is found from the output of the network when the first input is
set to this time (and other inputs are set to the covariates for the
case); similarly for the log of the value of the hazard function after
the last time point (1.5).  The log of the hazard for other pieces is
found by setting the first input of the network to the average of the
times at the ends of the piece.  For instance, the log of the value of
the hazard function in the time range 0.7 to 1.0 is found by setting
the first input to (0.7+1.0)/2 = 0.85.  Alternatively, if the "log"
option is used, everything is done in terms of the log of the time
(see model-spec.doc for details).

Note that the number of time points specified is limited only by the
greater amount of computation required when there are many of them.
The hyperparameters controlling the weight priors will be able to
control "overfitting" even if the hazard function has many pieces.

The data specification for this problem says there is just one input
(the covariate) and one target (the time of death, or censoring time):

    > data-spec vlog.net 1 1 / vdata@1:700 . vdata@-1:700 .

Note that data-spec should be told there is only one input even though
net-spec is told there are two inputs - the extra input for the net is
the time, not an input read from the data file.

The following commands sample from the posterior distribution:

    > net-gen vlog.net fix 0.5
    > mc-spec vlog.net repeat 10 heatbath hybrid 100:10 0.2
    > net-mc vlog.net 1
    > mc-spec vlog.net repeat 4 sample-sigmas heatbath hybrid 500:10 0.4
    > net-mc vlog.net 100

This takes 28 minutes on the system used (see Ex-system.doc).

The net-pred command can be used to make predictions for test cases.
Here is a command to display the median of the predictive distribution
for time of death, based on iterations from 25 on, along with the
covariates (inputs) and the actual times of death (targets) for the
test cases:

    > net-pred itd vlog.net 25:

    Number of iterations used: 76
    
    Case  Inputs Targets Medians |Error|

       1    0.00    0.07    0.84  0.7711
       2   -0.22    0.42    0.74  0.3155
       3   -2.24    0.04    0.39  0.3480
       4   -0.81    0.26    0.58  0.3172
       5    1.02    3.12    1.04  2.0784
 
             ( middle lines omitted )

     296   -0.81    0.28    0.58  0.2966
     297    1.19    1.60    1.05  0.5503
     298   -1.53    0.57    0.48  0.0940
     299    0.99    2.67    1.09  1.5773
     300    0.28    1.13    0.92  0.2030

    Average abs. error guessing median:    0.32997+-0.01644

The predictive medians can be plotted against the inputs as follows:

    > net-pred idb vlog.net 25: | plot-points

The fairly small amount of scatter seen in this plot is a result of
computing the medians by Monte Carlo, which produces some random
error.  The exact median time of death as predicted by the model is a
smooth function of the input.

These predictions can be compared to the true median time of death for
individuals with the given covariates, as determined by the way the
data was generated, which are in the file "vmedians".

The log of the true hazard function for this problem is the sum of a
function of x only and a function of t only - what is called a
"proportional hazards" model.  A network model specified as follows
can discover this:

    > net-spec v2log.net 2 8 8 1 / - 0.05:1:1 0.05:1 - - 0.05:1:1 0.05:1 - \
                                   x0.05:1 x0.05:1 - 100 

The input-hidden weights for the two hidden layers have hierarchical
priors that allow one of them to look at only the time input and the
other to look at only the covariate.  There are no connections between
the hidden layers.  Final output can therefore be an additive function
of time and the covariate, if the hyperparameters are set in this way.

The remaining commands can be the same as above, except that the
stepsize needs to be reduced to get a reasonable acceptance rate (the
number of leapfrog steps is increased to compensate):

    > model-spec v2log.net survival pw-const-hazard 0.05 0.1 0.2 0.35 \
                                                     0.5 0.7 1.0 1.5

    > data-spec v2log.net 1 1 / vdata@1:700 . vdata@-1:700 .

    > net-gen v2log.net fix 0.5
    > mc-spec v2log.net repeat 10 heatbath hybrid 100:10 0.1
    > net-mc v2log.net 1

    > mc-spec v2log.net repeat 4 sample-sigmas heatbath hybrid 1000:10 0.25
    > net-mc v2log.net 400

This takes 15 hours to run on the system used (see Ex-system.doc).

The model does in fact discover that an additive form for the log
hazard is appropriate, as can be seen by examining the hyperparameters
with a command such as the following:

    > net-plt t h1@h3@ v2log.net | plot

One can see from this that the Markov chain spends most of its time in
regions of the hyperparameter space in which one hidden layer looks
almost only at the time input, and the other looks almost only at the
covariate.