library(tidyverse)
## ── Attaching packages ─────────────────────────────────────── tidyverse 1.3.1 ──
## ✓ ggplot2 3.3.3 ✓ purrr 0.3.4
## ✓ tibble 3.1.1 ✓ dplyr 1.0.5
## ✓ tidyr 1.1.3 ✓ stringr 1.4.0
## ✓ readr 1.4.0 ✓ forcats 0.5.1
## ── Conflicts ────────────────────────────────────────── tidyverse_conflicts() ──
## x dplyr::filter() masks stats::filter()
## x dplyr::lag() masks stats::lag()
library(here)
## here() starts at /Users/benedictneo/next-word-predictor
library(feather)
library(multidplyr)
library(parallel)
library(tidytext)
We load our ngrams with feather
<- here('data/ngrams')
ngrams_path
<- read_feather(here(ngrams_path, "bigrams.feather"))
bigrams <- read_feather(here(ngrams_path, "trigrams.feather"))
trigrams <- read_feather(here(ngrams_path, "quadgram.feather"))
quadgrams <- read_feather(here('data/bad_words.feather')) bad_words
<- detectCores()
cl cl
## [1] 8
This tells us my machine has 8 cores
<- new_cluster(cl)
cluster cluster
## 8 session cluster [........]
Now that we’ve created our cluster object, we can start using multidplyr
cluster_library(cluster, "tidyverse")
Attatch the tidyverse library to our cluster
First we partition our data by assigning groups to each row of our data
<- rep(1:cl, length.out = nrow(bigrams))
group <- bind_cols(tibble(group), bigrams)
bigrams head(bigrams, 10)
We’ve grouped our data into 8 groups, since we have 8 cores. This mean each core will handle each subset of our data.
<- bigrams %>%
bigrams group_by(group) %>%
partition(cluster = cluster)
bigrams
## Source: party_df [1,066,311 x 4]
## Groups: group
## Shards: 8 [133,288--133,289 rows]
##
## group line word1 word2
## <int> <int> <chr> <chr>
## 1 1 1 peter schiff
## 2 1 2 charles martin
## 3 1 2 character chase
## 4 1 3 finally finished
## 5 1 3 hefty novels
## 6 1 3 woolf experience
## # … with 1,066,305 more rows
The output tells us information about the clusters and how many rows are in each cluster.
<- function(input1, n = 5) {
matchBigram <- bigrams %>%
prediction filter(word1 == input1) %>%
collect() %>%
mutate(freq = str_count(word2)) %>%
arrange(desc(freq)) %>%
pull(word2)
1:n]
prediction[
}
matchBigram('bad')
## [1] "implementation" "humidification" "representation" "motorcyclists"
## [5] "relationship"
Here, the words that follow bad are the following. They are also ordered in terms of frequency (most frequent to least)
<- rep(1:cl, length.out = nrow(trigrams))
group <- bind_cols(tibble(group), trigrams)
trigrams
<- trigrams %>%
trigrams group_by(group) %>%
partition(cluster = cluster)
trigrams
## Source: party_df [400,587 x 5]
## Groups: group
## Shards: 8 [50,073--50,074 rows]
##
## group line word1 word2 word3
## <int> <int> <chr> <chr> <chr>
## 1 1 1 peter schiff hard
## 2 1 2 loser custodial parent
## 3 1 7 mining investment boom
## 4 1 10 dry wall screws
## 5 1 14 hollywood movie world
## 6 1 18 mornings afternoons evenings
## # … with 400,581 more rows
<- function(input1, input2, n = 5) {
matchTrigram
# match 1st and 2nd word in trigram, and return third word
<- trigrams %>%
prediction filter(word1 == input1, word2 == input2) %>%
collect() %>%
mutate(freq = str_count(word3)) %>%
arrange(desc(freq)) %>%
pull(word3)
# if no matches, match 1st word in trigram, and return 2nd word
if (length(prediction) == 0) {
<- trigrams %>%
prediction filter(word1 == input2) %>%
collect() %>%
mutate(freq = str_count(word2)) %>%
arrange(desc(freq)) %>%
pull(word2)
# if no matches, match 2nd word in trigram, and return 3rd word
if (length(prediction) == 0) {
<- trigrams %>%
prediction filter(word2 == input2) %>%
collect() %>%
mutate(freq = str_count(word3)) %>%
arrange(desc(freq)) %>%
pull(word3)
# all else fails, find match in bigram
if (length(prediction) == 0) {
<- matchBigram(input2, n)
prediction
}
}
}
1:n]
prediction[
}
matchTrigram('I', 'love')
## [1] "itmakesmesmile" "constructive" "demographics" "relationship"
## [5] "freakonomics"
The comments pretty much tell the story, if nothing matches the trigram, we match the last word to our bigrams.
<- rep(1:cl, length.out = nrow(quadgrams))
group <- bind_cols(tibble(group), quadgrams)
quadgrams
<- quadgrams %>%
quadgrams group_by(group) %>%
partition(cluster = cluster)
quadgrams
## Source: party_df [160,327 x 6]
## Groups: group
## Shards: 8 [20,040--20,041 rows]
##
## group line word1 word2 word3 word4
## <int> <int> <chr> <chr> <chr> <chr>
## 1 1 2 neighbor recommended chasing fireflies
## 2 1 18 minutes hours mornings afternoons
## 3 1 34 anti star wars science
## 4 1 60 failed media blackout controversy
## 5 1 63 yogurt garlic lemon juice
## 6 1 75 graeme hall nature sanctuary
## # … with 160,321 more rows
<- function(input1, input2, input3, n=5) {
matchQuadgram
# match 1st, 2nd, 3rd word in quadgram, and return 4th word
<- quadgrams %>%
prediction filter(word1 == input1, word2 == input2, word3 == input3) %>%
collect() %>%
mutate(freq = str_count(word4)) %>%
arrange(desc(freq)) %>%
pull(word4)
# match 1st and 2nd, return 3rd word
if (length(prediction) == 0) {
<- quadgrams %>%
prediction filter(word1 == input2, word2 == input3) %>%
collect() %>%
mutate(freq = str_count(word3)) %>%
arrange(desc(freq)) %>%
pull(word3)
# match 2nd and 3rd, return 4th
if (length(prediction) == 0) {
<- quadgrams %>%
prediction filter(word2 == input2, word3 == input3) %>%
collect() %>%
mutate(freq = str_count(word4)) %>%
arrange(desc(freq)) %>%
pull(word4)
# if no matches, find match in trigrams
if (length(prediction) == 0) {
<- matchTrigram(input2, input3, n)
prediction
}
}
}
1:n]
prediction[
}
matchQuadgram('my', 'favourite', 'food')
## [1] "transportation" "entertainment" "availability" "restaurants"
## [5] "specialties"
<- function(input) {
clean_input
<- tibble(line = 1:(length(input)), text = input) %>%
input unnest_tokens(word, text) %>%
filter(!str_detect(word, "\\d+")) %>%
mutate_at("word", str_replace, "[[:punct:]]", "") %>% # remove punctuation
anti_join(bad_words, by = "word") %>% # remove profane words
pull(word)
input
}
clean_input("I h8 this crap SO much!!!")
## [1] "i" "this" "so" "much"
<- function(input, n=5) {
next_word <- clean_input(input)
input <- length(input)
wordCount
if (wordCount == 0) {
<- "Please enter a word"
pred
}
if (wordCount == 1) {
<- matchBigram(input[1], n)
pred
}
if (wordCount == 2) {
<- matchTrigram(input[1], input[2], n)
pred
}
if (wordCount == 3) {
<- matchQuadgram(input[1], input[2], input[3], n)
pred
}
if (wordCount > 3) {
# match with last three words in input
<- input[(wordCount - 2):wordCount]
input <- matchQuadgram(input[1], input[2], input[3], n)
pred
}
if(NA %in% pred) {
return("No predictions available :(")
}else {
return(pred)
}
}
next_word("President of the United")
## [1] "international" "conservation" "reservations" "communities"
## [5] "autoworkers"
Testing a simple example on the final function, it seems my ngrams model fails to give us States, which has to be the most obvious response to the next word prediction. It also shows you the limitations of a simple ngrams model, since it can only give us predictions that is contained in our ngrams
For models to predict the next word, it needs to remember context and the order of the words. For example, the sentence “I grew up in France, I can speak fluent [MASK]”, if the model saves the information that the person grew up in France (a country), it should be able to produce a prediction that relates to the langauge in France, which is French. Langauge models today are able to do that very well.
Below are two approaches suitable for a next word prediction
Read https://colah.github.io/posts/2015-08-Understanding-LSTMs/ to learn more about LSTM
Resources:
sessionInfo()
## R version 4.0.4 (2021-02-15)
## Platform: x86_64-apple-darwin17.0 (64-bit)
## Running under: macOS Big Sur 10.16
##
## Matrix products: default
## BLAS: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRblas.dylib
## LAPACK: /Library/Frameworks/R.framework/Versions/4.0/Resources/lib/libRlapack.dylib
##
## locale:
## [1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
##
## attached base packages:
## [1] parallel stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] tidytext_0.3.1 multidplyr_0.1.0 feather_0.3.5.9000 here_1.0.1
## [5] forcats_0.5.1 stringr_1.4.0 dplyr_1.0.5 purrr_0.3.4
## [9] readr_1.4.0 tidyr_1.1.3 tibble_3.1.1 ggplot2_3.3.3
## [13] tidyverse_1.3.1
##
## loaded via a namespace (and not attached):
## [1] httr_1.4.2 sass_0.4.0 bit64_4.0.5
## [4] jsonlite_1.7.2 modelr_0.1.8 RcppParallel_5.1.4
## [7] bslib_0.3.0 assertthat_0.2.1 cellranger_1.1.0
## [10] yaml_2.2.1 pillar_1.6.0 backports_1.2.1
## [13] lattice_0.20-41 glue_1.4.2 digest_0.6.27
## [16] rvest_1.0.0 stringfish_0.15.3 colorspace_2.0-1
## [19] htmltools_0.5.2 Matrix_1.3-2 pkgconfig_2.0.3
## [22] broom_0.7.6 haven_2.4.1 scales_1.1.1
## [25] processx_3.5.2 RApiSerialize_0.1.0 generics_0.1.0
## [28] ellipsis_0.3.2 withr_2.4.2 cli_3.0.1
## [31] magrittr_2.0.1 crayon_1.4.1 readxl_1.3.1
## [34] evaluate_0.14 ps_1.6.0 tokenizers_0.2.1
## [37] janeaustenr_0.1.5 fs_1.5.0 fansi_0.4.2
## [40] SnowballC_0.7.0 xml2_1.3.2 tools_4.0.4
## [43] hms_1.0.0 lifecycle_1.0.0 munsell_0.5.0
## [46] reprex_2.0.0 callr_3.7.0 compiler_4.0.4
## [49] jquerylib_0.1.4 qs_0.25.1 rlang_0.4.10
## [52] grid_4.0.4 rstudioapi_0.13 rmarkdown_2.7
## [55] arrow_5.0.0.2 gtable_0.3.0 DBI_1.1.1
## [58] R6_2.5.0 lubridate_1.7.10 knitr_1.31
## [61] fastmap_1.1.0 bit_4.0.4 utf8_1.2.1
## [64] rprojroot_2.0.2 stringi_1.5.3 Rcpp_1.0.7
## [67] vctrs_0.3.8 dbplyr_2.1.1 tidyselect_1.1.1
## [70] xfun_0.22