Skip to content

Commit

Permalink
add local embedding and other services
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesHWade committed Jan 18, 2024
1 parent d89ab0f commit 0042b93
Show file tree
Hide file tree
Showing 7 changed files with 113 additions and 26 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ repos:
- id: no-browser-statement
- id: no-debug-statement
- id: deps-in-desc
args: [--allow_private_imports]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
Expand Down
24 changes: 20 additions & 4 deletions R/embedding-py.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,28 @@
# get_transformer_model <- function(model_name = "jinaai/jina-embeddings-v2-base-en") {
# py_pkg_is_available()
# transformer <- reticulate::import("sentence_transformers")
# cli::cli_process_start("Downloading model. This may take a few minutes.")
# model <- transformer$SentenceTransformer(model_name)
# cli::cli_process_done()
# model
# }

# uses transformers instead of sentence transformers
get_transformer_model <- function(model_name = "jinaai/jina-embeddings-v2-base-en") {
py_pkg_is_available()
transformer <- reticulate::import("sentence_transformers")
py_pkg_is_available("transformers")
transformer <- reticulate::import("transformers")

Check warning on line 13 in R/embedding-py.R

View check run for this annotation

Codecov / codecov/patch

R/embedding-py.R#L12-L13

Added lines #L12 - L13 were not covered by tests
cli::cli_process_start("Downloading model. This may take a few minutes.")
model <- transformer$SentenceTransformer(model_name)
model <- transformer$AutoModel$from_pretrained(model_name,
trust_remote_code = TRUE

Check warning on line 16 in R/embedding-py.R

View check run for this annotation

Codecov / codecov/patch

R/embedding-py.R#L15-L16

Added lines #L15 - L16 were not covered by tests
)
cli::cli_process_done()
model
}


create_text_embeddings <- function(text, model) {
model$encode(text) |> as.numeric()
tibble::tibble(
usage = "local",
embedding = model$encode(text) |> as.numeric() |> list()

Check warning on line 26 in R/embedding-py.R

View check run for this annotation

Codecov / codecov/patch

R/embedding-py.R#L24-L26

Added lines #L24 - L26 were not covered by tests
)
}
29 changes: 20 additions & 9 deletions R/embedding.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,12 @@ add_embeddings <- function(index,
model <- get_transformer_model()
index |>
dplyr::mutate(
embeddings = purrr::map(chunks, \(x) {
create_text_embeddings(x, model)
}),
embedding_method = "local"
embeddings = purrr::map(
.x = chunks,
.f = \(x) create_text_embeddings(x, model),
.progress = "Creating Embeddings Locally"

Check warning on line 91 in R/embedding.R

View check run for this annotation

Codecov / codecov/patch

R/embedding.R#L88-L91

Added lines #L88 - L91 were not covered by tests
),
embedding_method = glue::glue("local: {model$name_or_path}")

Check warning on line 93 in R/embedding.R

View check run for this annotation

Codecov / codecov/patch

R/embedding.R#L93

Added line #L93 was not covered by tests
)
} else {
index |>
Expand Down Expand Up @@ -118,11 +120,17 @@ join_embeddings_from_index <- function(x) {
create_index <- function(domain,
overwrite = FALSE,
dont_ask = FALSE,
pkg_version = NULL) {
pkg_version = NULL,
local_embeddings = FALSE) {
index_dir <-
file.path(tools::R_user_dir("gpttools", which = "data"), "index")
index_file <-
glue::glue("{index_dir}/{domain}.parquet")

if (local_embeddings) {
index_file <- glue::glue("{index_dir}/local/{domain}.parquet")

Check warning on line 129 in R/embedding.R

View check run for this annotation

Codecov / codecov/patch

R/embedding.R#L128-L129

Added lines #L128 - L129 were not covered by tests
} else {
index_file <- glue::glue("{index_dir}/{domain}.parquet")

Check warning on line 131 in R/embedding.R

View check run for this annotation

Codecov / codecov/patch

R/embedding.R#L131

Added line #L131 was not covered by tests
}


if (file.exists(index_file) && rlang::is_false(overwrite)) {
cli::cli_abort(
Expand Down Expand Up @@ -155,7 +163,7 @@ create_index <- function(domain,
index <-
index |>
# join_embeddings_from_index() |>
add_embeddings() |>
add_embeddings(local_embeddings = local_embeddings) |>
tidyr::unnest(embeddings) |>

Check warning on line 167 in R/embedding.R

View check run for this annotation

Codecov / codecov/patch

R/embedding.R#L166-L167

Added lines #L166 - L167 were not covered by tests
dplyr::mutate(version = pkg_version)
if (rlang::is_false(dir.exists(index_dir))) {
Expand All @@ -172,10 +180,13 @@ create_index <- function(domain,

get_top_matches <- function(index, query_embedding, k = 5) {
k <- min(k, nrow(index))
print(head(index))

Check warning on line 183 in R/embedding.R

View check run for this annotation

Codecov / codecov/patch

R/embedding.R#L183

Added line #L183 was not covered by tests
index |>
dplyr::mutate(
similarity = purrr::map_dbl(embedding, \(x) {
lsa::cosine(query_embedding, unlist(x))
score <- lsa::cosine(query_embedding, unlist(x))
cli::cli_inform("Similarity: {score}")
score

Check warning on line 189 in R/embedding.R

View check run for this annotation

Codecov / codecov/patch

R/embedding.R#L187-L189

Added lines #L187 - L189 were not covered by tests
})
) |>
dplyr::arrange(dplyr::desc(similarity)) |>
Expand Down
1 change: 1 addition & 0 deletions R/gpt-query.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ gpt_chat <- function(instructions) {
)
)
cli::cli_process_start(msg = "Sending query to OpenAI")
cli::cli_progress_update()

Check warning on line 38 in R/gpt-query.R

View check run for this annotation

Codecov / codecov/patch

R/gpt-query.R#L38

Added line #L38 was not covered by tests
answer <- gptstudio::openai_create_chat_completion(prompt)
cli::cli_process_done(msg_done = "Received response from OpenAI")
text_to_insert <- c(
Expand Down
12 changes: 9 additions & 3 deletions R/harvest-docs.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ crawl <- function(url,
overwrite = FALSE,
num_cores = parallel::detectCores() - 1,
pkg_version = NULL,
use_azure_openai = FALSE) {
service = c("openai", "local", "azure")) {
parsed_url <- urltools::url_parse(url)
local_domain <- parsed_url$domain
url_path <- parsed_url$path
Expand Down Expand Up @@ -181,16 +181,22 @@ crawl <- function(url,
sink = scraped_text_file
)
if (index_create) {
if (use_azure_openai) {
if (service == "azure") {

Check warning on line 184 in R/harvest-docs.R

View check run for this annotation

Codecov / codecov/patch

R/harvest-docs.R#L184

Added line #L184 was not covered by tests
create_index_azure(local_domain_name,
overwrite = overwrite,
pkg_version = pkg_version
)
} else {
} else if (service == "openai") {

Check warning on line 189 in R/harvest-docs.R

View check run for this annotation

Codecov / codecov/patch

R/harvest-docs.R#L189

Added line #L189 was not covered by tests
create_index(local_domain_name,
overwrite = overwrite,
pkg_version = pkg_version
)
} else if (service == "local") {
create_index(local_domain_name,
overwrite = overwrite,
pkg_version = pkg_version,
local_embeddings = TRUE

Check warning on line 198 in R/harvest-docs.R

View check run for this annotation

Codecov / codecov/patch

R/harvest-docs.R#L194-L198

Added lines #L194 - L198 were not covered by tests
)
}
}
}
Expand Down
26 changes: 22 additions & 4 deletions R/history.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ check_context <- function(context) {
#' generate responses.
#'
#' @param query The input query to be processed.
#' @param service Name of the AI service to use, defaults to openai.
#' @param model Name of the openai model to use, defaults to gpt-3.5-turbo
#' @param index Index to look for context.
#' @param add_context Whether to add context to the query or not. Default is
Expand All @@ -147,6 +148,7 @@ check_context <- function(context) {
#' @param save_history Whether to save the chat history or not. Default is TRUE.
#' @param overwrite Whether to overwrite the history file or not. Default is
#' FALSE.
#' @param local Whether to use the local model or not. Default is FALSE.
#'
#' @return A list containing the prompt, context, and answer.
#' @export
Expand All @@ -162,7 +164,8 @@ check_context <- function(context) {
#' result <- chat_with_context(query = query, context = context)
#' }
chat_with_context <- function(query,
model = "gpt-3.5-turbo",
service = "openai",
model = "gpt-4",
index = NULL,
add_context = TRUE,
chat_history = NULL,
Expand All @@ -173,11 +176,19 @@ chat_with_context <- function(query,
k_context = 4,
k_history = 4,
save_history = TRUE,
overwrite = FALSE) {
overwrite = FALSE,
local = FALSE) {
arg_match(task, c("Context Only", "Permissive Chat"))

if (rlang::is_true(add_context) || rlang::is_true(add_history)) {
query_embedding <- get_query_embedding(query)
if (local) {
model <- get_transformer_model()
query_embedding <- create_text_embeddings(query, model = model) |>
dplyr::pull("embedding") |>
unlist()

Check warning on line 188 in R/history.R

View check run for this annotation

Codecov / codecov/patch

R/history.R#L184-L188

Added lines #L184 - L188 were not covered by tests
} else {
query_embedding <- get_query_embedding(query)

Check warning on line 190 in R/history.R

View check run for this annotation

Codecov / codecov/patch

R/history.R#L190

Added line #L190 was not covered by tests
}
}

if (rlang::is_true(add_context)) {
Expand Down Expand Up @@ -286,7 +297,14 @@ chat_with_context <- function(query,

cli::cat_print(prompt)

answer <- query_openai(body = prompt)
answer <-
gptstudio:::gptstudio_create_skeleton(
service = service,
model = model,
prompt = prompt,
stream = FALSE

Check warning on line 305 in R/history.R

View check run for this annotation

Codecov / codecov/patch

R/history.R#L300-L305

Added lines #L300 - L305 were not covered by tests
) |>
gptstudio:::gptstudio_request_perform()

Check warning on line 307 in R/history.R

View check run for this annotation

Codecov / codecov/patch

R/history.R#L307

Added line #L307 was not covered by tests

if (rlang::is_true(save_history)) {
purrr::map(prompt, \(x) {
Expand Down
46 changes: 40 additions & 6 deletions inst/retriever/app.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ make_chat_history <- function(chats) {
purrr::list_flatten()
}

indices <- gpttools::list_index() |> tools::file_path_sans_ext()
api_services <- utils::methods("gptstudio_request_perform") %>%
stringr::str_remove(pattern = "gptstudio_request_perform.gptstudio_request_") %>%
purrr::discard(~ .x == "gptstudio_request_perform.default")

ui <- page_fluid(
waiter::use_waiter(),
Expand Down Expand Up @@ -86,7 +88,7 @@ ui <- page_fluid(
icon = bs_icon("robot", class = "ms-auto"),
selectInput(
"source", "Data Source",
choices = c("All", indices)
choices = NULL
),
selectInput(
"task", "Task",
Expand All @@ -98,15 +100,25 @@ ui <- page_fluid(
accordion_panel(
"Preferences",
icon = bs_icon("sliders", class = "ms-auto"),
selectInput(
"service", "AI Service",
choices = api_services
),
selectInput("model", "Model",
choices = gptstudio::get_available_models(service = "openai")
choices = NULL
),
radioButtons(
"save_history", "Save & Use History",
choiceNames = c("Yes", "No"),
choiceValues = c(TRUE, FALSE),
selected = TRUE, inline = TRUE,
),
radioButtons(
"local", "Local Embeddings",
choiceNames = c("Yes", "No"),
choiceValues = c(TRUE, FALSE),
selected = FALSE, inline = TRUE,
),
sliderInput(
"n_docs", "Docs to Include (#)",
min = 0, max = 20, value = 3
Expand Down Expand Up @@ -148,7 +160,28 @@ server <- function(input, output, session) {
r$all_chats_formatted <- NULL
r$all_chats <- NULL
height <- window_height_server("height")
index <- reactive(load_index(input$source))
index <- reactive({
if (input$local == TRUE) {
load_index(glue::glue("local/{input$source}"))
} else {
load_index(input$source)
}
})
indices <- reactive({
if (input$local == TRUE) {
list_index(dir = "index/local") |> tools::file_path_sans_ext()
} else {
list_index() |> tools::file_path_sans_ext()
}
})
observe(
updateSelectInput(
session,
"model",
choices = gptstudio::get_available_models(service = input$service)
)
)
observe(updateSelectInput(session, "source", choices = indices()))
observe({
waiter_show(
html = tagList(
Expand All @@ -158,7 +191,7 @@ server <- function(input, output, session) {
color = waiter::transparent(0.5)
)
if (is.null(input$model)) {
input$model <- "gpt-3.5-turbo"
input$model <- "gpt-4"
}

interim <- chat_with_context(
Expand All @@ -173,7 +206,8 @@ server <- function(input, output, session) {
k_context = input$n_docs,
k_history = input$n_history,
save_history = input$save_history,
overwrite = FALSE
overwrite = FALSE,
local = input$local
)
new_response <- interim[[3]]$choices
r$context_links <- c(r$context_links, interim[[2]]$link)
Expand Down

0 comments on commit 0042b93

Please sign in to comment.