So I am trying to create a app which is Text To SQL LLM App with MYSQL , SteamLit & LangChain.
I am stuck as the code is showing errors which I cant figure out.
Here is my app.py
import streamlit as st
from app import *
st.title("Chat with Database")
question = st.text_input("Question: ")
if question:
chain, query = db_chain()
response_query = query.invoke({"question": f"{question}"})
response = chain.invoke({"question": f"{question}"})
st.header("SQL Query")
st.write(response_query +";")
st.header("Answer")
st.write(response)
Here is my main.py
from langchain.prompts import SemanticSimilarityExampleSelector
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import Chroma
from langchain.prompts import FewShotPromptTemplate
from langchain.chains.sql_database.prompt import PROMPT_SUFFIX
from langchain.prompts.prompt import PromptTemplate
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
from operator import itemgetter
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from examples import examples
import os
from dotenv import load_dotenv
import textwrap
import functools
class MyEmbeddings(HuggingFaceEmbeddings):
def __call__(self, input):
return super().__call__(input)
# take environment variables from .env (especially google api key)
load_dotenv
from langchain_google_genai import GoogleGenerativeAI
from langchain_community.utilities import SQLDatabase
from langchain.chains import create_sql_query_chain
def split_question(question, max_length=45000):
chunks = textwrap.wrap(question, max_length, break_long_words=False)
return chunks
def db_chain():
db_user = "root"
db_password = "***"
db_host = "localhost"
db_name = "test"
db = SQLDatabase.from_uri(f"mysql+pymysql://{db_user}:{db_password}@{db_host}/{db_name}",
sample_rows_in_table_info=3)
print(db.table_info)
llm = GoogleGenerativeAI(model="models/text-bison-001", google_api_key=os.environ["GOOGLE_API_KEY"], temperature=0.1)
embeddings = MyEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2')
#def embedding_fn(input):
#return embeddings.embed_query(input)
#to_vectorize = [" ".join(examples.values()) for example in examples]
to_vectorize = [" ".join(d.values()) for d in examples]
vectorstore = Chroma.from_texts(to_vectorize,embeddings, metadatas=examples, persist_directory='D:\Project_LangChain\Text_to_mYSQL_Google_Palm_v2\chroma_db')
example_selector = SemanticSimilarityExampleSelector(
vectorstore=vectorstore,
k=2
)
mysql_prompt = """You are a MySQL expert. Given an input question, first create a syntactically correct MySQL query to run, then look at the results of the query and return the answer to the input question.
You can order the results to return the most informative data in the database. Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
Pay attention to use CURDATE() function to get the current date, if the question involves "today".
Use the following format:
Question: Question here
SQLQuery: Query to run with no pre-amble
SQLResult: Result of the SQLQuery
Answer: Final answer here
No pre-amble.
Only use the following tables:
{table_info}
Question: {input}
SQLQuery: {sql_query}
SQLResult: {sql_result}
Answer: To limit the number of results to {top_k}:"""
mysql_prompt = PromptTemplate(input_variables=["input", "table_info", "top_k", "sql_query", "sql_result"],
template=mysql_prompt)
answer_prompt = PromptTemplate(
input_variables=["Question", "SQLQuery", "SQLResult", "Answer"],
template="nQuestion: {Question}nSQLQuery: {SQLQuery}nSQLResult: {SQLResult}nAnswer: {Answer}",
)
example_prompt = PromptTemplate(
input_variables=["Question", "SQLQuery", "SQLResult", "Answer", ],
template="nQuestion: {Question}nSQLQuery: {SQLQuery}nSQLResult: {SQLResult}nAnswer: {Answer}",
)
#few_shot_prompt = FewShotPromptTemplate(
#example_selector=example_selector,
#example_prompt=example_prompt,
#prefix=mysql_prompt,
#suffix=PROMPT_SUFFIX,
#input_variables=["input", "table", "top_k"],) # These variables are used in the prefix and suffix
few_shot_prompt = FewShotPromptTemplate(
example_selector=example_selector,
example_prompt=example_prompt,
prefix=mysql_prompt,
suffix="",
input_variables=["input", "table_info", "top_k"], # These variables are used in the prefix and suffix
)
execute_query = QuerySQLDataBaseTool(db=db)
write_query = create_sql_query_chain(llm, db, prompt=few_shot_prompt)
# chain = write_query | execute_query
answer = answer_prompt | llm | StrOutputParser()
query = write_query
def handle_question(question):
question_chunks = split_question(question)
combined_result = ""
for chunk in question_chunks:
response_query = query.invoke({"question": chunk})
combined_result += response_query
return combined_result
chain = (
RunnablePassthrough.assign(query=write_query).assign(
result=itemgetter("query") | execute_query
)
| answer
| functools.partial(handle_question)
)
return chain, write_query
Here is my examples.py
examples = [
{'Question' : "How many districts?",
'SQLQuery' : "SELECT COUNT(*) FROM districts",
'SQLResult' : "Result of the SQL Query",
'Answer' : "38"
}
]
The issue is I cant figure out what is the issue.Connection to database is established but throwing errors.
Let me know for any further clarification if needed.
Fix the code if needed