Load libraries

library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.3     ✓ purrr   0.3.4
## ✓ tibble  3.1.1     ✓ dplyr   1.0.5
## ✓ tidyr   1.1.3     ✓ stringr 1.4.0
## ✓ readr   1.4.0     ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(here)
## here() starts at /Users/benedictneo/next-word-predictor
library(feather)
library(multidplyr)
library(parallel)
library(tidytext)

We load our ngrams with feather

ngrams_path <- here('data/ngrams')

bigrams <- read_feather(here(ngrams_path, "bigrams.feather")) 
trigrams <- read_feather(here(ngrams_path, "trigrams.feather"))
quadgrams  <- read_feather(here(ngrams_path, "quadgram.feather"))
bad_words <- read_feather(here('data/bad_words.feather'))

Parallel Processing

detect cores

cl <- detectCores()
cl
## [1] 8

This tells us my machine has 8 cores

Creating clusters

cluster <- new_cluster(cl)
cluster
## 8 session cluster [........]

Now that we’ve created our cluster object, we can start using multidplyr

cluster_library(cluster, "tidyverse")

Attatch the tidyverse library to our cluster

Partition dataset for parallel processing

First we partition our data by assigning groups to each row of our data

group <- rep(1:cl, length.out = nrow(bigrams))
bigrams <- bind_cols(tibble(group), bigrams)
head(bigrams, 10)

We’ve grouped our data into 8 groups, since we have 8 cores. This mean each core will handle each subset of our data.

Bigrams

bigrams <- bigrams %>%
    group_by(group) %>% 
    partition(cluster = cluster)
bigrams
## Source: party_df [1,066,311 x 4]
## Groups: group
## Shards: 8 [133,288--133,289 rows]
## 
##   group  line word1     word2     
##   <int> <int> <chr>     <chr>     
## 1     1     1 peter     schiff    
## 2     1     2 charles   martin    
## 3     1     2 character chase     
## 4     1     3 finally   finished  
## 5     1     3 hefty     novels    
## 6     1     3 woolf     experience
## # … with 1,066,305 more rows

The output tells us information about the clusters and how many rows are in each cluster.

matching to bigram

matchBigram <- function(input1, n = 5) {
    prediction <- bigrams %>%
        filter(word1 == input1) %>%
        collect() %>%
        mutate(freq = str_count(word2)) %>%
        arrange(desc(freq)) %>% 
        pull(word2)
    
    prediction[1:n]
}

matchBigram('bad')
## [1] "implementation" "humidification" "representation" "motorcyclists" 
## [5] "relationship"

Here, the words that follow bad are the following. They are also ordered in terms of frequency (most frequent to least)

Trigram

group <- rep(1:cl, length.out = nrow(trigrams))
trigrams <- bind_cols(tibble(group), trigrams)

trigrams <- trigrams %>%
    group_by(group) %>% 
    partition(cluster = cluster)
trigrams
## Source: party_df [400,587 x 5]
## Groups: group
## Shards: 8 [50,073--50,074 rows]
## 
##   group  line word1     word2      word3   
##   <int> <int> <chr>     <chr>      <chr>   
## 1     1     1 peter     schiff     hard    
## 2     1     2 loser     custodial  parent  
## 3     1     7 mining    investment boom    
## 4     1    10 dry       wall       screws  
## 5     1    14 hollywood movie      world   
## 6     1    18 mornings  afternoons evenings
## # … with 400,581 more rows

Match to trigrams

matchTrigram <- function(input1, input2, n = 5) {
    
    # match 1st and 2nd word in trigram, and return third word
    prediction <- trigrams %>%
        filter(word1 == input1, word2 == input2) %>%
        collect() %>%
        mutate(freq = str_count(word3)) %>%
        arrange(desc(freq)) %>%
        pull(word3)
    
    # if no matches, match 1st word in trigram, and return 2nd word
    if (length(prediction) == 0) {
        prediction <- trigrams %>%
            filter(word1 == input2) %>%
            collect() %>%
            mutate(freq = str_count(word2)) %>%
            arrange(desc(freq)) %>%
            pull(word2)
        
        # if no matches, match 2nd word in trigram, and return 3rd word
        if (length(prediction) == 0) {
            prediction <- trigrams %>%
                filter(word2 == input2) %>%
                collect() %>%
                mutate(freq = str_count(word3)) %>%
                arrange(desc(freq)) %>%
                pull(word3)
            
            # all else fails, find match in bigram
            if (length(prediction) == 0) {
                prediction <- matchBigram(input2, n)
            }
        }
    }
    
    prediction[1:n]
}

matchTrigram('I', 'love')
## [1] "itmakesmesmile" "constructive"   "demographics"   "relationship"  
## [5] "freakonomics"

The comments pretty much tell the story, if nothing matches the trigram, we match the last word to our bigrams.

Quadgram

group <- rep(1:cl, length.out = nrow(quadgrams))
quadgrams <- bind_cols(tibble(group), quadgrams)

quadgrams <- quadgrams %>%
    group_by(group) %>% 
    partition(cluster = cluster)
quadgrams
## Source: party_df [160,327 x 6]
## Groups: group
## Shards: 8 [20,040--20,041 rows]
## 
##   group  line word1    word2       word3    word4      
##   <int> <int> <chr>    <chr>       <chr>    <chr>      
## 1     1     2 neighbor recommended chasing  fireflies  
## 2     1    18 minutes  hours       mornings afternoons 
## 3     1    34 anti     star        wars     science    
## 4     1    60 failed   media       blackout controversy
## 5     1    63 yogurt   garlic      lemon    juice      
## 6     1    75 graeme   hall        nature   sanctuary  
## # … with 160,321 more rows

Matching to quadgrams

matchQuadgram <- function(input1, input2, input3, n=5) {
    
    # match 1st, 2nd, 3rd word in quadgram, and return 4th word
    prediction <- quadgrams %>%
        filter(word1 == input1, word2 == input2, word3 == input3) %>%
        collect() %>%
        mutate(freq = str_count(word4)) %>%
        arrange(desc(freq)) %>%
        pull(word4)
    
    # match 1st and 2nd, return 3rd word
    if (length(prediction) == 0) {
        prediction <- quadgrams %>%
            filter(word1 == input2, word2 == input3) %>%
            collect() %>%
            mutate(freq = str_count(word3)) %>%
            arrange(desc(freq)) %>%
            pull(word3)
        
        # match 2nd and 3rd, return 4th
        if (length(prediction) == 0) {
            prediction <- quadgrams %>%
                filter(word2 == input2, word3 == input3) %>%
                collect() %>%
                mutate(freq = str_count(word4)) %>%
                arrange(desc(freq)) %>%
                pull(word4)
            
            # if no matches, find match in trigrams
            if (length(prediction) == 0) {
                prediction <- matchTrigram(input2, input3, n)
            }
        }
    }
    
    prediction[1:n]
}

matchQuadgram('my', 'favourite', 'food')
## [1] "transportation" "entertainment"  "availability"   "restaurants"   
## [5] "specialties"

Clean input text

clean_input <- function(input) {
    
    input <- tibble(line = 1:(length(input)), text = input) %>%
        unnest_tokens(word, text) %>%
        filter(!str_detect(word, "\\d+")) %>%
        mutate_at("word", str_replace, "[[:punct:]]", "") %>% # remove punctuation
        anti_join(bad_words, by = "word") %>% # remove profane words
        pull(word)
    
    input
}

clean_input("I h8 this crap SO much!!!")
## [1] "i"    "this" "so"   "much"

Next word prediction

next_word <- function(input, n=5) {
    input <- clean_input(input)
    wordCount <- length(input)
    
    if (wordCount == 0) {
        pred <- "Please enter a word"
    }
    
    if (wordCount == 1) {
        pred <- matchBigram(input[1], n)
    }
    
    if (wordCount == 2) {
        pred <- matchTrigram(input[1], input[2], n)
    }
    
    if (wordCount == 3) {
        pred <- matchQuadgram(input[1], input[2], input[3], n)
    }
    
    if (wordCount > 3) {
        # match with last three words in input
        input <- input[(wordCount - 2):wordCount]
        pred <- matchQuadgram(input[1], input[2], input[3], n)
    }
    
    if(NA %in% pred) {
        return("No predictions available :(")
    }
    else {
        return(pred)
    }
}

next_word("President of the United")
## [1] "international" "conservation"  "reservations"  "communities"  
## [5] "autoworkers"

Testing a simple example on the final function, it seems my ngrams model fails to give us States, which has to be the most obvious response to the next word prediction. It also shows you the limitations of a simple ngrams model, since it can only give us predictions that is contained in our ngrams

Better alternatives for language models:

For models to predict the next word, it needs to remember context and the order of the words. For example, the sentence “I grew up in France, I can speak fluent [MASK]”, if the model saves the information that the person grew up in France (a country), it should be able to produce a prediction that relates to the langauge in France, which is French. Langauge models today are able to do that very well.

Below are two approaches suitable for a next word prediction

Transformers

Resources:

sessionInfo()
## R version 4.0.4 (2021-02-15)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
## 
## Matrix products: default
## BLAS:   /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
## 
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
## 
## attached base packages:
## [1] parallel  stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
##  [1] tidytext_0.3.1     multidplyr_0.1.0   feather_0.3.5.9000 here_1.0.1        
##  [5] forcats_0.5.1      stringr_1.4.0      dplyr_1.0.5        purrr_0.3.4       
##  [9] readr_1.4.0        tidyr_1.1.3        tibble_3.1.1       ggplot2_3.3.3     
## [13] tidyverse_1.3.1   
## 
## loaded via a namespace (and not attached):
##  [1] httr_1.4.2          sass_0.4.0          bit64_4.0.5        
##  [4] jsonlite_1.7.2      modelr_0.1.8        RcppParallel_5.1.4 
##  [7] bslib_0.3.0         assertthat_0.2.1    cellranger_1.1.0   
## [10] yaml_2.2.1          pillar_1.6.0        backports_1.2.1    
## [13] lattice_0.20-41     glue_1.4.2          digest_0.6.27      
## [16] rvest_1.0.0         stringfish_0.15.3   colorspace_2.0-1   
## [19] htmltools_0.5.2     Matrix_1.3-2        pkgconfig_2.0.3    
## [22] broom_0.7.6         haven_2.4.1         scales_1.1.1       
## [25] processx_3.5.2      RApiSerialize_0.1.0 generics_0.1.0     
## [28] ellipsis_0.3.2      withr_2.4.2         cli_3.0.1          
## [31] magrittr_2.0.1      crayon_1.4.1        readxl_1.3.1       
## [34] evaluate_0.14       ps_1.6.0            tokenizers_0.2.1   
## [37] janeaustenr_0.1.5   fs_1.5.0            fansi_0.4.2        
## [40] SnowballC_0.7.0     xml2_1.3.2          tools_4.0.4        
## [43] hms_1.0.0           lifecycle_1.0.0     munsell_0.5.0      
## [46] reprex_2.0.0        callr_3.7.0         compiler_4.0.4     
## [49] jquerylib_0.1.4     qs_0.25.1           rlang_0.4.10       
## [52] grid_4.0.4          rstudioapi_0.13     rmarkdown_2.7      
## [55] arrow_5.0.0.2       gtable_0.3.0        DBI_1.1.1          
## [58] R6_2.5.0            lubridate_1.7.10    knitr_1.31         
## [61] fastmap_1.1.0       bit_4.0.4           utf8_1.2.1         
## [64] rprojroot_2.0.2     stringi_1.5.3       Rcpp_1.0.7         
## [67] vctrs_0.3.8         dbplyr_2.1.1        tidyselect_1.1.1   
## [70] xfun_0.22