Statistical Methods in Computational Linguistics

HMMs: Baum-Welch Algorithm

Naive Maximum Likelihood Algorithm

First Guess
and
Truth
 
Machine in Charniak's Figure 4.9 (b), p. 64
A First Guess
From State Observation To State Prob
a 0 b .04
a 1 b 0
a 0 a .48
a 1 a .48
Total a Prob 1.0
From State Observation To State Prob
b 0 b 0
b 1 b 0
b 0 a 0
b 1 a 1.0
Total b Prob 1.0
Machine in Charniak's Figure 4.9 (a), p. 64
The Truth
From State Observation To State Prob
a 0 b .67
a 1 b 0
a 0 a .16
a 1 a .17
Total a Prob 1.0
From State Observation To State Prob
b 0 b 0
b 1 b 0
b 0 a 0
b 1 a 1.0
Total b Prob 1.0
Input  

01011

State Sequences
(Paths) and
Transition
Sequences
 

Path 0 1 0 1 1  
aaaaaa a0a1 a0a1 a1 a
--> --> --> --> -->
abaaaa a0b1 a0a1 a1 a
--> --> --> --> -->
aaabaa a0a1 a0b1 a1 a
--> --> --> --> -->
ababaa a0b1 a0b1 a1 a
--> --> --> --> -->

Observations:  
  1. 01 sequence producable by either aa or ab
  2. Final 1 must be generated by aa
Path
Probabilities
 
aaaaaa .02548 .48 * .48 *.48 * .48 * .48
abaaaa .00442 .04 * 1.0 *.48 * .48 * .48
aaabaa .00442 .48 * .48 *.04 * 1.0 * .48
ababaa .00077 .04 * 1.0 *.04 * 1.0 * .48
Count
Calculation
 
Table computing "eta" in Charniak's 4.20
Path 0 1 0 1
a -> b b -> a a -> a a -> a
ababaa 2 2 0 1
abaaaa 1 1 1 2
aaabaa 1 1 1 2
aaaaaa 0 0 2 3
Table computing Charniak's 4.20
Path Prob 0 1 0 1
a -> b b -> a a -> a a -> a
ababaa .00077 .00154 .00154 0 .00077
abaaaa .00442 .00442 .00442 .00442 .00884
aaabaa .00442 .00442 .00442 .00442 .00884
aaaaaa .02548 0 0 .05096 .07644
Probability
Calculation
 
Path Prob 0 1 0 1 Count for
Transitions
From a
(Add red cells)
Count for
Transitions
From b
a -> b b -> a a -> a a -> a
ababaa .00077 .00154 .00154 0 .00077    
abaaaa .00442 .00442 .00442 .00442 .00884
aaabaa .00442 .00442 .00442 .00442 .00884
aaaaaa .02548 0 0 .05096 .07644
Total
Count
.03509 .01038 .01038 .05970 .09489
Rounded .035 .01 .01 .06 .095 .165 .01
Probability
Calculation
.01 .01 .06 .095
.165 .01 .165 .165
= .06 1.0 .36 .58  
old
transition
probability
.04 1.0 .48 .48  
Truth .67 1.0 .16 .17  
Homework
Problem
1
 

  1. Starting with the model labeled A first guess do one iteration of the Maximum Likelihood Algorithm for the training input:
      11000
  2. Compare the new transition probabilities with the old ones. What happened that is a little unsettling?
  3. Do a second iteration of the algorithm.

Maximum Likelihood Probability Estimates

Transition
Probability
Estimate

Concept: For transition si[wk]sjcount the number of time si[wk]sj was taken and divide by the total number of times a transition from si was taken.

PMLE(si[wk]sj) = Count(si[wk]sj) (1)
Summ=1,j=1 Count(si [wm] sj )

(1)
Count Calculation

Concept: For each path s1,n+1 and each transition si[wk]sj count the number of time si[wk]sj appears in s1,n+1 given output w1,n (Charniak's "eta" in 4.20):

    eta(si[wk]sj,s1,n+1,w1,n)

For each path s1,n+1 and each transition si[wk]sj multiply the eta times the probability of the path given the output.

    Count(si[wk]sj) = Sums1,n+1 P(s1,n+1 | w1,n) * eta(si[wk]sj,s1,n+1,w1,n) (2)
(2)
Normalization
Factor

Concept: In principle the "path probability" we compute by multiplying all the transition probabilities on a path are JOINT probabilities:

    p(w1,n,s1,n+1)= Pii=1 si[wi]si+1.

But the probabilities we are multiplying counts by in Equation (2)[Charniak's 4.20] are conditional:

    p(s1,n+1|w1,n)
All paths are to be weighted according to their probability GIVEN the corpus.

So we "normalize" using the chain rule:

    p(s1,n+1,w1,n)= p(s1,n+1|w1,n) * p(w1,n)
    p(s1,n+1|w1,n) = p(s1,n+1,w1,n)/p(w1,n)
Equation (2) becomes (Charniak's 4.21):
    Count(si[wk]sj) = Sums1,n+1 P(s1,n+1, w1,n) * eta(si[wk]sj,s1,n+1,w1,n) (3)
    P(w1,n)
This now uses what we normally mean by the path probabilityin an HMM.

Notice however that in our example Maximum Likelihood iteration we never divided by a normalization factor. This is because ALL our counts need to be divided by a normalization factor. But then in equation (1) we divide count by counts:
PMLE(si[wk]sj) = Count(si[wk]sj) (1)
Summ=1,j=1 Count(si [wm] sj )
So the normalization factors will cancel out and can safely be ignored!

 

Baum-Welch

Eliminating
State sequences
from Max-
Likelihood
Estimation

Our naive MLE algorithm summed over all state-sequences. That is, we needed to find the probabilities of all paths through the model given the corpus. In general this is exponential. Bad.

We try to fix this by rethinking Counts:

    Count(si[wk]sj) = Sumt=1 P(si[wk]sj at time tick t | w1,n) (4)
      = Sumt = 1 P(St=si, St+1=sj, Wt=wk, w1,n)
    P(w1,n)
We are really mentioning wt twice here, once as Wt=wk, once in w1,n. We can split w1,n into two sequences, one before wt, one after, and refer to the same joint events:
    = 1 Sumt = 1 P(w1,t-1, St=si, St+1=sj, Wt=wk, wt+1,n)
    P(w1,n)
We can now apply the chain rule
    = 1 Sumt = 1 P(w1,t-1, St=si) * P(St+1=sj, Wt=wk, wt+1,n|w1,t-1, St=si) (5)
    P(w1,n)

The second probability in the product can have the chain rule REapplied:

    P(St+1=sj, Wt=wk, wt+1,n| w1,t-1, St=si) = P(wt+1,n|w1,t-1, St=si, St+1=sj, Wt=wk) *
      P(St+1=sj, Wt=wk | w1,t-1, St=si)

Substituting back in (5):

    = 1 Sumt = 1 P(w1,t-1, St=si) * (6)
    P(w1,n) P(wt+1,n|w1,t-1, St=si, St+1=sj, Wt=wk) *
    P(St+1=sj, Wt=wk | w1,t-1, St=si)

We now make use of the Markov asumption, that transition probabilities depend, at most, on the input, the previous state, and the next state: That is,

    P(St+1=sj, Wt=wk | w1,t-1, St=si) = P(St+1=sj, Wt=wk | St=si)
and
    P(wt+1,n|w1,t-1, St=si, St+1=sj, Wt=wk) = P(wt+1,n| St+1=sj)

Substituting in (6), we have:

    = 1 Sumt = 1 P(w1,t-1, St=si) * (7)
    P(w1,n) P(wt+1,n| St+1=sj) *
    P(St+1=sj, Wt=wk | St=si)

But HOLY COW! Recall the definition of alphai(t):

    alphai(t) = P(w1,t-1, St=si)
And the definition of betaj(t+1):
    betai(t+1) = P(wt+1,n, St+1=sj)
And while we're at it P(St+1=sj, Wt=wk | St=si) is just what we mean by a transition probability:
    P(St+1=sj, Wt=wk | St=si) = P(si[w]sj)
Substituting we have the basic equation that makes the Baum Welch algorithm possible:
    Count(si[wk]sj) = 1 Sumt = 1 alphai(t) * P(si[w]sj) * betaj(t+1) (8)
    P(w1,n)

Illustration of Baum-Welch

Steps in
Re-estimation

  1. Compute forward-probabilities
  2. Compute backward probabilities
  3. For each time-tick, for each transition, compute re-estimated transition count using equation (8)
  4. Compute re-estimated transition probabilities using re-estimated transition counts and equation (1).
Forward- Probabilities

alphai(t) is what goes in each cell of the table:

    alphai(t) = Sumj alphaj(t-1) * P(j[wt-1]i)
This is the total probability of the string up to and including time t-1, given that you end up in state si at time t:
Time-relative indexing of time ticks, states, and words
t     1   2   ... n   n+1
s   s1   s2   ... sn   sn+1
w     w1   w2   ... wn  

Initialization of algorithm:

Time: t=1

    t 1 2 3 4 5 6
    wt-1 eps 0 1 0 1 1
    alphaa(t) 1          
    alphab(t) 0          

Time t=2:

    t 1 2 3 4 5 6
    wt-1 eps 0 1 0 1 1
    alphaa(t) 1 .48        
    alphab(t) 0 .04        
Computation of alphaa(2)
    alphaa(2)= Sumi=1 alphai(1)*P(i[0]a)

    From a:
    P(a[0]a)=.48
    alphaa(1)= 1.0
    P(a[0]a) * alphaa(1)=.48

    From b:
    P(b[0]a)=.48
    alphab(1)= 0
    P(b[0]a) * alphaa(1)=0

    alphaa(2)= .48 + 0 = .48

Computation of alphab(2)

    alphab(2)= Sumi=1 alphai(1)*P(i[0]b)

    From a:
    P(a[0]b)=.04
    alphaa(1)= 1.0
    P(a[0]b) * alphaa(1)=.04

    From b:
    P(b[0]b)=0
    alphab(1)= 0
    P(b[0]b) * alphaa(1)=0

    alphaa(2)= .04 + 0 = .04

Time: t=6

    t 1 2 3 4 5 6
    wt-1 eps 0 1 0 1 1
    alphaa(t) 1 .48 .27 .13 .072 .35
    alphab(t) 0 .04 0 .01 0 0
Computation of alphab(3)
    alphaa(3)= Sumi=1 alphai(2)*P(i[1]a)

    From a:
    P(a[1]a)=.48
    alphaa(2)=.48
    P(a[1]a) * alphaa(2)=.23

    From b:
    P(b[1]a)=1.0
    alphab(2)= .04
    P(b[1]a) * alphaa(2)=.04

    alphaa(3)= .23 + .04 = .27

Computation of alphab(3)

    alphab(3)= Sumi=1 alphai(2)*P(i[1]b)

    From a:
    P(a[1]b)=0
    alphaa(2)= .48
    P(a[1]b) * alphaa(2)=0

    From b:
    P(b[1]b)=0
    alphab(2)= .04
    P(b[1]b) * alphaa(2)=0

    alphaa(3)= 0 + 0 = 0
Backward- Probabilities

betai(t) is what goes in each cell of the table:

    betai(t) = Sumj P(i[wt]j) * betaj(t+1)
This is the total probability of the string from time t on, given that you're in state si at time t.

Initialization of algorithm:

Time: t=6

    t 1 2 3 4 5 6
    wt 0 1 0 1 1 eps
    betaa(t)           1.0
    betab(t)           1.0

Time t=5:

    t 1 2 3 4 5 6
    wt 0 1 0 1 1 eps
    betaa(t)         .48 1.0
    betab(t)         1.0 1.0
Computation of betaa(5):
    betaa(5)= Sumi=1 betai(6)*P(a[1]i)

    To a:
    P(a[1]a)=.48
    betaa(6)= 1.0
    P(a[1]a) * betaa(6)=.48

    To b:
    P(a[1]b)=0
    betab(6)= 1.0
    P(a[1]b) * betab(6)=0

    betaa(5)= .48 + 0 = .48

Computation of betab(5):

    betab(5)= Sumi=1 betai(6)*P(b[1]i)

    To a:
    P(b[1]a)= 1.0
    betaa(6)= 1.0
    P(b[1]a) * betaa(6)=1.0

    To b:
    P(b[1]b)=0
    betab(6)= 0
    P(b[1]b) * betab(6)=0

    betab(5)= 1.0 + 0 = 1.0

Time: t=1

    t 1 2 3 4 5 6
    wt 0 1 0 1 1 eps
    betaa(t) .035 .062 .13 .23 .48 1.0
    betab(t) 0 .13 0 .28[typo]
    .48[corr]
    0 0
Computation of betaa(4):
    betaa(4)= Sumi=1 betai(5)*P(a[1]i)

    To a:
    P(a[1]a)=.48
    betaa(5)=.48
    P(a[1]a) * betaa(5)=.23

    To b:
    P(a[1]b)=0
    betab(5)= 1.0
    P(a[1]b) * betaa(5)=0

    betaa(4)= .23 + 0 = .23

Computation of betab(4):

    betab(4)= Sumi=1 betai(5)*P(b[1]i)

    To a:
    P(b[1]a)=1.0
    betaa(5)= .48
    P(b[1]a) * betaa(5)=.48

    To b:
    P(b[1]b)=0
    betab(5)= 1.0
    P(b[1]b) * betaa(5)=0

    betab(4)= .48 + 0 = .48
Re-estimated
Counts

In each cell of table:

    alphai(t)* P(si[wt]sj)* betaj(t+1)si

    t 1 2 3 4 5  
    wt 0 1 0 1 1 Count(si[wt]sj) Count(si) New P
    a[0]b .0052 0 .0052 0 0 .010 .165 .06
    b[1]a 0 .0052 0 .0048 0 .010 .010 1.0
    a[0]a .030 0 .030 0 0 .060 .165 .36
    a[1]a 0 .030 0 .030 .035 .095 .165 1.0

Computation of Count(a[0]b) for time tick 1

    = alphaa(1) * P(a[0]b) * betab(2)
    = 1.0 * .04 * .13
    = .0052

Computation of Count(a[0]b) for time tick 3

    = alphaa(3) * P(a[0]b) * betab(4)
    = .27 * .04 * .48
    = .052

Computation of Count(a[0]a) for time tick 3

    = alphaa(3) * P(a[0]a) * betaa(4)
    = .27 * .48 * .23
    = .030