Skip to content

Commit

Permalink
MariaDB Knowledge Base Chat
Browse files Browse the repository at this point in the history
  • Loading branch information
vladbogo committed May 1, 2023
1 parent 801e8bc commit ae70777
Show file tree
Hide file tree
Showing 3 changed files with 193 additions and 0 deletions.
31 changes: 31 additions & 0 deletions kb_chat/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# MariaDB KB Chat and Vector Store Generator

This script scrapes web pages from the MariaDB Knowledge Base, cleans and processes the content, and then generates a FAISS index using the OpenAI embeddings for each document. The vector store is saved as a pickle file and then used by a chatbot to answer questions about the MariaDB server.

## Requirements

Install the required packages with the following command:

pip install argparse bs4 dotenv faiss-cpu openai requests numpy streamlit

## Setup

1. Download the MariaDB KB CSV file from https://github.com/Icerath/mariadb_kb_server/blob/main/kb_urls.csv
2. Create a `.env` file in the same directory as the script.
3. Add your OpenAI API key to the `.env` file as follows:

OPENAI_API_KEY=your_api_key_here

## Preprocessing

Run the script with the following command:

python create_vectorestore.py --csv-file kb_urls.csv --tmp-dir tmp --md-dir md --vectorstore-path vectorstore.pkl --chunk-size 4000 --chunk-overlap 200

This will create a file `vectorestore.pkl` which is used to answer questions

## Run chat

streamlit run chat.py

Now, you will have a self hosted version of the chat over the MariaDB KB.
70 changes: 70 additions & 0 deletions kb_chat/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import streamlit as st
import pickle
from langchain.vectorstores import FAISS
from dotenv import load_dotenv
import openai
import os

load_dotenv()

openai.api_key = os.getenv("OPENAI_API_KEY")

def gen_prompts(content, question):
system_msg_content = "You are a questioning answering expert about MariaDB. You only respond based on the facts that are given to you and ignore your prior knowledge."
user_msg_content = f"{content}\n---\n\nGiven the above content about MariaDB along with the URL of the content, respond to this question {question} and mention the URL as a source. If the question is not about MariaDB and you cannot answer it based on the provided content, politely decline to answer. Simply state that you couldn't find any relevant information instead of going into details. Do not say the phrase 'in the provided content'. If the information I provide contains the word obsolete, emphasize that the response is obsolete. Also, suggest newer MariaDB versions if the question is about versions older than 10.3 and say that the others are no longer maintained. Do not add the URL as a source if you cannot answer based on the provided content. If there are exceptions for particular MariaDB version, specify the exceptions that apply. Also, if the provided score is lower than 0.2 decline to answer and say you found no relevant information. If the source URL repeats, only use it once."
system_msg = {"role": "system", "content": system_msg_content}
user_msg = {"role": "user", "content": user_msg_content}

return system_msg, user_msg

def process_doc(content, question, model_type="gpt-4", max_tokens=30000):
if len(content) > max_tokens:
print('Trimmed')
content = content[:max_tokens]
system_msg, user_msg = gen_prompts(content, question)

try:
response = openai.ChatCompletion.create(
model=model_type,
messages=[system_msg, user_msg],
)
except Exception as e:
return "Sorry, there was an error. Please try again!"

result = response.choices[0].message['content']
return result

with open("vectorstore.pkl", "rb") as f:
faiss_index = pickle.load(f)

def search_similar_docs(question, k=4):
docs = faiss_index.similarity_search_with_score(question, k=k)
docs_with_url = []
for doc in docs:
url = doc[0].metadata["source"]
doc[0].page_content = f"URL: {url}\n{doc[0].page_content}\nSCORE:{doc[1]}\n"
docs_with_url.append(doc[0])
print(docs)
return docs_with_url

def main():
st.title("MariaDB KB Chatbot")

if 'chat_history' not in st.session_state:
st.session_state.chat_history = []

user_input = st.text_input("Ask a question:", "")
if st.button("Send"):
st.session_state.chat_history.append(("User", user_input))
results = process_doc(search_similar_docs(user_input), user_input)

st.session_state.chat_history.append(("Bot", results))

for role, message in st.session_state.chat_history:
if role == "User":
st.markdown(f"> **{role}**: {message}")
else:
st.markdown(f"**{role}**: {message}")

if __name__ == "__main__":
main()
92 changes: 92 additions & 0 deletions kb_chat/create_vectorstore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import argparse
import pickle
import os
import csv
import openai
import re
import requests

from langchain.document_loaders import BSHTMLLoader
from langchain.vectorstores import FAISS
from langchain.text_splitter import CharacterTextSplitter
from langchain.embeddings.openai import OpenAIEmbeddings
from dotenv import load_dotenv

load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

def parse_args():
parser = argparse.ArgumentParser(description='MariaDB KB Vector Store Generator')
parser.add_argument('--csv-file', type=str, default='kb_urls.csv', help='Path to the input CSV file containing the URLs')
parser.add_argument('--tmp-dir', type=str, default='tmp', help='Directory where the temporary HTML files will be stored')
parser.add_argument('--md-dir', type=str, default='md', help='Directory where the output Markdown files will be stored')
parser.add_argument('--vectorstore-path', type=str, default='vectorstore.pkl', help='Path to save the generated FAISS vector store pickle file')
parser.add_argument('--chunk-size', type=int, default=4000, help='Chunk size for splitting the documents')
parser.add_argument('--chunk-overlap', type=int, default=200, help='Overlap size between chunks when splitting documents')
return parser.parse_args()

def download_web_page(url):
response = requests.get(url)

if response.status_code == 200:
content = response.text
filename = url.replace('://', '_').replace('/', '_') + '.html'

with open('./tmp/' + filename, 'w', encoding='utf-8') as file:
file.write(content)
else:
print(f"Error: Unable to fetch the web page. Status code: {response.status_code}")

def read_csv(csv_file):
urls = []

with open(csv_file, newline='', encoding='utf-8') as csvfile:
csv_reader = csv.reader(csvfile)
for row in csv_reader:
if row[0].strip():
urls.append(row[0])

return urls[1:]

def main():
args = parse_args()

urls = read_csv(args.csv_file)
all_docs = []
idx = 0
for url in urls:
filename = url.replace('://', '_').replace('/', '_').strip() + '.html'
doc_path = args.tmp_dir + '/' + filename
if not os.path.exists(doc_path):
download_web_page(url)
loader = BSHTMLLoader(doc_path)
doc = loader.load()[0]

content = re.sub(r'\s+', ' ', doc.page_content)
doc.page_content = content
doc.metadata["source"] = url

md_filename = os.path.join(args.md_dir, f'{filename}.md')

with open(md_filename, 'w', encoding='utf-8') as md_file:
md_file.write(doc.page_content)

all_docs.append(doc)

text_splitter = CharacterTextSplitter(
separator = " ",
chunk_size = args.chunk_size,
chunk_overlap = args.chunk_overlap,
length_function = len,
)
print("Loaded {} documents".format(len(all_docs)))
all_docs = text_splitter.split_documents(all_docs)
print("After split: {} documents".format(len(all_docs)))

faiss_index = FAISS.from_documents(all_docs, OpenAIEmbeddings())

with open(args.vectorstore_path, "wb") as f:
pickle.dump(faiss_index, f)

if __name__ == "__main__":
main()

0 comments on commit ae70777

Please sign in to comment.