library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.3 ✓ purrr 0.3.4
## ✓ tibble 3.1.1 ✓ dplyr 1.0.5
## ✓ tidyr 1.1.3 ✓ stringr 1.4.0
## ✓ readr 1.4.0 ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(tidytext)
library(tokenizers)
library(markovchain)
## Package: markovchain
## Version: 0.8.6
## Date: 2021-05-17
## BugReport: https://github.com/spedygiorgio/markovchain/issues
library(here)
## here() starts at /Users/benedictneo/next-word-predictor
A Markov chain or Markov process is a stochastic model describing a sequence of possible events in which the probability of each event depends only on the state attained in the previous event
Take weather for example, we can have 3 states - rainy, sunny and cloudy. using the previous events (it rained today)
Using this principle, we can predict the next word based on the last word typed. The Markov Chain model will model the transition probability between states, where each state are the tokens.
More information about markov chains on the wikipedia page
Below is an example of how it works with a simple sentence.
We have a simple sentence, and then tokenize it to individual tokens.
<- c("the quick brown fox jumps over the lazy dog and the angry dog chase the fox")
text
<- strsplit(text, split = " ") %>% unlist()) (tokens
## [1] "the" "quick" "brown" "fox" "jumps" "over" "the" "lazy" "dog"
## [10] "and" "the" "angry" "dog" "chase" "the" "fox"
To create a Markov chains without manually calculating the tarnsitions, we can utilize the markovchain
package.
<- markovchainFit(tokens, method = "laplace") simple_markov
To visualize our markov chain, we can simply use the plot function
set.seed(2021)
plot(simple_markov$estimate)
The arrows indicate the transition to the next state and the numbers are the probability of those transitions. For example, starting at “the”, there is a probability it will go to “angry”, “lazy”, “quick”, and “fox”. As “the” is before those 4 words, the probability that it will transition to them is equally likely, which is why it’s 0.25.
With our markov model, we can run $estimate
on it and it will show the dimensions (number of words) and the transition matrix.
$estimate simple_markov
## Laplacian Smooth Fit
## A 11 - dimensional discrete Markov Chain defined by the following states:
## and, angry, brown, chase, dog, fox, jumps, lazy, over, quick, the
## The transition matrix (by rows) is defined as follows:
## and angry brown chase dog fox jumps lazy over quick the
## and 0.0 0.00 0 0.0 0 0.00 0 0.00 0 0.00 1
## angry 0.0 0.00 0 0.0 1 0.00 0 0.00 0 0.00 0
## brown 0.0 0.00 0 0.0 0 1.00 0 0.00 0 0.00 0
## chase 0.0 0.00 0 0.0 0 0.00 0 0.00 0 0.00 1
## dog 0.5 0.00 0 0.5 0 0.00 0 0.00 0 0.00 0
## fox 0.0 0.00 0 0.0 0 0.00 1 0.00 0 0.00 0
## jumps 0.0 0.00 0 0.0 0 0.00 0 0.00 1 0.00 0
## lazy 0.0 0.00 0 0.0 1 0.00 0 0.00 0 0.00 0
## over 0.0 0.00 0 0.0 0 0.00 0 0.00 0 0.00 1
## quick 0.0 0.00 1 0.0 0 0.00 0 0.00 0 0.00 0
## the 0.0 0.25 0 0.0 0 0.25 0 0.25 0 0.25 0
Using the markovchainSequence
, we can generate a sentence based on our markov chain model
markovchainSequence(
n = 5,
markovchain = simple_markov$estimate,
t0 = "the",
# set the first word
include.t0 = T
%>%
) paste(collapse = " ")
## [1] "the fox jumps over the quick"
Boom! A sentence was generated, althought it’s the same exact sentence it was fed with.
We can also generate multiple sentences by doing a for loop around it.
for (i in 1:5) {
set.seed(i)
markovchainSequence(
n = 5,
markovchain = simple_markov$estimate,
t0 = "the",
# set the first word
include.t0 = T
%>%
) paste(collapse = " ") %>%
print()
}
## [1] "the fox jumps over the angry"
## [1] "the angry dog chase the quick"
## [1] "the angry dog and the lazy"
## [1] "the lazy dog and the quick"
## [1] "the angry dog chase the angry"
Now let’s move on to fitting our ngram to the markov model.
The plan was to fit all unigram, bigram, trigram and quadgram to the markov chain. However, due to the memory and size limit of the shiny app, I have resorted to using teh unigram only.
<- read_rds(here("app", "data/unigrams.rds")) unigrams
length(unigrams$word)
## [1] 6737543
There are too many words in our unigrams and it would significantly slow down the training process. We will be sampling ~ 60k of unigrams by filtering how many lines we want
<- unigrams %>%
sub_unigrams filter(line < 800) %>%
pull(word)
<- markovchainFit(sub_unigrams, method = "laplace")
markov_uni write_rds(markov_uni, here("app/models/markov_uni_small.rds")) # smaller model
<- read_rds(here("app/models/markov_uni.rds")) markov_uni
<- 'who'
word markovchainSequence(
n = 1,
markovchain = markov_uni$estimate,
t0 = word,
include.t0 = F
%>% paste(collapse = " ") )
## [1] "wanted"
<- function(word, num=5) {
next_word <- c()
sents for (i in 1:num) {
set.seed(i) # randomize generation
{<- markovchainSequence(
sent n = 1,
markovchain = markov_uni$estimate,
t0 = word,
include.t0 = F
%>% # set the first word
) paste(collapse = " ")
<- c(sents, sent)
sents
}
}return(sents)
}
next_word('who')
## [1] "do" "have" "have" "reached" "has"
<- read_rds(here("app/data/bad_words.rds"))
bad_words <- function(input) {
clean_input <- tibble(line = 1:(length(input)), text = input) %>%
input unnest_tokens(word, text) %>%
filter(!str_detect(word, "\\d+")) %>%
mutate_at("word", str_replace, "[[:punct:]]", "") %>% # remove punctuation
anti_join(bad_words, by = "word") %>% # remove profane words
pull(word)
input }
<- function(word, num = 5) {
next_word <- clean_input(word)
word <- length(word)
length
if (length > 1) {
<- word[length]
word print(word)
}
<- c()
sents for (i in 1:num) {
set.seed(i) # randomize generation
{<- markovchainSequence(
sent n = 1,
markovchain = markov_uni$estimate,
t0 = word,
include.t0 = F
%>% # set the first word
) paste(collapse = " ")
<- c(sents, sent)
sents
}
}return(unique(sents))
}
next_word('Love', 5)
## [1] "with" "it" "those" "you"
If words are not in our markov chain, the function will spit out a nasty errror Initial state is not defined
To handle this, we wrap the function in a try block, and the errors are surpressed.
<- try(markovchainSequence(
pred n = 1,
markovchain = markov_uni$estimate,
t0 = 'zzz',
include.t0 = F
silent=T)
),
print(pred[1])
## [1] "Error in markovchainSequence(n = 1, markovchain = markov_uni$estimate, : \n Error! Initial state not defined\n"
if (nchar(pred[1]) > 100) {
print("no predictions available :(")
}
## [1] "no predictions available :("
For models to predict the next word, it needs to remember context and the order of the words. For example, the sentence “I grew up in France, I can speak fluent [MASK]”, if the model saves the information that the person grew up in France (a country), it should be able to produce a prediction that relates to the langauge in France, which is French. Langauge models today are able to do that very well.
Below are two approaches suitable for a next word prediction
Read https://colah.github.io/posts/2015-08-Understanding-LSTMs/ to learn more about LSTM
sessionInfo()
## R version 4.0.4 (2021-02-15)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] here_1.0.1 markovchain_0.8.6 tokenizers_0.2.1 tidytext_0.3.1
## [5] forcats_0.5.1 stringr_1.4.0 dplyr_1.0.5 purrr_0.3.4
## [9] readr_1.4.0 tidyr_1.1.3 tibble_3.1.1 ggplot2_3.3.3
## [13] tidyverse_1.3.1
##
## loaded via a namespace (and not attached):
## [1] Rcpp_1.0.7 lubridate_1.7.10 lattice_0.20-41 rprojroot_2.0.2
## [5] assertthat_0.2.1 digest_0.6.27 utf8_1.2.1 R6_2.5.0
## [9] cellranger_1.1.0 backports_1.2.1 stats4_4.0.4 reprex_2.0.0
## [13] evaluate_0.14 highr_0.8 httr_1.4.2 pillar_1.6.0
## [17] rlang_0.4.10 readxl_1.3.1 rstudioapi_0.13 jquerylib_0.1.4
## [21] Matrix_1.3-2 rmarkdown_2.7 igraph_1.2.6 munsell_0.5.0
## [25] broom_0.7.6 compiler_4.0.4 janeaustenr_0.1.5 modelr_0.1.8
## [29] xfun_0.22 pkgconfig_2.0.3 htmltools_0.5.2 tidyselect_1.1.1
## [33] expm_0.999-6 fansi_0.4.2 crayon_1.4.1 dbplyr_2.1.1
## [37] withr_2.4.2 SnowballC_0.7.0 grid_4.0.4 jsonlite_1.7.2
## [41] gtable_0.3.0 lifecycle_1.0.0 DBI_1.1.1 magrittr_2.0.1
## [45] scales_1.1.1 RcppParallel_5.1.4 cli_3.0.1 stringi_1.5.3
## [49] fs_1.5.0 xml2_1.3.2 bslib_0.3.0 ellipsis_0.3.2
## [53] generics_0.1.0 vctrs_0.3.8 tools_4.0.4 glue_1.4.2
## [57] hms_1.0.0 parallel_4.0.4 fastmap_1.1.0 yaml_2.2.1
## [61] colorspace_2.0-1 rvest_1.0.0 matlab_1.0.2 knitr_1.31
## [65] haven_2.4.1 sass_0.4.0