diff --git a/history_db.py b/history_db.py new file mode 100644 index 0000000..8278fbb --- /dev/null +++ b/history_db.py @@ -0,0 +1,263 @@ +""" +@Author: Liushu +@Date: 2023/04/30 +历史记录和收藏功能的SQLite数据库模块 +""" +import sqlite3 +import os +import json +from datetime import datetime +from utility.utils import config_dict + +class HistoryDB: + def __init__(self, db_path=None): + if db_path is None: + db_path = os.path.join(os.path.dirname(config_dict['db_path']), 'history.db') + self.db_path = db_path + self._init_db() + + def _get_connection(self): + return sqlite3.connect(self.db_path) + + def _init_db(self): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + CREATE TABLE IF NOT EXISTS query_history ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + natural_language TEXT NOT NULL, + generated_sql TEXT, + execution_time TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + is_favorite BOOLEAN DEFAULT 0 + ) + ''') + + cursor.execute(''' + CREATE INDEX IF NOT EXISTS idx_created_at ON query_history(created_at) + ''') + + cursor.execute(''' + CREATE INDEX IF NOT EXISTS idx_is_favorite ON query_history(is_favorite) + ''') + + conn.commit() + conn.close() + + def add_history(self, natural_language, generated_sql=None, execution_time=None): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + INSERT INTO query_history (natural_language, generated_sql, execution_time, created_at) + VALUES (?, ?, ?, ?) + ''', (natural_language, generated_sql, execution_time, datetime.now().strftime('%Y-%m-%d %H:%M:%S'))) + + history_id = cursor.lastrowid + conn.commit() + conn.close() + + return history_id + + def get_history(self, history_id): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + SELECT id, natural_language, generated_sql, execution_time, created_at, is_favorite + FROM query_history + WHERE id = ? + ''', (history_id,)) + + result = cursor.fetchone() + conn.close() + + if result: + return { + 'id': result[0], + 'natural_language': result[1], + 'generated_sql': result[2], + 'execution_time': result[3], + 'created_at': result[4], + 'is_favorite': bool(result[5]) + } + return None + + def get_all_history(self, limit=100, offset=0): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + SELECT id, natural_language, generated_sql, execution_time, created_at, is_favorite + FROM query_history + ORDER BY created_at DESC + LIMIT ? OFFSET ? + ''', (limit, offset)) + + results = cursor.fetchall() + conn.close() + + history_list = [] + for result in results: + history_list.append({ + 'id': result[0], + 'natural_language': result[1], + 'generated_sql': result[2], + 'execution_time': result[3], + 'created_at': result[4], + 'is_favorite': bool(result[5]) + }) + + return history_list + + def get_favorites(self, limit=100, offset=0): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + SELECT id, natural_language, generated_sql, execution_time, created_at, is_favorite + FROM query_history + WHERE is_favorite = 1 + ORDER BY created_at DESC + LIMIT ? OFFSET ? + ''', (limit, offset)) + + results = cursor.fetchall() + conn.close() + + favorites_list = [] + for result in results: + favorites_list.append({ + 'id': result[0], + 'natural_language': result[1], + 'generated_sql': result[2], + 'execution_time': result[3], + 'created_at': result[4], + 'is_favorite': bool(result[5]) + }) + + return favorites_list + + def toggle_favorite(self, history_id): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + SELECT is_favorite FROM query_history WHERE id = ? + ''', (history_id,)) + + result = cursor.fetchone() + if result: + new_favorite = 0 if result[0] else 1 + cursor.execute(''' + UPDATE query_history SET is_favorite = ? WHERE id = ? + ''', (new_favorite, history_id)) + conn.commit() + conn.close() + return bool(new_favorite) + + conn.close() + return None + + def set_favorite(self, history_id, is_favorite): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + UPDATE query_history SET is_favorite = ? WHERE id = ? + ''', (1 if is_favorite else 0, history_id)) + + rows_affected = cursor.rowcount + conn.commit() + conn.close() + + return rows_affected > 0 + + def delete_history(self, history_id): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + DELETE FROM query_history WHERE id = ? + ''', (history_id,)) + + rows_affected = cursor.rowcount + conn.commit() + conn.close() + + return rows_affected > 0 + + def clear_all_history(self): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + DELETE FROM query_history + ''') + + rows_affected = cursor.rowcount + conn.commit() + conn.close() + + return rows_affected + + def update_history(self, history_id, generated_sql=None, execution_time=None): + conn = self._get_connection() + cursor = conn.cursor() + + updates = [] + params = [] + + if generated_sql is not None: + updates.append('generated_sql = ?') + params.append(generated_sql) + + if execution_time is not None: + updates.append('execution_time = ?') + params.append(execution_time) + + if updates: + params.append(history_id) + query = f'UPDATE query_history SET {", ".join(updates)} WHERE id = ?' + cursor.execute(query, params) + + rows_affected = cursor.rowcount + conn.commit() + conn.close() + + return rows_affected > 0 + + conn.close() + return False + + def search_history(self, keyword, limit=50): + conn = self._get_connection() + cursor = conn.cursor() + + cursor.execute(''' + SELECT id, natural_language, generated_sql, execution_time, created_at, is_favorite + FROM query_history + WHERE natural_language LIKE ? OR generated_sql LIKE ? + ORDER BY created_at DESC + LIMIT ? + ''', (f'%{keyword}%', f'%{keyword}%', limit)) + + results = cursor.fetchall() + conn.close() + + history_list = [] + for result in results: + history_list.append({ + 'id': result[0], + 'natural_language': result[1], + 'generated_sql': result[2], + 'execution_time': result[3], + 'created_at': result[4], + 'is_favorite': bool(result[5]) + }) + + return history_list + + +history_db = HistoryDB() diff --git a/main_gui.py b/main_gui.py index 9e89845..c9803ad 100644 --- a/main_gui.py +++ b/main_gui.py @@ -1,5 +1,7 @@ import os import re +import time +from datetime import datetime os.environ["CUDA_VISIBLE_DEVICES"] = "3" import torch from transformers import AutoModel, AutoTokenizer @@ -13,6 +15,7 @@ from local_database import db_operate from utils import obtain_sql, retrieval_related_table, execute_sql from prompt import query_template, chatbot_prompt +from history_db import history_db tokenizer = AutoTokenizer.from_pretrained("./ChatGlm-6b", trust_remote_code=True) @@ -77,17 +80,22 @@ def predict(input, chatbot, history): input_prompt += query_template query = input_prompt.replace("", input) chatbot.append((parse_text(input), "")) - # 流式输出 - # for response, history in model.stream_chat(tokenizer, query, history, max_length=max_length, top_p=top_p, - # temperature=temperature): - # chatbot[-1] = (parse_text(input), parse_text(response)) + + start_time = time.time() response, history = model.chat(tokenizer, query, history=history, max_length=max_length, top_p=top_p,temperature=temperature) + execution_time = f"{time.time() - start_time:.2f}s" + chatbot[-1] = (parse_text(input), parse_text(response)) - # chatbot[-1] = (chatbot[-1][0], chatbot[-1][1]) - # 获取结果中的SQL语句 - response = obtain_sql(response) - chatbot = execute_sql(response, chatbot, dboperate) - return chatbot, history + generated_sql = obtain_sql(response) + chatbot = execute_sql(generated_sql, chatbot, dboperate) + + history_db.add_history( + natural_language=input, + generated_sql=generated_sql, + execution_time=execution_time + ) + + return chatbot, history, refresh_history_list(), refresh_favorites_list() def reset_user_input(): @@ -97,29 +105,190 @@ def reset_user_input(): def reset_state(): return [], [] -with gr.Blocks() as demo: - gr.HTML("""

🤖ChatSQL-GLM

""") - chatbot = gr.Chatbot() +def format_history_item(item, show_favorite=True): + favorite_icon = "⭐ " if item['is_favorite'] else " " + created_at = item['created_at'] + natural_language = item['natural_language'][:50] + "..." if len(item['natural_language']) > 50 else item['natural_language'] + sql_preview = item['generated_sql'][:40] + "..." if item['generated_sql'] and len(item['generated_sql']) > 40 else (item['generated_sql'] or "N/A") + execution_time = item['execution_time'] or "N/A" + + if show_favorite: + return f"{favorite_icon}[{item['id']}] {created_at}\n自然语言: {natural_language}\nSQL: {sql_preview}\n执行时间: {execution_time}" + else: + return f"[{item['id']}] {created_at}\n自然语言: {natural_language}\nSQL: {sql_preview}\n执行时间: {execution_time}" + + +def get_history_list(): + history_items = history_db.get_all_history(limit=50) + return [format_history_item(item) for item in history_items] if history_items else ["暂无历史记录"] + + +def get_favorites_list(): + favorite_items = history_db.get_favorites(limit=50) + return [format_history_item(item) for item in favorite_items] if favorite_items else ["暂无收藏记录"] + + +def refresh_history_list(): + return gr.update(choices=get_history_list(), value=None) + + +def refresh_favorites_list(): + return gr.update(choices=get_favorites_list(), value=None) + + +def parse_history_id_from_selection(selection): + if not selection: + return None + match = re.search(r'\[(\d+)\]', selection) + if match: + return int(match.group(1)) + return None + + +def reuse_history_item(selection, chatbot, history): + history_id = parse_history_id_from_selection(selection) + if not history_id: + return chatbot, history, gr.update(value=''), "请选择有效的历史记录" + + item = history_db.get_history(history_id) + if not item: + return chatbot, history, gr.update(value=''), "历史记录不存在" + + return chatbot, history, gr.update(value=item['natural_language']), f"已加载历史记录: {item['natural_language'][:30]}..." + + +def toggle_favorite_for_history(selection): + history_id = parse_history_id_from_selection(selection) + if not history_id: + return "请选择有效的历史记录", refresh_history_list(), refresh_favorites_list() + + new_status = history_db.toggle_favorite(history_id) + if new_status is None: + return "历史记录不存在", refresh_history_list(), refresh_favorites_list() + + if new_status: + return "已添加到收藏", refresh_history_list(), refresh_favorites_list() + else: + return "已取消收藏", refresh_history_list(), refresh_favorites_list() + + +def delete_history_item(selection): + history_id = parse_history_id_from_selection(selection) + if not history_id: + return "请选择有效的历史记录", refresh_history_list(), refresh_favorites_list() + + success = history_db.delete_history(history_id) + if success: + return "已删除历史记录", refresh_history_list(), refresh_favorites_list() + else: + return "删除失败", refresh_history_list(), refresh_favorites_list() + + +def clear_all_history(): + count = history_db.clear_all_history() + return f"已清除 {count} 条历史记录", refresh_history_list(), refresh_favorites_list() + + +def search_history(keyword): + if not keyword or keyword.strip() == "": + return refresh_history_list(), "请输入搜索关键词" + + results = history_db.search_history(keyword.strip(), limit=50) + if results: + choices = [format_history_item(item) for item in results] + return gr.update(choices=choices, value=None), f"找到 {len(results)} 条相关记录" + else: + return gr.update(choices=["未找到相关记录"], value=None), "未找到相关记录" + + +with gr.Blocks(title="ChatSQL-GLM") as demo: + gr.HTML("""

🤖ChatSQL-GLM

""") + with gr.Row(): + with gr.Column(scale=1, min_width=300): + with gr.Tabs(): + with gr.TabItem("历史记录"): + history_list = gr.Radio( + choices=get_history_list(), + label="查询历史记录", + info="点击选择可复用,支持搜索和收藏", + interactive=True + ) + search_input = gr.Textbox( + label="搜索历史", + placeholder="输入关键词搜索...", + lines=1 + ) + with gr.Row(): + search_btn = gr.Button("搜索", variant="secondary") + refresh_btn = gr.Button("刷新", variant="secondary") + + with gr.Row(): + reuse_btn = gr.Button("复用查询", variant="primary") + favorite_btn = gr.Button("收藏/取消", variant="secondary") + + with gr.Row(): + delete_btn = gr.Button("删除选中", variant="secondary") + clear_all_btn = gr.Button("清空历史", variant="stop") + + history_status = gr.Textbox( + label="状态", + interactive=False, + lines=2 + ) + + with gr.TabItem("收藏夹"): + favorites_list = gr.Radio( + choices=get_favorites_list(), + label="收藏的查询", + info="已收藏的SQL查询,可快速复用", + interactive=True + ) + with gr.Row(): + reuse_favorite_btn = gr.Button("复用收藏", variant="primary") + refresh_favorites_btn = gr.Button("刷新", variant="secondary") + + favorites_status = gr.Textbox( + label="状态", + interactive=False, + lines=2 + ) + with gr.Column(scale=4): - with gr.Column(scale=12): - user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( - container=False) - with gr.Column(min_width=32, scale=1): - submitBtn = gr.Button("Submit", variant="primary") - with gr.Column(scale=1): - emptyBtn = gr.Button("Clear History") - # max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) - # top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) - # temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) + chatbot = gr.Chatbot(height=600) + with gr.Row(): + with gr.Column(scale=4): + with gr.Column(scale=12): + user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style( + container=False) + with gr.Column(min_width=32, scale=1): + submitBtn = gr.Button("Submit", variant="primary") + with gr.Column(scale=1): + emptyBtn = gr.Button("Clear History") + # max_length = gr.Slider(0, 4096, value=2048, step=1.0, label="Maximum length", interactive=True) + # top_p = gr.Slider(0, 1, value=0.7, step=0.01, label="Top P", interactive=True) + # temperature = gr.Slider(0, 1, value=0.95, step=0.01, label="Temperature", interactive=True) history = gr.State([]) - submitBtn.click(predict, [user_input, chatbot, history], [chatbot, history], + submitBtn.click(predict, [user_input, chatbot, history], [chatbot, history, history_list, favorites_list], show_progress=True) submitBtn.click(reset_user_input, [], [user_input]) emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True) + + refresh_btn.click(refresh_history_list, outputs=[history_list]) + refresh_favorites_btn.click(refresh_favorites_list, outputs=[favorites_list]) + + search_btn.click(search_history, [search_input], [history_list, history_status]) + + reuse_btn.click(reuse_history_item, [history_list, chatbot, history], [chatbot, history, user_input, history_status]) + reuse_favorite_btn.click(reuse_history_item, [favorites_list, chatbot, history], [chatbot, history, user_input, favorites_status]) + + favorite_btn.click(toggle_favorite_for_history, [history_list], [history_status, history_list, favorites_list]) + + delete_btn.click(delete_history_item, [history_list], [history_status, history_list, favorites_list]) + clear_all_btn.click(clear_all_history, outputs=[history_status, history_list, favorites_list]) -demo.queue().launch(share=False, inbrowser=True) \ No newline at end of file +demo.queue().launch(share=False, inbrowser=True)