library(tidyverse)## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.3 ✓ purrr 0.3.4
## ✓ tibble 3.1.1 ✓ dplyr 1.0.5
## ✓ tidyr 1.1.3 ✓ stringr 1.4.0
## ✓ readr 1.4.0 ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(here)## here() starts at /Users/benedictneo/next-word-predictor
library(feather)
library(multidplyr)
library(parallel)
library(tidytext)We load our ngrams with feather
ngrams_path <- here('data/ngrams')
bigrams <- read_feather(here(ngrams_path, "bigrams.feather"))
trigrams <- read_feather(here(ngrams_path, "trigrams.feather"))
quadgrams <- read_feather(here(ngrams_path, "quadgram.feather"))
bad_words <- read_feather(here('data/bad_words.feather'))cl <- detectCores()
cl## [1] 8
This tells us my machine has 8 cores
cluster <- new_cluster(cl)
cluster## 8 session cluster [........]
Now that we’ve created our cluster object, we can start using multidplyr
cluster_library(cluster, "tidyverse")Attatch the tidyverse library to our cluster
First we partition our data by assigning groups to each row of our data
group <- rep(1:cl, length.out = nrow(bigrams))
bigrams <- bind_cols(tibble(group), bigrams)
head(bigrams, 10)We’ve grouped our data into 8 groups, since we have 8 cores. This mean each core will handle each subset of our data.
bigrams <- bigrams %>%
group_by(group) %>%
partition(cluster = cluster)
bigrams## Source: party_df [1,066,311 x 4]
## Groups: group
## Shards: 8 [133,288--133,289 rows]
##
## group line word1 word2
## <int> <int> <chr> <chr>
## 1 1 1 peter schiff
## 2 1 2 charles martin
## 3 1 2 character chase
## 4 1 3 finally finished
## 5 1 3 hefty novels
## 6 1 3 woolf experience
## # … with 1,066,305 more rows
The output tells us information about the clusters and how many rows are in each cluster.
matchBigram <- function(input1, n = 5) {
prediction <- bigrams %>%
filter(word1 == input1) %>%
collect() %>%
mutate(freq = str_count(word2)) %>%
arrange(desc(freq)) %>%
pull(word2)
prediction[1:n]
}
matchBigram('bad')## [1] "implementation" "humidification" "representation" "motorcyclists"
## [5] "relationship"
Here, the words that follow bad are the following. They are also ordered in terms of frequency (most frequent to least)
group <- rep(1:cl, length.out = nrow(trigrams))
trigrams <- bind_cols(tibble(group), trigrams)
trigrams <- trigrams %>%
group_by(group) %>%
partition(cluster = cluster)
trigrams## Source: party_df [400,587 x 5]
## Groups: group
## Shards: 8 [50,073--50,074 rows]
##
## group line word1 word2 word3
## <int> <int> <chr> <chr> <chr>
## 1 1 1 peter schiff hard
## 2 1 2 loser custodial parent
## 3 1 7 mining investment boom
## 4 1 10 dry wall screws
## 5 1 14 hollywood movie world
## 6 1 18 mornings afternoons evenings
## # … with 400,581 more rows
matchTrigram <- function(input1, input2, n = 5) {
# match 1st and 2nd word in trigram, and return third word
prediction <- trigrams %>%
filter(word1 == input1, word2 == input2) %>%
collect() %>%
mutate(freq = str_count(word3)) %>%
arrange(desc(freq)) %>%
pull(word3)
# if no matches, match 1st word in trigram, and return 2nd word
if (length(prediction) == 0) {
prediction <- trigrams %>%
filter(word1 == input2) %>%
collect() %>%
mutate(freq = str_count(word2)) %>%
arrange(desc(freq)) %>%
pull(word2)
# if no matches, match 2nd word in trigram, and return 3rd word
if (length(prediction) == 0) {
prediction <- trigrams %>%
filter(word2 == input2) %>%
collect() %>%
mutate(freq = str_count(word3)) %>%
arrange(desc(freq)) %>%
pull(word3)
# all else fails, find match in bigram
if (length(prediction) == 0) {
prediction <- matchBigram(input2, n)
}
}
}
prediction[1:n]
}
matchTrigram('I', 'love')## [1] "itmakesmesmile" "constructive" "demographics" "relationship"
## [5] "freakonomics"
The comments pretty much tell the story, if nothing matches the trigram, we match the last word to our bigrams.
group <- rep(1:cl, length.out = nrow(quadgrams))
quadgrams <- bind_cols(tibble(group), quadgrams)
quadgrams <- quadgrams %>%
group_by(group) %>%
partition(cluster = cluster)
quadgrams## Source: party_df [160,327 x 6]
## Groups: group
## Shards: 8 [20,040--20,041 rows]
##
## group line word1 word2 word3 word4
## <int> <int> <chr> <chr> <chr> <chr>
## 1 1 2 neighbor recommended chasing fireflies
## 2 1 18 minutes hours mornings afternoons
## 3 1 34 anti star wars science
## 4 1 60 failed media blackout controversy
## 5 1 63 yogurt garlic lemon juice
## 6 1 75 graeme hall nature sanctuary
## # … with 160,321 more rows
matchQuadgram <- function(input1, input2, input3, n=5) {
# match 1st, 2nd, 3rd word in quadgram, and return 4th word
prediction <- quadgrams %>%
filter(word1 == input1, word2 == input2, word3 == input3) %>%
collect() %>%
mutate(freq = str_count(word4)) %>%
arrange(desc(freq)) %>%
pull(word4)
# match 1st and 2nd, return 3rd word
if (length(prediction) == 0) {
prediction <- quadgrams %>%
filter(word1 == input2, word2 == input3) %>%
collect() %>%
mutate(freq = str_count(word3)) %>%
arrange(desc(freq)) %>%
pull(word3)
# match 2nd and 3rd, return 4th
if (length(prediction) == 0) {
prediction <- quadgrams %>%
filter(word2 == input2, word3 == input3) %>%
collect() %>%
mutate(freq = str_count(word4)) %>%
arrange(desc(freq)) %>%
pull(word4)
# if no matches, find match in trigrams
if (length(prediction) == 0) {
prediction <- matchTrigram(input2, input3, n)
}
}
}
prediction[1:n]
}
matchQuadgram('my', 'favourite', 'food')## [1] "transportation" "entertainment" "availability" "restaurants"
## [5] "specialties"
clean_input <- function(input) {
input <- tibble(line = 1:(length(input)), text = input) %>%
unnest_tokens(word, text) %>%
filter(!str_detect(word, "\\d+")) %>%
mutate_at("word", str_replace, "[[:punct:]]", "") %>% # remove punctuation
anti_join(bad_words, by = "word") %>% # remove profane words
pull(word)
input
}
clean_input("I h8 this crap SO much!!!")## [1] "i" "this" "so" "much"
next_word <- function(input, n=5) {
input <- clean_input(input)
wordCount <- length(input)
if (wordCount == 0) {
pred <- "Please enter a word"
}
if (wordCount == 1) {
pred <- matchBigram(input[1], n)
}
if (wordCount == 2) {
pred <- matchTrigram(input[1], input[2], n)
}
if (wordCount == 3) {
pred <- matchQuadgram(input[1], input[2], input[3], n)
}
if (wordCount > 3) {
# match with last three words in input
input <- input[(wordCount - 2):wordCount]
pred <- matchQuadgram(input[1], input[2], input[3], n)
}
if(NA %in% pred) {
return("No predictions available :(")
}
else {
return(pred)
}
}
next_word("President of the United")## [1] "international" "conservation" "reservations" "communities"
## [5] "autoworkers"
Testing a simple example on the final function, it seems my ngrams model fails to give us States, which has to be the most obvious response to the next word prediction. It also shows you the limitations of a simple ngrams model, since it can only give us predictions that is contained in our ngrams
For models to predict the next word, it needs to remember context and the order of the words. For example, the sentence “I grew up in France, I can speak fluent [MASK]”, if the model saves the information that the person grew up in France (a country), it should be able to produce a prediction that relates to the langauge in France, which is French. Langauge models today are able to do that very well.
Below are two approaches suitable for a next word prediction
Read https://colah.github.io/posts/2015-08-Understanding-LSTMs/ to learn more about LSTM
Resources:
sessionInfo()## R version 4.0.4 (2021-02-15)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] parallel stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] tidytext_0.3.1 multidplyr_0.1.0 feather_0.3.5.9000 here_1.0.1
## [5] forcats_0.5.1 stringr_1.4.0 dplyr_1.0.5 purrr_0.3.4
## [9] readr_1.4.0 tidyr_1.1.3 tibble_3.1.1 ggplot2_3.3.3
## [13] tidyverse_1.3.1
##
## loaded via a namespace (and not attached):
## [1] httr_1.4.2 sass_0.4.0 bit64_4.0.5
## [4] jsonlite_1.7.2 modelr_0.1.8 RcppParallel_5.1.4
## [7] bslib_0.3.0 assertthat_0.2.1 cellranger_1.1.0
## [10] yaml_2.2.1 pillar_1.6.0 backports_1.2.1
## [13] lattice_0.20-41 glue_1.4.2 digest_0.6.27
## [16] rvest_1.0.0 stringfish_0.15.3 colorspace_2.0-1
## [19] htmltools_0.5.2 Matrix_1.3-2 pkgconfig_2.0.3
## [22] broom_0.7.6 haven_2.4.1 scales_1.1.1
## [25] processx_3.5.2 RApiSerialize_0.1.0 generics_0.1.0
## [28] ellipsis_0.3.2 withr_2.4.2 cli_3.0.1
## [31] magrittr_2.0.1 crayon_1.4.1 readxl_1.3.1
## [34] evaluate_0.14 ps_1.6.0 tokenizers_0.2.1
## [37] janeaustenr_0.1.5 fs_1.5.0 fansi_0.4.2
## [40] SnowballC_0.7.0 xml2_1.3.2 tools_4.0.4
## [43] hms_1.0.0 lifecycle_1.0.0 munsell_0.5.0
## [46] reprex_2.0.0 callr_3.7.0 compiler_4.0.4
## [49] jquerylib_0.1.4 qs_0.25.1 rlang_0.4.10
## [52] grid_4.0.4 rstudioapi_0.13 rmarkdown_2.7
## [55] arrow_5.0.0.2 gtable_0.3.0 DBI_1.1.1
## [58] R6_2.5.0 lubridate_1.7.10 knitr_1.31
## [61] fastmap_1.1.0 bit_4.0.4 utf8_1.2.1
## [64] rprojroot_2.0.2 stringi_1.5.3 Rcpp_1.0.7
## [67] vctrs_0.3.8 dbplyr_2.1.1 tidyselect_1.1.1
## [70] xfun_0.22