Importing libraries

library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.3     ✓ purrr   0.3.4
## ✓ tibble  3.1.1     ✓ dplyr   1.0.5
## ✓ tidyr   1.1.3     ✓ stringr 1.4.0
## ✓ readr   1.4.0     ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(tidytext)
library(tokenizers)
library(markovchain)
## Package:  markovchain
## Version:  0.8.6
## Date:     2021-05-17
## BugReport: https://github.com/spedygiorgio/markovchain/issues
library(here)
## here() starts at /Users/benedictneo/next-word-predictor

Markov Chains

A Markov chain or Markov process is a stochastic model describing a sequence of possible events in which the probability of each event depends only on the state attained in the previous event

Take weather for example, we can have 3 states - rainy, sunny and cloudy. using the previous events (it rained today)

Using this principle, we can predict the next word based on the last word typed. The Markov Chain model will model the transition probability between states, where each state are the tokens.

More information about markov chains on the wikipedia page

Below is an example of how it works with a simple sentence.

Markov Chains in R

We have a simple sentence, and then tokenize it to individual tokens.

text <- c("the quick brown fox jumps over the lazy dog and the angry dog chase the fox")

(tokens <- strsplit(text, split = " ") %>% unlist())
##  [1] "the"   "quick" "brown" "fox"   "jumps" "over"  "the"   "lazy"  "dog"  
## [10] "and"   "the"   "angry" "dog"   "chase" "the"   "fox"

To create a Markov chains without manually calculating the tarnsitions, we can utilize the markovchain package.

simple_markov <- markovchainFit(tokens, method = "laplace")

Markov Chain Visualized

To visualize our markov chain, we can simply use the plot function

set.seed(2021)
plot(simple_markov$estimate)

The arrows indicate the transition to the next state and the numbers are the probability of those transitions. For example, starting at “the”, there is a probability it will go to “angry”, “lazy”, “quick”, and “fox”. As “the” is before those 4 words, the probability that it will transition to them is equally likely, which is why it’s 0.25.

Transition Matrix

With our markov model, we can run $estimate on it and it will show the dimensions (number of words) and the transition matrix.

simple_markov$estimate
## Laplacian Smooth Fit 
##  A  11 - dimensional discrete Markov Chain defined by the following states: 
##  and, angry, brown, chase, dog, fox, jumps, lazy, over, quick, the 
##  The transition matrix  (by rows)  is defined as follows: 
##       and angry brown chase dog  fox jumps lazy over quick the
## and   0.0  0.00     0   0.0   0 0.00     0 0.00    0  0.00   1
## angry 0.0  0.00     0   0.0   1 0.00     0 0.00    0  0.00   0
## brown 0.0  0.00     0   0.0   0 1.00     0 0.00    0  0.00   0
## chase 0.0  0.00     0   0.0   0 0.00     0 0.00    0  0.00   1
## dog   0.5  0.00     0   0.5   0 0.00     0 0.00    0  0.00   0
## fox   0.0  0.00     0   0.0   0 0.00     1 0.00    0  0.00   0
## jumps 0.0  0.00     0   0.0   0 0.00     0 0.00    1  0.00   0
## lazy  0.0  0.00     0   0.0   1 0.00     0 0.00    0  0.00   0
## over  0.0  0.00     0   0.0   0 0.00     0 0.00    0  0.00   1
## quick 0.0  0.00     1   0.0   0 0.00     0 0.00    0  0.00   0
## the   0.0  0.25     0   0.0   0 0.25     0 0.25    0  0.25   0

Using the markovchainSequence, we can generate a sentence based on our markov chain model

markovchainSequence(
  n = 5,
  markovchain = simple_markov$estimate,
  t0 = "the",
  # set the first word
  include.t0 = T
) %>% 
  paste(collapse = " ")
## [1] "the fox jumps over the quick"

Boom! A sentence was generated, althought it’s the same exact sentence it was fed with.

We can also generate multiple sentences by doing a for loop around it.

for (i in 1:5) {
  set.seed(i)
  markovchainSequence(
  n = 5,
  markovchain = simple_markov$estimate,
  t0 = "the",
  # set the first word
  include.t0 = T
) %>% 
  paste(collapse = " ") %>% 
    print()
}
## [1] "the fox jumps over the angry"
## [1] "the angry dog chase the quick"
## [1] "the angry dog and the lazy"
## [1] "the lazy dog and the quick"
## [1] "the angry dog chase the angry"

Now let’s move on to fitting our ngram to the markov model.

Model fitting with Ngrams

The plan was to fit all unigram, bigram, trigram and quadgram to the markov chain. However, due to the memory and size limit of the shiny app, I have resorted to using teh unigram only.

unigrams <- read_rds(here("app", "data/unigrams.rds"))
length(unigrams$word)
## [1] 6737543

There are too many words in our unigrams and it would significantly slow down the training process. We will be sampling ~ 60k of unigrams by filtering how many lines we want

sub_unigrams <- unigrams %>%
  filter(line < 800) %>% 
  pull(word)

Fitting to ngrams

markov_uni <- markovchainFit(sub_unigrams, method = "laplace")
write_rds(markov_uni, here("app/models/markov_uni_small.rds")) # smaller model

Loading the models

markov_uni <- read_rds(here("app/models/markov_uni.rds"))

MarkovChainSequence with unigram

word <- 'who'
markovchainSequence(
  n = 1,
  markovchain = markov_uni$estimate,
  t0 = word,
  include.t0 = F
) %>% paste(collapse = " ")
## [1] "wanted"

Predicting next word with markov model

next_word <- function(word, num=5) {
  sents <- c()
  for (i in 1:num) {
    set.seed(i) # randomize generation
    {
      sent <- markovchainSequence(
        n = 1,
        markovchain = markov_uni$estimate,
        t0 = word,
        include.t0 = F
      ) %>%  # set the first word
        paste(collapse = " ") 
      sents <- c(sents, sent)
    }
  }
  return(sents)
}

next_word('who')
## [1] "do"      "have"    "have"    "reached" "has"

Cleaning the input

bad_words <- read_rds(here("app/data/bad_words.rds"))
clean_input <- function(input) {
    input <- tibble(line = 1:(length(input)), text = input) %>%
        unnest_tokens(word, text) %>%
        filter(!str_detect(word, "\\d+")) %>%
        mutate_at("word", str_replace, "[[:punct:]]", "") %>% # remove punctuation
        anti_join(bad_words, by = "word") %>% # remove profane words
        pull(word)
    
    input
}

next word function

next_word <- function(word, num = 5) {
  word <- clean_input(word)
  length <- length(word)
  
  if (length > 1) {
    word <- word[length]
    print(word)
  }
  
  sents <- c()
  for (i in 1:num) {
    set.seed(i) # randomize generation
    {
      sent <- markovchainSequence(
        n = 1,
        markovchain = markov_uni$estimate,
        t0 = word,
        include.t0 = F
      ) %>%  # set the first word
        paste(collapse = " ")
      sents <- c(sents, sent)
    }
  }
  return(unique(sents))
}
next_word('Love', 5)
## [1] "with"  "it"    "those" "you"

Dealing with words not in markov chain

If words are not in our markov chain, the function will spit out a nasty errror Initial state is not defined

To handle this, we wrap the function in a try block, and the errors are surpressed.

pred <- try(markovchainSequence(
              n = 1,
              markovchain = markov_uni$estimate,
              t0 = 'zzz',
              include.t0 = F
          ), silent=T)

print(pred[1])
## [1] "Error in markovchainSequence(n = 1, markovchain = markov_uni$estimate,  : \n  Error! Initial state not defined\n"
if (nchar(pred[1]) > 100) {
  print("no predictions available :(")
}
## [1] "no predictions available :("

Better alternatives for language models:

For models to predict the next word, it needs to remember context and the order of the words. For example, the sentence “I grew up in France, I can speak fluent [MASK]”, if the model saves the information that the person grew up in France (a country), it should be able to produce a prediction that relates to the langauge in France, which is French. Langauge models today are able to do that very well.

Below are two approaches suitable for a next word prediction

Session info

sessionInfo()
## R version 4.0.4 (2021-02-15)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] stats     graphics  grDevices utils     datasets  methods   base     
## 
## other attached packages:
##  [1] here_1.0.1        markovchain_0.8.6 tokenizers_0.2.1  tidytext_0.3.1   
##  [5] forcats_0.5.1     stringr_1.4.0     dplyr_1.0.5       purrr_0.3.4      
##  [9] readr_1.4.0       tidyr_1.1.3       tibble_3.1.1      ggplot2_3.3.3    
## [13] tidyverse_1.3.1  
## 
## loaded via a namespace (and not attached):
##  [1] Rcpp_1.0.7         lubridate_1.7.10   lattice_0.20-41    rprojroot_2.0.2   
##  [5] assertthat_0.2.1   digest_0.6.27      utf8_1.2.1         R6_2.5.0          
##  [9] cellranger_1.1.0   backports_1.2.1    stats4_4.0.4       reprex_2.0.0      
## [13] evaluate_0.14      highr_0.8          httr_1.4.2         pillar_1.6.0      
## [17] rlang_0.4.10       readxl_1.3.1       rstudioapi_0.13    jquerylib_0.1.4   
## [21] Matrix_1.3-2       rmarkdown_2.7      igraph_1.2.6       munsell_0.5.0     
## [25] broom_0.7.6        compiler_4.0.4     janeaustenr_0.1.5  modelr_0.1.8      
## [29] xfun_0.22          pkgconfig_2.0.3    htmltools_0.5.2    tidyselect_1.1.1  
## [33] expm_0.999-6       fansi_0.4.2        crayon_1.4.1       dbplyr_2.1.1      
## [37] withr_2.4.2        SnowballC_0.7.0    grid_4.0.4         jsonlite_1.7.2    
## [41] gtable_0.3.0       lifecycle_1.0.0    DBI_1.1.1          magrittr_2.0.1    
## [45] scales_1.1.1       RcppParallel_5.1.4 cli_3.0.1          stringi_1.5.3     
## [49] fs_1.5.0           xml2_1.3.2         bslib_0.3.0        ellipsis_0.3.2    
## [53] generics_0.1.0     vctrs_0.3.8        tools_4.0.4        glue_1.4.2        
## [57] hms_1.0.0          parallel_4.0.4     fastmap_1.1.0      yaml_2.2.1        
## [61] colorspace_2.0-1   rvest_1.0.0        matlab_1.0.2       knitr_1.31        
## [65] haven_2.4.1        sass_0.4.0