From ae70777c4a0ce9d479fc54c0defb0359bf7c4a32 Mon Sep 17 00:00:00 2001 From: Vlad Bogolin Date: Mon, 1 May 2023 22:44:35 +0300 Subject: [PATCH] MariaDB Knowledge Base Chat --- kb_chat/README.md | 31 ++++++++++++ kb_chat/chat.py | 70 ++++++++++++++++++++++++++ kb_chat/create_vectorstore.py | 92 +++++++++++++++++++++++++++++++++++ 3 files changed, 193 insertions(+) create mode 100644 kb_chat/README.md create mode 100644 kb_chat/chat.py create mode 100644 kb_chat/create_vectorstore.py diff --git a/kb_chat/README.md b/kb_chat/README.md new file mode 100644 index 00000000..7bcabed3 --- /dev/null +++ b/kb_chat/README.md @@ -0,0 +1,31 @@ +# MariaDB KB Chat and Vector Store Generator + +This script scrapes web pages from the MariaDB Knowledge Base, cleans and processes the content, and then generates a FAISS index using the OpenAI embeddings for each document. The vector store is saved as a pickle file and then used by a chatbot to answer questions about the MariaDB server. + +## Requirements + +Install the required packages with the following command: + +pip install argparse bs4 dotenv faiss-cpu openai requests numpy streamlit + +## Setup + +1. Download the MariaDB KB CSV file from https://github.com/Icerath/mariadb_kb_server/blob/main/kb_urls.csv +2. Create a `.env` file in the same directory as the script. +3. Add your OpenAI API key to the `.env` file as follows: + +OPENAI_API_KEY=your_api_key_here + +## Preprocessing + +Run the script with the following command: + +python create_vectorestore.py --csv-file kb_urls.csv --tmp-dir tmp --md-dir md --vectorstore-path vectorstore.pkl --chunk-size 4000 --chunk-overlap 200 + +This will create a file `vectorestore.pkl` which is used to answer questions + +## Run chat + +streamlit run chat.py + +Now, you will have a self hosted version of the chat over the MariaDB KB. diff --git a/kb_chat/chat.py b/kb_chat/chat.py new file mode 100644 index 00000000..e85170ce --- /dev/null +++ b/kb_chat/chat.py @@ -0,0 +1,70 @@ +import streamlit as st +import pickle +from langchain.vectorstores import FAISS +from dotenv import load_dotenv +import openai +import os + +load_dotenv() + +openai.api_key = os.getenv("OPENAI_API_KEY") + +def gen_prompts(content, question): + system_msg_content = "You are a questioning answering expert about MariaDB. You only respond based on the facts that are given to you and ignore your prior knowledge." + user_msg_content = f"{content}\n---\n\nGiven the above content about MariaDB along with the URL of the content, respond to this question {question} and mention the URL as a source. If the question is not about MariaDB and you cannot answer it based on the provided content, politely decline to answer. Simply state that you couldn't find any relevant information instead of going into details. Do not say the phrase 'in the provided content'. If the information I provide contains the word obsolete, emphasize that the response is obsolete. Also, suggest newer MariaDB versions if the question is about versions older than 10.3 and say that the others are no longer maintained. Do not add the URL as a source if you cannot answer based on the provided content. If there are exceptions for particular MariaDB version, specify the exceptions that apply. Also, if the provided score is lower than 0.2 decline to answer and say you found no relevant information. If the source URL repeats, only use it once." + system_msg = {"role": "system", "content": system_msg_content} + user_msg = {"role": "user", "content": user_msg_content} + + return system_msg, user_msg + +def process_doc(content, question, model_type="gpt-4", max_tokens=30000): + if len(content) > max_tokens: + print('Trimmed') + content = content[:max_tokens] + system_msg, user_msg = gen_prompts(content, question) + + try: + response = openai.ChatCompletion.create( + model=model_type, + messages=[system_msg, user_msg], + ) + except Exception as e: + return "Sorry, there was an error. Please try again!" + + result = response.choices[0].message['content'] + return result + +with open("vectorstore.pkl", "rb") as f: + faiss_index = pickle.load(f) + +def search_similar_docs(question, k=4): + docs = faiss_index.similarity_search_with_score(question, k=k) + docs_with_url = [] + for doc in docs: + url = doc[0].metadata["source"] + doc[0].page_content = f"URL: {url}\n{doc[0].page_content}\nSCORE:{doc[1]}\n" + docs_with_url.append(doc[0]) + print(docs) + return docs_with_url + +def main(): + st.title("MariaDB KB Chatbot") + + if 'chat_history' not in st.session_state: + st.session_state.chat_history = [] + + user_input = st.text_input("Ask a question:", "") + if st.button("Send"): + st.session_state.chat_history.append(("User", user_input)) + results = process_doc(search_similar_docs(user_input), user_input) + + st.session_state.chat_history.append(("Bot", results)) + + for role, message in st.session_state.chat_history: + if role == "User": + st.markdown(f"> **{role}**: {message}") + else: + st.markdown(f"**{role}**: {message}") + +if __name__ == "__main__": + main() diff --git a/kb_chat/create_vectorstore.py b/kb_chat/create_vectorstore.py new file mode 100644 index 00000000..84775fcb --- /dev/null +++ b/kb_chat/create_vectorstore.py @@ -0,0 +1,92 @@ +import argparse +import pickle +import os +import csv +import openai +import re +import requests + +from langchain.document_loaders import BSHTMLLoader +from langchain.vectorstores import FAISS +from langchain.text_splitter import CharacterTextSplitter +from langchain.embeddings.openai import OpenAIEmbeddings +from dotenv import load_dotenv + +load_dotenv() +openai.api_key = os.getenv("OPENAI_API_KEY") + +def parse_args(): + parser = argparse.ArgumentParser(description='MariaDB KB Vector Store Generator') + parser.add_argument('--csv-file', type=str, default='kb_urls.csv', help='Path to the input CSV file containing the URLs') + parser.add_argument('--tmp-dir', type=str, default='tmp', help='Directory where the temporary HTML files will be stored') + parser.add_argument('--md-dir', type=str, default='md', help='Directory where the output Markdown files will be stored') + parser.add_argument('--vectorstore-path', type=str, default='vectorstore.pkl', help='Path to save the generated FAISS vector store pickle file') + parser.add_argument('--chunk-size', type=int, default=4000, help='Chunk size for splitting the documents') + parser.add_argument('--chunk-overlap', type=int, default=200, help='Overlap size between chunks when splitting documents') + return parser.parse_args() + +def download_web_page(url): + response = requests.get(url) + + if response.status_code == 200: + content = response.text + filename = url.replace('://', '_').replace('/', '_') + '.html' + + with open('./tmp/' + filename, 'w', encoding='utf-8') as file: + file.write(content) + else: + print(f"Error: Unable to fetch the web page. Status code: {response.status_code}") + +def read_csv(csv_file): + urls = [] + + with open(csv_file, newline='', encoding='utf-8') as csvfile: + csv_reader = csv.reader(csvfile) + for row in csv_reader: + if row[0].strip(): + urls.append(row[0]) + + return urls[1:] + +def main(): + args = parse_args() + + urls = read_csv(args.csv_file) + all_docs = [] + idx = 0 + for url in urls: + filename = url.replace('://', '_').replace('/', '_').strip() + '.html' + doc_path = args.tmp_dir + '/' + filename + if not os.path.exists(doc_path): + download_web_page(url) + loader = BSHTMLLoader(doc_path) + doc = loader.load()[0] + + content = re.sub(r'\s+', ' ', doc.page_content) + doc.page_content = content + doc.metadata["source"] = url + + md_filename = os.path.join(args.md_dir, f'{filename}.md') + + with open(md_filename, 'w', encoding='utf-8') as md_file: + md_file.write(doc.page_content) + + all_docs.append(doc) + + text_splitter = CharacterTextSplitter( + separator = " ", + chunk_size = args.chunk_size, + chunk_overlap = args.chunk_overlap, + length_function = len, + ) + print("Loaded {} documents".format(len(all_docs))) + all_docs = text_splitter.split_documents(all_docs) + print("After split: {} documents".format(len(all_docs))) + + faiss_index = FAISS.from_documents(all_docs, OpenAIEmbeddings()) + + with open(args.vectorstore_path, "wb") as f: + pickle.dump(faiss_index, f) + +if __name__ == "__main__": + main()