diff --git a/.gitignore b/.gitignore index 2c05ed1..0e323d0 100644 --- a/.gitignore +++ b/.gitignore @@ -48,6 +48,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +results/ # Translations *.mo diff --git a/scripts/test_spider2_eval.py b/scripts/test_spider2_eval.py new file mode 100755 index 0000000..4837dbc --- /dev/null +++ b/scripts/test_spider2_eval.py @@ -0,0 +1,576 @@ +#!/usr/bin/env python3 +""" +Script de Avaliação do Agente contra Spider 2.0 Lite Dataset. + +Testa o agente Text-to-Insight contra perguntas reais do Spider 2.0 Lite dataset, +usando a classe InsightEngine do pacote text_to_insight. Foca especificamente +nas instâncias "local" que correspondem a bancos SQLite. + +Uso: + python scripts/test_spider2_eval.py --sample-size 10 --seed 42 + python scripts/test_spider2_eval.py --db-filter E_commerce --output reports/eval_spider2.csv +""" + +import argparse +import json +import os +import sys +import time +from datetime import datetime +from pathlib import Path +import random +import glob +import pandas as pd + +# Add project root to path +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from dotenv import load_dotenv + +# Importar InsightEngine do pacote text_to_insight +from text_to_insight import InsightEngine + +from src.spider.csv_reporter import CSVReporter +from src.spider.metrics import ( + build_comparison_row, + results_exact_match, + results_f1_score, + sql_similarity_score, +) +from src.spider.query_executor import SpiderQueryExecutor +from src.spider.analise_empirica import gerar_relatorio_empirico_completo + +load_dotenv() + + +class Spider2QueryExecutor(SpiderQueryExecutor): + """ + Executor adaptado para o Spider 2.0 Lite, + onde os bancos locais geralmente estão na raiz da pasta. + """ + def get_db_path(self, db_id: str) -> Path: + # Tenta na raiz + db_path = self.database_dir / f"{db_id}.sqlite" + if not db_path.exists(): + # Tenta na subpasta como no Spider 1.0 + db_path = self.database_dir / db_id / f"{db_id}.sqlite" + if not db_path.exists(): + raise FileNotFoundError(f"Banco não encontrado: {db_path} (nem na subpasta)") + return db_path + + +def load_spider2_examples(data_dir: str) -> list[dict]: + """Carrega as instâncias do spider2-lite.jsonl""" + jsonl_path = Path(data_dir) / "spider2-lite.jsonl" + if not jsonl_path.exists(): + raise FileNotFoundError(f"Arquivo não encontrado: {jsonl_path}") + + examples = [] + with open(jsonl_path, "r", encoding="utf-8") as f: + for line in f: + if line.strip(): + examples.append(json.loads(line)) + return examples + + +def get_gold_sql(data_dir: str, instance_id: str) -> str: + """Lê a query gold da pasta evaluation_suite/gold/sql/""" + sql_path = Path(data_dir) / "evaluation_suite" / "gold" / "sql" / f"{instance_id}.sql" + if not sql_path.exists(): + return "" + with open(sql_path, "r", encoding="utf-8") as f: + return f.read().strip() + + +def get_gold_results(data_dir: str, instance_id: str) -> list[list[dict]]: + """Carrega os resultados gold (CSVs) para um dado instance_id. + Pode haver múltiplos CSVs (e.g. local040_a.csv, local040_b.csv). + """ + exec_result_dir = Path(data_dir) / "evaluation_suite" / "gold" / "exec_result" + + # Try exact match first + exact_match = exec_result_dir / f"{instance_id}.csv" + if exact_match.exists(): + try: + return [pd.read_csv(exact_match).to_dict(orient="records")] + except Exception: + pass + + # Try multiple (e.g. _a, _b) + pattern = str(exec_result_dir / f"{instance_id}_*.csv") + files = sorted(glob.glob(pattern)) + results = [] + for f in files: + try: + results.append(pd.read_csv(f).to_dict(orient="records")) + except Exception: + pass + + return results + + +def _gerar_relatorio_md( + report_path: str, + summary: dict, + f1_medio: float, + exact_match_rate: float, + all_rows: list[dict], + mismatches: list[dict], + model: str, + sample_size: int, + seed: int, + data_dir: str, +) -> None: + """Gera um relatório textual em Markdown com estatísticas e detalhes de mismatches.""" + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + Path(report_path).parent.mkdir(parents=True, exist_ok=True) + + lines = [] + lines.append("# Spider 2.0 Lite Evaluation Report") + lines.append("") + lines.append(f"**Gerado em:** {timestamp}") + lines.append("") + + # --- Configuração --- + lines.append("## Configuração") + lines.append("") + lines.append(f"| Parâmetro | Valor |") + lines.append(f"|-----------|-------|") + lines.append(f"| Modelo | `{model}` |") + lines.append(f"| Sample size | {sample_size} |") + lines.append(f"| Seed | {seed} |") + lines.append(f"| Data dir | `{data_dir}` |") + lines.append("") + + # --- Resumo --- + lines.append("## Resumo") + lines.append("") + lines.append(f"| Métrica | Valor |") + lines.append(f"|---------|-------|") + lines.append(f"| Total de perguntas | {summary['total_perguntas']} |") + lines.append(f"| Total de tentativas | {summary['total_tentativas']} |") + lines.append(f"| Perguntas aprovadas (crítico) | {summary['perguntas_aprovadas']} |") + lines.append(f"| Taxa de aprovação | {summary['taxa_aprovacao']:.1%} |") + lines.append(f"| Taxa de sucesso na 1ª tentativa | {summary['taxa_1a_tentativa']:.1%} |") + lines.append(f"| Tentativas médias por pergunta | {summary['tentativas_media']:.2f} |") + lines.append(f"| Similarity score médio (SQL) | {summary['similarity_media']:.4f} |") + lines.append(f"| F1 score médio (resultados) | {f1_medio:.4f} |") + lines.append(f"| Exact match rate | {exact_match_rate:.1%} |") + lines.append(f"| Mismatches | {len(mismatches)}/{len(all_rows)} |") + lines.append(f"| Tempo médio por tentativa | {summary['tempo_medio_ms']:.0f} ms |") + lines.append("") + + # --- Tabela por pergunta --- + lines.append("## Resultados por Pergunta") + lines.append("") + lines.append("| Instance ID | DB | Pergunta | Match | F1 | Similarity | Veredito |") + lines.append("|-------------|----|----------|-------|----|------------|----------|") + for r in all_rows: + pergunta_curta = str(r['pergunta_usuario'])[:50] + match_icon = "✅" if r['resultado_exato_match'] is True else ("❌" if r['resultado_exato_match'] is False else "⚠️") + lines.append( + f"| {r['id_exemplo']} " + f"| {r['db_id']} " + f"| {pergunta_curta}... " + f"| {match_icon} " + f"| {r.get('resultado_f1', 0):.2f} " + f"| {r['similarity_score_sql']:.2f} " + f"| {r['veredito_critico']} |" + ) + lines.append("") + + # --- Detalhes dos mismatches --- + if mismatches: + lines.append("## Detalhes dos Mismatches") + lines.append("") + lines.append(f"Total: **{len(mismatches)}** perguntas não obtiveram exact match.") + lines.append("") + + for i, m in enumerate(mismatches, 1): + lines.append(f"### Mismatch {i} — Instância `{m['id']}` (`{m['db_id']}`)") + lines.append("") + lines.append(f"**Pergunta:** {m['pergunta']}") + lines.append("") + lines.append(f"**F1:** {m['f1']:.4f} | **Precision:** {m['precision']:.4f} | **Recall:** {m['recall']:.4f}") + lines.append("") + + lines.append("**Query Ouro (Spider):**") + lines.append(f"```sql\n{m['query_ouro']}\n```\n") + + lines.append("**Query Agente:**") + lines.append(f"```sql\n{m['query_agente']}\n```\n") + + lines.append("**Resultado Ouro** (primeiras 20 linhas):") + lines.append("") + ouro_sample = m['resultado_ouro'][:20] + if ouro_sample: + cols = list(ouro_sample[0].keys()) + lines.append("| " + " | ".join(cols) + " |") + lines.append("| " + " | ".join(["---"] * len(cols)) + " |") + for row in ouro_sample: + vals = [str(row.get(c, "")) for c in cols] + lines.append("| " + " | ".join(vals) + " |") + if len(m['resultado_ouro']) > 20: + lines.append(f"*... e mais {len(m['resultado_ouro']) - 20} linhas*") + else: + lines.append("*(vazio)*") + lines.append("") + + lines.append("**Resultado Agente** (primeiras 20 linhas):") + lines.append("") + agent_sample = m['resultado_agente'][:20] + if agent_sample: + cols = list(agent_sample[0].keys()) + lines.append("| " + " | ".join(cols) + " |") + lines.append("| " + " | ".join(["---"] * len(cols)) + " |") + for row in agent_sample: + vals = [str(row.get(c, "")) for c in cols] + lines.append("| " + " | ".join(vals) + " |") + if len(m['resultado_agente']) > 20: + lines.append(f"*... e mais {len(m['resultado_agente']) - 20} linhas*") + else: + lines.append("*(vazio)*") + lines.append("") + lines.append("---") + lines.append("") + else: + lines.append("## Detalhes dos Mismatches") + lines.append("") + lines.append("🎉 **Nenhum mismatch!** Todos os resultados foram exact match.") + lines.append("") + + with open(report_path, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) + + +def main(): + parser = argparse.ArgumentParser(description="Avaliar agente Text-to-Insight contra Spider 2.0 Lite") + parser.add_argument("--sample-size", type=int, default=10, help="Quantas perguntas testar") + parser.add_argument("--seed", type=int, default=42, help="Seed para reproducibilidade") + parser.add_argument("--db-filter", type=str, help="Filtrar por banco específico (ex: E_commerce)") + parser.add_argument("--output", type=str, help="Caminho para salvar CSV") + parser.add_argument("--max-attempts", type=int, default=3, help="Máximo de tentativas por pergunta") + parser.add_argument("--data-dir", type=str, default="spider2-lite", help="Diretório base do Spider 2 Lite") + parser.add_argument("--sqlite-dir", type=str, default="spider2-lite/resource/databases/spider2-localdb", help="Diretório contendo os bancos sqlite do Spider 2") + parser.add_argument("--question-filter", type=str, help="Filtrar por um trecho da pergunta") + parser.add_argument("--model", type=str, default="gpt-4o-mini", help="Modelo LLM a utilizar") + parser.add_argument("--with-graphs", action="store_true", help="Ativar a geração de gráficos e salvamento de CSV") + parser.add_argument("--report-dir", type=str, default="", help="Pasta dentro de 'reports' para salvar os relatórios .md") + + args = parser.parse_args() + + # Validar API key + model = args.model + + api_key = os.getenv("OPENAI_API_KEY") if "gpt" in model.lower() else os.getenv("GOOGLE_API_KEY") + if not api_key: + print("❌ Erro: Chave API não encontrada em .env") + sys.exit(1) + + print(f"\n📂 Carregando exemplos do Spider 2 Lite de {args.data_dir}...") + try: + exemplos = load_spider2_examples(args.data_dir) + print(f"✓ Carregados {len(exemplos)} exemplos totais") + except FileNotFoundError as e: + print(f"❌ {e}") + sys.exit(1) + + # Filtrar apenas os locais (SQLite) já que o framework suporta SQLite nativamente + exemplos = [ex for ex in exemplos if ex.get("instance_id", "").startswith("local")] + print(f"✓ Filtrados para {len(exemplos)} exemplos baseados em SQLite (prefixo 'local')") + + if args.db_filter: + exemplos = [ex for ex in exemplos if ex.get("db") == args.db_filter] + print(f"✓ Filtrados por db_id={args.db_filter}: {len(exemplos)} exemplos") + + if args.question_filter: + exemplos = [ex for ex in exemplos if args.question_filter.lower() in ex.get("question", "").lower()] + print(f"✓ Filtrados pela pergunta '{args.question_filter}': {len(exemplos)} exemplos") + + if args.seed is not None: + random.seed(args.seed) + if args.sample_size is not None and args.sample_size < len(exemplos): + exemplos = random.sample(exemplos, k=args.sample_size) + + print(f"✓ Selecionados {len(exemplos)} exemplos para teste.") + + print("\n🔧 Inicializando componentes...") + executor = Spider2QueryExecutor(database_dir=args.sqlite_dir) + print("✓ Query executor inicializado") + + csv_path = args.output if args.output else f"reports/{CSVReporter.generate_timestamped_filename('spider2_eval')}" + reporter = CSVReporter(csv_path) + print(f"✓ CSV reporter inicializado: {csv_path}") + + print(f"\n🚀 Iniciando avaliação com {len(exemplos)} perguntas...\n") + print("=" * 100) + + all_rows = [] + mismatches = [] + engine_cache = {} + + for idx, ex in enumerate(exemplos, 1): + instance_id = ex.get("instance_id") + db_id = ex.get("db", "") + pergunta = ex.get("question", "") + + # Recuperar query ouro e/ou csvs ouro + query_ouro = get_gold_sql(args.data_dir, instance_id) + gold_results_list = get_gold_results(args.data_dir, instance_id) + + if not query_ouro and not gold_results_list: + print(f"\n[{idx}/{len(exemplos)}] ⚠️ Nenhuma query ouro nem resultado CSV encontrados para {instance_id}. Pulando.") + continue + + print(f"\n[{idx}/{len(exemplos)}] Instance: {instance_id} | DB: {db_id} | Pergunta: {pergunta[:60]}...") + if query_ouro: + print(f" → Query Ouro: {query_ouro[:50]}...") + else: + print(f" → Query Ouro não fornecida (avaliando via CSVs oficiais).") + + # Configurar resultado ouro (para relatórios e fallback) + if gold_results_list: + resultado_ouro_primeiro = {"success": True, "results": gold_results_list[0], "row_count": len(gold_results_list[0])} + print(f" ✓ CSV Ouro carregado ({len(gold_results_list)} variantes, usando a primeira para display com {len(gold_results_list[0])} linhas)") + else: + resultado_ouro_primeiro = executor.execute_query(db_id, query_ouro) + if not resultado_ouro_primeiro["success"]: + print(f" ⚠️ Erro na query ouro ou db ausente: {resultado_ouro_primeiro['error']}") + print(" (Aviso: Certifique-se de baixar e extrair os bancos locais em spider2-localdb)") + continue + gold_results_list = [resultado_ouro_primeiro["results"]] + print(f" ✓ Query ouro retornou {resultado_ouro_primeiro['row_count']} linhas") + + # Inicializar engine + try: + db_path = str(executor.get_db_path(db_id)) + except FileNotFoundError as e: + print(f" ❌ {e}") + continue + + if db_id not in engine_cache: + try: + engine_cache[db_id] = InsightEngine( + api_key=api_key, + model=model, + db_path=db_path, + hitl=False, + show_output=False, + enable_graphs=args.with_graphs, + ) + except Exception as e: + print(f" ❌ Erro ao inicializar InsightEngine: {e}") + continue + + engine = engine_cache[db_id] + + print(f" → Invocando agente...") + inicio_agente = time.time() + try: + resultado = engine.run(thread_id=f"spider2_test_{instance_id}", query=pergunta) + except Exception as e: + print(f" ⚠️ Erro ao processar pergunta: {str(e)}") + continue + + tempo_total = (time.time() - inicio_agente) * 1000 + + query_agente = resultado.get("sql_gerada", "") + veredito = resultado.get("status", "") + feedback_estado = resultado.get("feedback_critico", "") + erro_exec = resultado.get("erro_execucao", "") + tentativas = resultado.get("tentativas_loop", 1) + historico_tent = resultado.get("historico_tentativas", []) + + # Extrair métricas de tokens acumuladas + tokens_input = resultado.get("tokens_input", 0) or 0 + tokens_output = resultado.get("tokens_output", 0) or 0 + tokens_total = resultado.get("tokens_total", 0) or 0 + + # Extrair dados do agente de visualização + viz_acionado = "grafico_gerado" in resultado + viz_sucesso = resultado.get("grafico_gerado", False) + + # Extrair SQL da 1ª tentativa para ablação do Crítico + query_1a_tentativa = "" + if historico_tent and isinstance(historico_tent, list) and len(historico_tent) > 0: + query_1a_tentativa = historico_tent[0].get("sql", "") + if not query_1a_tentativa: + query_1a_tentativa = query_agente # fallback: se só houve 1 tentativa + + if veredito == "aprovado": + veredito_critico = "aprovado" + feedback_critico = feedback_estado if feedback_estado else "Aprovado" + elif veredito == "reprovado": + veredito_critico = "reprovado" + feedback_critico = feedback_estado if feedback_estado else "Reprovado pelo crítico" + else: + veredito_critico = "erro" + feedback_critico = feedback_estado if feedback_estado else "Erro na avaliação" + + resultado_exato_match = None + resultado_exato_match_1a = None + resultado_f1_1a = 0.0 + similarity_score = 0.0 + f1_scores = {"f1": 0.0, "precision": 0.0, "recall": 0.0} + + if query_agente and not erro_exec: + resultado_agente = executor.execute_query(db_id, query_agente) + if resultado_agente["success"]: + if query_ouro: + similarity_score = sql_similarity_score(query_ouro, query_agente) + + # Testar contra todas as variantes de ouro e pegar a melhor pontuação + best_match = False + best_f1 = {"f1": 0.0, "precision": 0.0, "recall": 0.0} + + for gold_res in gold_results_list: + match_atual = results_exact_match(gold_res, resultado_agente["results"]) + f1_atual = results_f1_score(gold_res, resultado_agente["results"]) + + if match_atual: + best_match = True + + if f1_atual["f1"] > best_f1["f1"]: + best_f1 = f1_atual + + resultado_exato_match = best_match + f1_scores = best_f1 + + # Ablação: calcular exact match da 1ª tentativa + if query_1a_tentativa and query_1a_tentativa != query_agente: + res_1a = executor.execute_query(db_id, query_1a_tentativa) + if res_1a["success"]: + best_match_1a = False + best_f1_1a = 0.0 + for gold_res in gold_results_list: + if results_exact_match(gold_res, res_1a["results"]): + best_match_1a = True + f1_atual_1a = results_f1_score(gold_res, res_1a["results"]) + if f1_atual_1a["f1"] > best_f1_1a: + best_f1_1a = f1_atual_1a["f1"] + resultado_exato_match_1a = best_match_1a + resultado_f1_1a = best_f1_1a + else: + resultado_exato_match_1a = False + resultado_f1_1a = 0.0 + else: + resultado_exato_match_1a = resultado_exato_match + resultado_f1_1a = f1_scores["f1"] + + print( + f" Resultado final ({tentativas} tentativa(s)): " + f"similarity={similarity_score:.2f}, " + f"match={resultado_exato_match}, " + f"match_1a={resultado_exato_match_1a}, " + f"F1={f1_scores['f1']:.2f}, " + f"veredito={veredito_critico}" + ) + if not resultado_exato_match: + mismatches.append({ + "id": instance_id, + "db_id": db_id, + "pergunta": pergunta, + "query_ouro": query_ouro, + "query_agente": query_agente, + "resultado_ouro": gold_results_list[0], + "resultado_agente": resultado_agente["results"], + "f1": f1_scores["f1"], + "precision": f1_scores["precision"], + "recall": f1_scores["recall"], + }) + else: + erro_exec = resultado_agente["error"] + else: + print(f" Resultado final ({tentativas} tentativa(s)): sem query gerada ou com erro") + + row = build_comparison_row( + id_exemplo=instance_id, + tentativa_numero=tentativas, + db_id=db_id, + pergunta=pergunta, + query_ouro=query_ouro, + query_agente=query_agente, + tempo_agente_ms=tempo_total, + veredito_critico=veredito_critico, + feedback_critico=feedback_critico, + erro_execucao=erro_exec, + resultado_exato_match=resultado_exato_match, + similarity_score=similarity_score, + resultado_f1=f1_scores["f1"], + resultado_precision=f1_scores["precision"], + resultado_recall=f1_scores["recall"], + tokens_input=tokens_input, + tokens_output=tokens_output, + tokens_total=tokens_total, + viz_acionado=viz_acionado, + viz_sucesso=viz_sucesso, + resultado_exato_match_1a_tentativa=resultado_exato_match_1a, + resultado_f1_1a_tentativa=resultado_f1_1a, + query_1a_tentativa=query_1a_tentativa, + ) + + reporter.append_row(row) + all_rows.append(row) + + if veredito_critico == "aprovado": + print(f" ✅ APROVADO após {tentativas} tentativa(s)") + else: + print(f" ❌ NÃO APROVADO após {tentativas} tentativa(s)") + + print("\n" + "=" * 100) + print("📊 RESUMO FINAL SPIDER 2 LITE") + print("=" * 100) + + if all_rows: + summary = reporter.generate_summary(all_rows) + f1_values = [float(r.get("resultado_f1", 0.0) or 0.0) for r in all_rows] + f1_medio = sum(f1_values) / len(f1_values) if f1_values else 0.0 + match_values = [r.get("resultado_exato_match") for r in all_rows] + exact_matches = sum(1 for v in match_values if v is True) + exact_match_rate = exact_matches / len(all_rows) if all_rows else 0.0 + + print(f"Total de perguntas processadas: {summary['total_perguntas']}") + print(f"Taxa de aprovação: {summary['taxa_aprovacao']:.1%}") + print(f"F1 score médio (resultados): {f1_medio:.4f}") + print(f"Exact match rate: {exact_match_rate:.1%}") + print(f"\n✅ CSV salvo em: {csv_path}") + + if args.report_dir: + md_dir = Path("reports") / args.report_dir + md_dir.mkdir(parents=True, exist_ok=True) + report_path = str(md_dir / f"{Path(csv_path).stem}_report.md") + empirico_path = str(md_dir / f"{Path(csv_path).stem}_empirico.md") + else: + report_path = csv_path.replace(".csv", "_report.md") + empirico_path = csv_path.replace(".csv", "_empirico.md") + + _gerar_relatorio_md( + report_path=report_path, + summary=summary, + f1_medio=f1_medio, + exact_match_rate=exact_match_rate, + all_rows=all_rows, + mismatches=mismatches, + model=model, + sample_size=args.sample_size, + seed=args.seed, + data_dir=args.data_dir, + ) + print(f"✅ Relatório salvo em: {report_path}") + + # 9. Gerar relatório empírico completo (análises do orientador) + empirico_dir = str((Path(csv_path).parent / Path(csv_path).stem).absolute()) + "_empirico" + gerar_relatorio_empirico_completo( + report_path=empirico_path, + dataset_label="Spider 2.0 Lite", + all_rows=all_rows, + output_dir=empirico_dir, + ) + print(f"✅ Relatório empírico salvo em: {empirico_path}") + print(f" Gráficos e CSVs auxiliares em: {Path(empirico_dir).relative_to(Path.cwd())}/") + else: + print("❌ Nenhum resultado para salvar. (Verificou os bancos na pasta spider2-localdb?)") + + +if __name__ == "__main__": + main() diff --git a/scripts/test_spider_eval.py b/scripts/test_spider_eval.py index e8415f2..fe881f4 100644 --- a/scripts/test_spider_eval.py +++ b/scripts/test_spider_eval.py @@ -39,6 +39,7 @@ sql_similarity_score, ) from src.spider.query_executor import SpiderQueryExecutor +from src.spider.analise_empirica import gerar_relatorio_empirico_completo load_dotenv() @@ -220,7 +221,7 @@ def main(): parser.add_argument( "--data-dir", type=str, - default="data/spider_data/spider_data", + default="spider_data", help="Diretório com dados do Spider", ) @@ -230,12 +231,28 @@ def main(): type=str, help="Filtrar por um trecho específico da pergunta em inglês", ) + parser.add_argument( + "--model", + type=str, + default="gpt-4o-mini", + help="Modelo LLM a utilizar (default: gpt-4o-mini)", + ) + parser.add_argument( + "--with-graphs", + action="store_true", + help="Ativar a geração de gráficos e salvamento de CSV", + ) + parser.add_argument( + "--report-dir", + type=str, + default="", + help="Pasta dentro de 'reports' para salvar os relatórios .md (CSVs continuam fora)", + ) args = parser.parse_args() # Validar API key - model = "gpt-4o-mini" - # model = "gemini-2.5-flash" + model = args.model api_key = os.getenv("OPENAI_API_KEY") if "gpt" in model.lower() else os.getenv("GOOGLE_API_KEY") if not api_key: @@ -325,6 +342,7 @@ def main(): db_path=db_path, hitl=False, show_output=False, + enable_graphs=args.with_graphs, ) print(f" ✓ InsightEngine inicializado para db={db_id}") except Exception as e: @@ -354,6 +372,23 @@ def main(): feedback_estado = resultado.get("feedback_critico", "") erro_exec = resultado.get("erro_execucao", "") tentativas = resultado.get("tentativas_loop", 1) + historico_tent = resultado.get("historico_tentativas", []) + + # Extrair métricas de tokens acumuladas + tokens_input = resultado.get("tokens_input", 0) or 0 + tokens_output = resultado.get("tokens_output", 0) or 0 + tokens_total = resultado.get("tokens_total", 0) or 0 + + # Extrair dados do agente de visualização + viz_acionado = "grafico_gerado" in resultado + viz_sucesso = resultado.get("grafico_gerado", False) + + # Extrair SQL da 1ª tentativa para ablação do Crítico + query_1a_tentativa = "" + if historico_tent and isinstance(historico_tent, list) and len(historico_tent) > 0: + query_1a_tentativa = historico_tent[0].get("sql", "") + if not query_1a_tentativa: + query_1a_tentativa = query_agente # fallback: se só houve 1 tentativa # Mapear status para veredito e definir feedback if veredito == "aprovado": @@ -368,6 +403,8 @@ def main(): # Comparar resultados se query agente foi gerada resultado_exato_match = None + resultado_exato_match_1a = None + resultado_f1_1a = 0.0 similarity_score = 0.0 f1_scores = {"f1": 0.0, "precision": 0.0, "recall": 0.0} @@ -387,10 +424,30 @@ def main(): resultado_ouro["results"], resultado_agente["results"], ) + + # Ablação: calcular exact match e F1 da 1ª tentativa + if query_1a_tentativa and query_1a_tentativa != query_agente: + res_1a = executor.execute_query(db_id, query_1a_tentativa) + if res_1a["success"]: + resultado_exato_match_1a = results_exact_match( + resultado_ouro["results"], res_1a["results"] + ) + f1_1a = results_f1_score( + resultado_ouro["results"], res_1a["results"] + ) + resultado_f1_1a = f1_1a["f1"] + else: + resultado_exato_match_1a = False + resultado_f1_1a = 0.0 + else: + resultado_exato_match_1a = resultado_exato_match + resultado_f1_1a = f1_scores["f1"] + print( f" Resultado final ({tentativas} tentativa(s)): " f"similarity={similarity_score:.2f}, " f"match={resultado_exato_match}, " + f"match_1a={resultado_exato_match_1a}, " f"F1={f1_scores['f1']:.2f}, " f"veredito={veredito_critico}" ) @@ -416,7 +473,7 @@ def main(): f"sem query gerada ou com erro de execução" ) - # Construir linha para CSV + # Construir linha para CSV (com campos empíricos adicionais) row = build_comparison_row( id_exemplo=ex_id, tentativa_numero=tentativas, @@ -433,6 +490,14 @@ def main(): resultado_f1=f1_scores["f1"], resultado_precision=f1_scores["precision"], resultado_recall=f1_scores["recall"], + tokens_input=tokens_input, + tokens_output=tokens_output, + tokens_total=tokens_total, + viz_acionado=viz_acionado, + viz_sucesso=viz_sucesso, + resultado_exato_match_1a_tentativa=resultado_exato_match_1a, + resultado_f1_1a_tentativa=resultado_f1_1a, + query_1a_tentativa=query_1a_tentativa, ) reporter.append_row(row) @@ -454,7 +519,7 @@ def main(): if all_rows: summary = reporter.generate_summary(all_rows) # Calcular F1 médio - f1_values = [float(r.get("resultado_f1", 0)) for r in all_rows if r.get("resultado_f1")] + f1_values = [float(r.get("resultado_f1", 0.0) or 0.0) for r in all_rows] f1_medio = sum(f1_values) / len(f1_values) if f1_values else 0.0 # Calcular exact match rate match_values = [r.get("resultado_exato_match") for r in all_rows] @@ -475,7 +540,15 @@ def main(): print(f"\n✅ CSV salvo em: {csv_path}") # 8. Gerar relatório textual em Markdown - report_path = csv_path.replace(".csv", "_report.md") + if args.report_dir: + md_dir = Path("reports") / args.report_dir + md_dir.mkdir(parents=True, exist_ok=True) + report_path = str(md_dir / f"{Path(csv_path).stem}_report.md") + empirico_path = str(md_dir / f"{Path(csv_path).stem}_empirico.md") + else: + report_path = csv_path.replace(".csv", "_report.md") + empirico_path = csv_path.replace(".csv", "_empirico.md") + _gerar_relatorio_md( report_path=report_path, summary=summary, @@ -489,6 +562,17 @@ def main(): data_dir=args.data_dir, ) print(f"✅ Relatório salvo em: {report_path}") + + # 9. Gerar relatório empírico completo (análises do orientador) + empirico_dir = str((Path(csv_path).parent / Path(csv_path).stem).absolute()) + "_empirico" + gerar_relatorio_empirico_completo( + report_path=empirico_path, + dataset_label="Spider", + all_rows=all_rows, + output_dir=empirico_dir, + ) + print(f"✅ Relatório empírico salvo em: {empirico_path}") + print(f" Gráficos e CSVs auxiliares em: {Path(empirico_dir).relative_to(Path.cwd())}/") else: print("❌ Nenhum resultado para salvar") diff --git a/src/spider/analise_empirica.py b/src/spider/analise_empirica.py new file mode 100644 index 0000000..c2228a3 --- /dev/null +++ b/src/spider/analise_empirica.py @@ -0,0 +1,713 @@ +""" +Módulo de Análise Empírica para avaliação Spider / Spider 2.0 Lite. + +Contém funções de pós-processamento para gerar: +1. Distribuição de tentativas e ablação do Crítico +2. Matriz de confusão do Crítico +3. Taxonomia de erros SQL +4. Tabela de métricas operacionais +5. Estatísticas do Agente de Visualização +""" + +import csv +import re +from collections import Counter +from datetime import datetime +from pathlib import Path +from typing import Any + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt + +from .metrics import normalize_sql + + +# --------------------------------------------------------------------------- +# 1. Distribuição de tentativas e ablação do Crítico +# --------------------------------------------------------------------------- + +def calcular_distribuicao_tentativas(all_rows: list[dict]) -> dict[str, Any]: + """ + Para cada pergunta, identifica em qual tentativa o Crítico aprovou. + Retorna tabela de frequência e dados para ablação. + """ + freq = {"1a_tentativa": 0, "2a_tentativa": 0, "3a_tentativa": 0, "falha": 0} + + for r in all_rows: + tent = r.get("tentativa_numero", 1) + veredito = r.get("veredito_critico", "") + try: + tent = int(tent) + except (ValueError, TypeError): + tent = 1 + + if veredito == "aprovado": + if tent == 1: + freq["1a_tentativa"] += 1 + elif tent == 2: + freq["2a_tentativa"] += 1 + elif tent >= 3: + freq["3a_tentativa"] += 1 + else: + freq["falha"] += 1 + else: + freq["falha"] += 1 + + return freq + + +def calcular_ablacao_critico(all_rows: list[dict]) -> dict[str, float]: + """ + Calcula exact match COM e SEM o mecanismo de autocorreção (Crítico). + + - COM Crítico: exact match final (como já calculado). + - SEM Crítico: exact match considerando APENAS o resultado da 1ª tentativa + (campo `resultado_exato_match_1a_tentativa`). + """ + total = len(all_rows) if all_rows else 1 + + em_com_critico = sum( + 1 for r in all_rows if r.get("resultado_exato_match") is True + ) + em_sem_critico = sum( + 1 for r in all_rows if r.get("resultado_exato_match_1a_tentativa") is True + ) + + f1_com_critico = sum( + float(r.get("resultado_f1", 0.0) or 0.0) for r in all_rows + ) + f1_sem_critico = sum( + float(r.get("resultado_f1_1a_tentativa", 0.0) or 0.0) for r in all_rows + ) + + return { + "exact_match_com_critico": em_com_critico / total, + "exact_match_sem_critico": em_sem_critico / total, + "f1_com_critico": f1_com_critico / total, + "f1_sem_critico": f1_sem_critico / total, + "total_perguntas": total, + "acertos_com_critico": em_com_critico, + "acertos_sem_critico": em_sem_critico, + } + + +def gerar_grafico_ablacao(ablacao: dict, output_path: str, dataset_label: str = "Spider") -> str: + """Gera gráfico de barras comparando exact match com e sem Crítico.""" + labels = ["Com Crítico\n(autocorreção)", "Sem Crítico\n(1ª tentativa)"] + valores = [ + ablacao["exact_match_com_critico"] * 100, + ablacao["exact_match_sem_critico"] * 100, + ] + cores = ["#2ecc71", "#e74c3c"] + + fig, ax = plt.subplots(figsize=(7, 5)) + bars = ax.bar(labels, valores, color=cores, width=0.5, edgecolor="white", linewidth=1.5) + + for bar, val in zip(bars, valores): + ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 1, + f"{val:.1f}%", ha="center", va="bottom", fontweight="bold", fontsize=13) + + ax.set_ylabel("Exact Match (%)", fontsize=12) + ax.set_title(f"Ablação do Crítico — {dataset_label}", fontsize=14, fontweight="bold") + ax.set_ylim(0, max(valores) * 1.25 if max(valores) > 0 else 100) + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + return output_path + + +def gerar_grafico_distribuicao_tentativas(freq: dict, output_path: str, dataset_label: str = "Spider") -> str: + """Gera gráfico de barras com a distribuição de tentativas.""" + labels = ["1ª tentativa", "2ª tentativa", "3ª tentativa", "Falha (3 tent.)"] + valores = [freq["1a_tentativa"], freq["2a_tentativa"], freq["3a_tentativa"], freq["falha"]] + cores = ["#27ae60", "#f39c12", "#e67e22", "#e74c3c"] + + fig, ax = plt.subplots(figsize=(8, 5)) + bars = ax.bar(labels, valores, color=cores, width=0.55, edgecolor="white", linewidth=1.5) + + for bar, val in zip(bars, valores): + if val > 0: + ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.3, + str(val), ha="center", va="bottom", fontweight="bold", fontsize=12) + + ax.set_ylabel("Número de Perguntas", fontsize=12) + ax.set_title(f"Distribuição de Tentativas até Aprovação — {dataset_label}", fontsize=13, fontweight="bold") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + return output_path + + +# --------------------------------------------------------------------------- +# 1.5 Transições de Autocorreção +# --------------------------------------------------------------------------- + +def calcular_transicoes_autocorrecao(all_rows: list[dict]) -> dict[str, int]: + """ + Analisa as transições de Exact Match da 1ª tentativa para a final, + apenas para perguntas que tiveram > 1 tentativa. + """ + transicoes = { + "ajudou": 0, # False -> True + "manteve_certo": 0, # True -> True + "manteve_errado": 0, # False -> False + "atrapalhou": 0 # True -> False + } + + for r in all_rows: + tent = r.get("tentativa_numero", 1) + try: + tent = int(tent) + except (ValueError, TypeError): + tent = 1 + + if tent > 1: + em_1a = r.get("resultado_exato_match_1a_tentativa") is True + em_final = r.get("resultado_exato_match") is True + + if not em_1a and em_final: + transicoes["ajudou"] += 1 + elif em_1a and em_final: + transicoes["manteve_certo"] += 1 + elif not em_1a and not em_final: + transicoes["manteve_errado"] += 1 + elif em_1a and not em_final: + transicoes["atrapalhou"] += 1 + + return transicoes + + +def exportar_detalhes_transicoes(all_rows: list[dict], output_csv: str) -> int: + """ + Exporta os dados de TODAS as queries que passaram por autocorreção (>1 tentativa), + classificando o tipo de transição. + """ + detalhes = [] + for r in all_rows: + tent = r.get("tentativa_numero", 1) + try: + tent = int(tent) + except (ValueError, TypeError): + tent = 1 + + if tent > 1: + em_1a = r.get("resultado_exato_match_1a_tentativa") is True + em_final = r.get("resultado_exato_match") is True + + tipo_transicao = "" + if not em_1a and em_final: + tipo_transicao = "EM=False -> EM=True (Ajudou)" + elif em_1a and em_final: + tipo_transicao = "EM=True -> EM=True (Manteve Certo)" + elif not em_1a and not em_final: + tipo_transicao = "EM=False -> EM=False (Manteve Errado)" + elif em_1a and not em_final: + tipo_transicao = "EM=True -> EM=False (Atrapalhou)" + + detalhes.append({ + "id_exemplo": r.get("id_exemplo", ""), + "db_id": r.get("db_id", ""), + "tipo_transicao": tipo_transicao, + "pergunta": r.get("pergunta_usuario", ""), + "query_ouro": r.get("query_ouro_spider", ""), + "query_1a_tentativa": r.get("query_1a_tentativa", ""), + "query_final": r.get("query_agente_tentativa", ""), + "feedback_critico": r.get("feedback_critico_recebido", ""), + "f1_1a_tentativa": r.get("resultado_f1_1a_tentativa", 0), + "f1_final": r.get("resultado_f1", 0), + "tentativas": tent + }) + + if detalhes: + Path(output_csv).parent.mkdir(parents=True, exist_ok=True) + with open(output_csv, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=detalhes[0].keys()) + writer.writeheader() + writer.writerows(detalhes) + + return len(detalhes) + + +# --------------------------------------------------------------------------- +# 2. Matriz de confusão do Crítico +# --------------------------------------------------------------------------- + +def calcular_matriz_confusao(all_rows: list[dict]) -> dict[str, int]: + """ + Cruza veredito do Crítico com exact match real. + Retorna TP, FP, FN, TN. + """ + tp = fp = fn = tn = 0 + for r in all_rows: + aprovado = r.get("veredito_critico") == "aprovado" + match = r.get("resultado_exato_match") is True + + if aprovado and match: + tp += 1 + elif aprovado and not match: + fp += 1 + elif not aprovado and match: + fn += 1 + else: + tn += 1 + + return {"TP": tp, "FP": fp, "FN": fn, "TN": tn} + + +def exportar_falsos_positivos(all_rows: list[dict], output_csv: str) -> int: + """ + Exporta os casos de falso positivo (Crítico aprovou, mas exact match incorreto). + Retorna a quantidade de falsos positivos. + """ + fps = [] + for r in all_rows: + aprovado = r.get("veredito_critico") == "aprovado" + match = r.get("resultado_exato_match") + if aprovado and match is not True: + fps.append({ + "id_exemplo": r.get("id_exemplo", ""), + "db_id": r.get("db_id", ""), + "pergunta": r.get("pergunta_usuario", ""), + "query_ouro": r.get("query_ouro_spider", ""), + "query_agente": r.get("query_agente_tentativa", ""), + "veredito_critico": r.get("veredito_critico", ""), + "feedback_critico": r.get("feedback_critico_recebido", ""), + "resultado_exato_match": match, + "f1": r.get("resultado_f1", 0), + }) + + if fps: + Path(output_csv).parent.mkdir(parents=True, exist_ok=True) + with open(output_csv, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=fps[0].keys()) + writer.writeheader() + writer.writerows(fps) + + return len(fps) + + +# --------------------------------------------------------------------------- +# 3. Taxonomia de erros +# --------------------------------------------------------------------------- + +_CATEGORIAS_ERRO = [ + ("DISTINCT", r"\bDISTINCT\b"), + ("GROUP BY", r"\bGROUP\s+BY\b"), + ("ORDER BY", r"\bORDER\s+BY\b"), + ("LIMIT", r"\bLIMIT\b"), + ("HAVING", r"\bHAVING\b"), + ("Subconsulta", r"\(\s*SELECT\b"), + ("JOIN", r"\bJOIN\b"), + ("WHERE", r"\bWHERE\b"), + ("Agregação (SUM/AVG/COUNT/MIN/MAX)", r"\b(SUM|AVG|COUNT|MIN|MAX)\s*\("), + ("UNION", r"\bUNION\b"), +] + + +def _detectar_divergencias(sql_ouro: str, sql_agente: str) -> list[str]: + """Detecta categorias de divergência entre SQL ouro e SQL agente.""" + ouro_norm = normalize_sql(sql_ouro) + agente_norm = normalize_sql(sql_agente) + + divergencias = [] + for nome, padrao in _CATEGORIAS_ERRO: + ouro_tem = bool(re.search(padrao, ouro_norm, re.IGNORECASE)) + agente_tem = bool(re.search(padrao, agente_norm, re.IGNORECASE)) + if ouro_tem != agente_tem: + divergencias.append(nome) + + # Checar diferença nas colunas do SELECT + def _extrair_colunas_select(sql_norm: str) -> set[str]: + m = re.match(r"SELECT\s+(.*?)\s+FROM\b", sql_norm, re.IGNORECASE | re.DOTALL) + if m: + cols = m.group(1).split(",") + return {c.strip() for c in cols} + return set() + + cols_ouro = _extrair_colunas_select(ouro_norm) + cols_agente = _extrair_colunas_select(agente_norm) + if cols_ouro and cols_agente and cols_ouro != cols_agente: + divergencias.append("Colunas SELECT diferentes") + + if not divergencias: + divergencias.append("Outro (valores/lógica)") + + return divergencias + + +def classificar_erros(all_rows: list[dict]) -> tuple[list[dict], Counter]: + """ + Para as perguntas sem exact match, classifica o tipo de erro. + Retorna lista de registros detalhados e Counter por categoria. + """ + erros_detalhados = [] + contagem = Counter() + + for r in all_rows: + if r.get("resultado_exato_match") is True: + continue + sql_ouro = r.get("query_ouro_spider", "") + sql_agente = r.get("query_agente_tentativa", "") + + if not sql_ouro or not sql_agente: + categorias = ["Sem SQL (ouro ou agente)"] + else: + categorias = _detectar_divergencias(sql_ouro, sql_agente) + + for cat in categorias: + contagem[cat] += 1 + + erros_detalhados.append({ + "id_exemplo": r.get("id_exemplo", ""), + "db_id": r.get("db_id", ""), + "pergunta": r.get("pergunta_usuario", ""), + "query_ouro": sql_ouro, + "query_agente": sql_agente, + "categorias_erro": "; ".join(categorias), + }) + + return erros_detalhados, contagem + + +def exportar_taxonomia_erros_csv(erros_detalhados: list[dict], output_csv: str) -> None: + """Exporta erros detalhados em CSV.""" + if not erros_detalhados: + return + Path(output_csv).parent.mkdir(parents=True, exist_ok=True) + with open(output_csv, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=erros_detalhados[0].keys()) + writer.writeheader() + writer.writerows(erros_detalhados) + + +def gerar_grafico_taxonomia_erros(contagem: Counter, output_path: str, dataset_label: str = "Spider") -> str: + """Gera gráfico de barras horizontais com contagem por categoria de erro.""" + if not contagem: + return "" + + cats = contagem.most_common() + labels = [c[0] for c in cats] + valores = [c[1] for c in cats] + + fig, ax = plt.subplots(figsize=(9, max(4, len(labels) * 0.55))) + cores = plt.cm.RdYlGn_r([i / max(len(labels), 1) for i in range(len(labels))]) + bars = ax.barh(labels[::-1], valores[::-1], color=cores[::-1], edgecolor="white", linewidth=1.2) + + for bar, val in zip(bars, valores[::-1]): + ax.text(bar.get_width() + 0.3, bar.get_y() + bar.get_height() / 2, + str(val), va="center", fontweight="bold", fontsize=11) + + ax.set_xlabel("Contagem", fontsize=12) + ax.set_title(f"Taxonomia de Erros SQL — {dataset_label}", fontsize=13, fontweight="bold") + ax.spines["top"].set_visible(False) + ax.spines["right"].set_visible(False) + plt.tight_layout() + plt.savefig(output_path, dpi=150, bbox_inches="tight") + plt.close() + return output_path + + +# --------------------------------------------------------------------------- +# 4. Tabela de métricas operacionais +# --------------------------------------------------------------------------- + +def calcular_metricas_operacionais(all_rows: list[dict]) -> dict[str, Any]: + """ + Calcula métricas operacionais por pergunta e segmenta por + resolvidas na 1ª tentativa vs múltiplas tentativas. + """ + grupo_1a = [] # resolvidas na 1ª tentativa + grupo_multi = [] # 2+ tentativas + + for r in all_rows: + tent = r.get("tentativa_numero", 1) + try: + tent = int(tent) + except (ValueError, TypeError): + tent = 1 + + dados = { + "tokens_input": int(r.get("tokens_input", 0) or 0), + "tokens_output": int(r.get("tokens_output", 0) or 0), + "tokens_total": int(r.get("tokens_total", 0) or 0), + "tempo_ms": float(r.get("tempo_agente_ms", 0) or 0), + } + + if tent <= 1: + grupo_1a.append(dados) + else: + grupo_multi.append(dados) + + def _media(grupo: list[dict], chave: str) -> float: + vals = [g[chave] for g in grupo] + return sum(vals) / len(vals) if vals else 0.0 + + def _soma(grupo: list[dict], chave: str) -> int: + return sum(g[chave] for g in grupo) + + return { + "geral": { + "tokens_input_medio": _media(grupo_1a + grupo_multi, "tokens_input"), + "tokens_output_medio": _media(grupo_1a + grupo_multi, "tokens_output"), + "tokens_total_soma": _soma(grupo_1a + grupo_multi, "tokens_total"), + "tempo_medio_ms": _media(grupo_1a + grupo_multi, "tempo_ms"), + "n": len(grupo_1a + grupo_multi), + }, + "1a_tentativa": { + "tokens_input_medio": _media(grupo_1a, "tokens_input"), + "tokens_output_medio": _media(grupo_1a, "tokens_output"), + "tokens_total_soma": _soma(grupo_1a, "tokens_total"), + "tempo_medio_ms": _media(grupo_1a, "tempo_ms"), + "n": len(grupo_1a), + }, + "multiplas_tentativas": { + "tokens_input_medio": _media(grupo_multi, "tokens_input"), + "tokens_output_medio": _media(grupo_multi, "tokens_output"), + "tokens_total_soma": _soma(grupo_multi, "tokens_total"), + "tempo_medio_ms": _media(grupo_multi, "tempo_ms"), + "n": len(grupo_multi), + }, + } + + +# --------------------------------------------------------------------------- +# 5. Estatísticas do Agente de Visualização +# --------------------------------------------------------------------------- + +def calcular_estatisticas_visualizacao(all_rows: list[dict]) -> dict[str, int]: + """Contabiliza quantas queries acionaram o agente de visualização.""" + total = len(all_rows) + acionaram = sum(1 for r in all_rows if r.get("viz_acionado") is True) + sucesso = sum(1 for r in all_rows if r.get("viz_sucesso") is True) + falha = acionaram - sucesso + + return { + "total_queries": total, + "acionaram_agente": acionaram, + "graficos_sucesso": sucesso, + "graficos_falha": falha, + } + + +# --------------------------------------------------------------------------- +# 6. Relatório consolidado em Markdown +# --------------------------------------------------------------------------- + +def gerar_secao_distribuicao_tentativas(freq: dict, ablacao: dict, grafico_dist_path: str, grafico_abl_path: str) -> list[str]: + """Gera seção do relatório com distribuição de tentativas e ablação.""" + lines = [] + lines.append("## 1. Distribuição de Tentativas e Ablação do Crítico") + lines.append("") + lines.append("### Tabela de Frequência") + lines.append("") + lines.append("| Tentativa | Quantidade |") + lines.append("|-----------|-----------|") + lines.append(f"| 1ª tentativa | {freq['1a_tentativa']} |") + lines.append(f"| 2ª tentativa | {freq['2a_tentativa']} |") + lines.append(f"| 3ª tentativa | {freq['3a_tentativa']} |") + lines.append(f"| Falha (todas as 3) | {freq['falha']} |") + lines.append("") + lines.append("### Ablação do Crítico") + lines.append("") + lines.append("| Configuração | Exact Match | Taxa EM | F1 Médio |") + lines.append("|-------------|-------------|---------|----------|") + lines.append(f"| Com Crítico (autocorreção) | {ablacao['acertos_com_critico']}/{ablacao['total_perguntas']} | {ablacao['exact_match_com_critico']:.1%} | {ablacao['f1_com_critico']:.4f} |") + lines.append(f"| Sem Crítico (1ª tentativa) | {ablacao['acertos_sem_critico']}/{ablacao['total_perguntas']} | {ablacao['exact_match_sem_critico']:.1%} | {ablacao['f1_sem_critico']:.4f} |") + lines.append("") + if grafico_dist_path: + lines.append(f"![Distribuição de Tentativas]({grafico_dist_path})") + lines.append("") + if grafico_abl_path: + lines.append(f"![Ablação do Crítico]({grafico_abl_path})") + lines.append("") + return lines + + +def gerar_secao_transicoes(transicoes: dict, n_total_transicoes: int, csv_path: str) -> list[str]: + """Gera seção do relatório com a tabela de transições de autocorreção.""" + lines = [] + lines.append("## 1.5. Transições de Autocorreção (>1 tentativa)") + lines.append("") + total_transicoes = sum(transicoes.values()) + + if total_transicoes == 0: + lines.append("Nenhuma pergunta precisou de autocorreção (todas resolvidas na 1ª tentativa).") + lines.append("") + return lines + + lines.append("| Transição | Quantidade | Significado |") + lines.append("|-----------|------------|-------------|") + lines.append(f"| EM=False → EM=True | {transicoes['ajudou']} | Autocorreção ajudou |") + lines.append(f"| EM=True → EM=True | {transicoes['manteve_certo']} | Já era certo, continuou certo |") + lines.append(f"| EM=False → EM=False | {transicoes['manteve_errado']} | Já era errado, continuou errado |") + lines.append(f"| EM=True → EM=False | {transicoes['atrapalhou']} | Autocorreção atrapalhou |") + lines.append("") + + if n_total_transicoes > 0: + lines.append(f"> Detalhes completos (Queries, Feedbacks e F1) de todas as **{n_total_transicoes} transições** exportados em: `{csv_path}`") + lines.append("") + + return lines + + +def gerar_secao_matriz_confusao(mc: dict, n_fps: int, fps_csv: str) -> list[str]: + """Gera seção do relatório com matriz de confusão.""" + lines = [] + lines.append("## 2. Matriz de Confusão do Crítico") + lines.append("") + lines.append("| | Exact Match Correto | Exact Match Incorreto |") + lines.append("|--|--------------------|-----------------------|") + lines.append(f"| **Crítico Aprovou** | TP = {mc['TP']} | FP = {mc['FP']} |") + lines.append(f"| **Crítico Reprovou** | FN = {mc['FN']} | TN = {mc['TN']} |") + lines.append("") + total = mc["TP"] + mc["FP"] + mc["FN"] + mc["TN"] + if total > 0: + acc = (mc["TP"] + mc["TN"]) / total + lines.append(f"**Acurácia do Crítico:** {acc:.1%}") + lines.append("") + if n_fps > 0: + lines.append(f"> **{n_fps} falso(s) positivo(s)** exportados para análise qualitativa: `{fps_csv}`") + lines.append("") + return lines + + +def gerar_secao_taxonomia_erros(contagem: Counter, n_erros: int, erros_csv: str, grafico_path: str) -> list[str]: + """Gera seção do relatório com taxonomia de erros.""" + lines = [] + lines.append("## 3. Taxonomia de Erros SQL") + lines.append("") + lines.append(f"Total de perguntas sem exact match: **{n_erros}**") + lines.append("") + if contagem: + lines.append("| Categoria de Divergência | Contagem |") + lines.append("|--------------------------|----------|") + for cat, cnt in contagem.most_common(): + lines.append(f"| {cat} | {cnt} |") + lines.append("") + if erros_csv: + lines.append(f"> Detalhes exportados em: `{erros_csv}`") + lines.append("") + if grafico_path: + lines.append(f"![Taxonomia de Erros]({grafico_path})") + lines.append("") + return lines + + +def gerar_secao_metricas_operacionais(metricas: dict) -> list[str]: + """Gera seção do relatório com métricas operacionais.""" + lines = [] + lines.append("## 4. Métricas Operacionais") + lines.append("") + lines.append("| Métrica | Geral | 1ª Tentativa | 2+ Tentativas |") + lines.append("|---------|-------|--------------|---------------|") + + g = metricas["geral"] + t1 = metricas["1a_tentativa"] + tm = metricas["multiplas_tentativas"] + + lines.append(f"| N (perguntas) | {g['n']} | {t1['n']} | {tm['n']} |") + lines.append(f"| Tokens input médios | {g['tokens_input_medio']:.0f} | {t1['tokens_input_medio']:.0f} | {tm['tokens_input_medio']:.0f} |") + lines.append(f"| Tokens output médios | {g['tokens_output_medio']:.0f} | {t1['tokens_output_medio']:.0f} | {tm['tokens_output_medio']:.0f} |") + lines.append(f"| Tokens total (soma) | {g['tokens_total_soma']} | {t1['tokens_total_soma']} | {tm['tokens_total_soma']} |") + lines.append(f"| Tempo médio (ms) | {g['tempo_medio_ms']:.0f} | {t1['tempo_medio_ms']:.0f} | {tm['tempo_medio_ms']:.0f} |") + lines.append("") + return lines + + +def gerar_secao_visualizacao(stats: dict) -> list[str]: + """Gera seção do relatório com estatísticas de visualização.""" + lines = [] + lines.append("## 5. Estatísticas do Agente de Visualização") + lines.append("") + lines.append("| Métrica | Valor |") + lines.append("|---------|-------|") + lines.append(f"| Total de queries | {stats['total_queries']} |") + lines.append(f"| Acionaram agente de gráfico | {stats['acionaram_agente']} |") + lines.append(f"| Gráficos gerados com sucesso | {stats['graficos_sucesso']} |") + lines.append(f"| Gráficos com falha | {stats['graficos_falha']} |") + lines.append("") + return lines + + +def gerar_relatorio_empirico_completo( + report_path: str, + dataset_label: str, + all_rows: list[dict], + output_dir: str, +) -> str: + """ + Gera o relatório empírico completo em Markdown com todos os 5 módulos. + + Args: + report_path: Caminho para salvar o relatório .md + dataset_label: "Spider" ou "Spider 2.0 Lite" + all_rows: Lista de dicts com resultados por pergunta + output_dir: Diretório para salvar gráficos e CSVs auxiliares + + Returns: + Caminho do relatório gerado. + """ + Path(output_dir).mkdir(parents=True, exist_ok=True) + timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + + lines = [] + lines.append(f"# Relatório Empírico — {dataset_label}") + lines.append("") + lines.append(f"**Gerado em:** {timestamp}") + lines.append("") + + # 1. Distribuição de tentativas + ablação + freq = calcular_distribuicao_tentativas(all_rows) + ablacao = calcular_ablacao_critico(all_rows) + + grafico_dist = str(Path(output_dir) / "distribuicao_tentativas.png") + gerar_grafico_distribuicao_tentativas(freq, grafico_dist, dataset_label) + + grafico_abl = str(Path(output_dir) / "ablacao_critico.png") + gerar_grafico_ablacao(ablacao, grafico_abl, dataset_label) + + lines.extend(gerar_secao_distribuicao_tentativas(freq, ablacao, grafico_dist, grafico_abl)) + + # 1.5 Transições de Autocorreção + transicoes = calcular_transicoes_autocorrecao(all_rows) + detalhes_csv = str(Path(output_dir) / "detalhes_transicoes.csv") + n_detalhes = exportar_detalhes_transicoes(all_rows, detalhes_csv) + lines.extend(gerar_secao_transicoes(transicoes, n_detalhes, detalhes_csv)) + + # 2. Matriz de confusão + mc = calcular_matriz_confusao(all_rows) + fps_csv = str(Path(output_dir) / "falsos_positivos.csv") + n_fps = exportar_falsos_positivos(all_rows, fps_csv) + lines.extend(gerar_secao_matriz_confusao(mc, n_fps, fps_csv)) + + # 3. Taxonomia de erros + erros_detalhados, contagem_erros = classificar_erros(all_rows) + erros_csv = str(Path(output_dir) / "taxonomia_erros.csv") + exportar_taxonomia_erros_csv(erros_detalhados, erros_csv) + + grafico_erros = "" + if contagem_erros: + grafico_erros = str(Path(output_dir) / "taxonomia_erros.png") + gerar_grafico_taxonomia_erros(contagem_erros, grafico_erros, dataset_label) + + lines.extend(gerar_secao_taxonomia_erros(contagem_erros, len(erros_detalhados), erros_csv, grafico_erros)) + + # 4. Métricas operacionais + metricas = calcular_metricas_operacionais(all_rows) + lines.extend(gerar_secao_metricas_operacionais(metricas)) + + # 5. Estatísticas de visualização + stats_viz = calcular_estatisticas_visualizacao(all_rows) + lines.extend(gerar_secao_visualizacao(stats_viz)) + + # Salvar + Path(report_path).parent.mkdir(parents=True, exist_ok=True) + with open(report_path, "w", encoding="utf-8") as f: + f.write("\n".join(lines)) + + return report_path diff --git a/src/spider/csv_reporter.py b/src/spider/csv_reporter.py index 249a687..b42fb50 100644 --- a/src/spider/csv_reporter.py +++ b/src/spider/csv_reporter.py @@ -32,6 +32,15 @@ class CSVReporter: "resultado_f1", "resultado_precision", "resultado_recall", + # Campos adicionais para análise empírica + "tokens_input", + "tokens_output", + "tokens_total", + "viz_acionado", + "viz_sucesso", + "resultado_exato_match_1a_tentativa", + "resultado_f1_1a_tentativa", + "query_1a_tentativa", ] def __init__(self, filepath: str | Path): @@ -101,18 +110,25 @@ def generate_summary(self, rows: list[dict[str, Any]]) -> dict[str, Any]: total_perguntas = len(by_exemplo) perguntas_aprovadas = 0 perguntas_1a_tentativa = 0 - total_tentativas = len(rows) + total_tentativas = 0 similarities = [] tempos = [] for ex_id, tentativas in by_exemplo.items(): # Última tentativa desta pergunta ultima = tentativas[-1] + + try: + qtd_tentativas = int(ultima.get("tentativa_numero", len(tentativas))) + except (ValueError, TypeError): + qtd_tentativas = len(tentativas) + + total_tentativas += qtd_tentativas if ultima["veredito_critico"] == "aprovado": perguntas_aprovadas += 1 - if len(tentativas) == 1 and ultima["veredito_critico"] == "aprovado": + if qtd_tentativas == 1 and ultima["veredito_critico"] == "aprovado": perguntas_1a_tentativa += 1 # Coletar similarity scores (de tentativas bem-sucedidas) diff --git a/src/spider/metrics.py b/src/spider/metrics.py index 5fd0ad9..aed7cc2 100644 --- a/src/spider/metrics.py +++ b/src/spider/metrics.py @@ -3,19 +3,34 @@ Fornece: - Similarity score entre duas queries (difflib-based) -- Comparação de resultados (exato match) -- F1 score de resultados (row-level precision/recall) +- Comparação de resultados no estilo Spider 2.0 (column-level matching) +- F1 score de resultados (column-level precision/recall) - Normalização de SQL para comparação + +A lógica de comparação segue a métrica oficial do Spider 2.0: +- Transpõe ambas as tabelas para obter vetores-coluna +- Para cada coluna do gold, verifica se alguma coluna do pred bate +- Usa tolerância de 1e-2 para comparações numéricas +- Trata NaN/None com pd.isna """ import difflib +import math import re -from collections import Counter from typing import Any -import numpy as np +import pandas as pd + + +# --------------------------------------------------------------------------- +# Constantes +# --------------------------------------------------------------------------- +_TOLERANCE = 1e-2 +# --------------------------------------------------------------------------- +# Normalização de SQL +# --------------------------------------------------------------------------- def normalize_sql(sql: str) -> str: """ Normaliza SQL para comparação mais robusta. @@ -75,114 +90,192 @@ def sql_similarity_score(sql1: str, sql2: str) -> float: return matcher.ratio() -# def results_exact_match( -# results_gold: list[dict[str, Any]], -# results_agent: list[dict[str, Any]], -# ) -> bool: -# """ -# Compara se dois conjuntos de resultados são exatamente iguais. - -# Compara: -# - Número de linhas -# - Valores de cada linha (insensível a ordem das colunas) - -# Args: -# results_gold: Resultados da query ouro -# results_agent: Resultados da query do agente - -# Returns: -# True se resultados são iguais -# """ -# if len(results_gold) != len(results_agent): -# return False - -# # Converter dicts para conjuntos de tuplas para comparação -# # (para serem agnósticos à ordem das colunas) -# def result_set(results: list[dict[str, Any]]) -> set: -# converted = [] -# for row in results: -# # Converter valores para strings para lidar com tipos diferentes -# items = [] -# for k in sorted(row.keys()): -# # Normalizar None/NULL -# v = row[k] -# if v is None: -# v = "NULL" -# items.append((k, str(v))) -# converted.append(tuple(items)) -# return set(converted) - -# return result_set(results_gold) == result_set(results_agent) +# --------------------------------------------------------------------------- +# Helpers internos — lógica Spider 2.0 +# --------------------------------------------------------------------------- +def _vectors_match(v1: list, v2: list, tol: float = _TOLERANCE, ignore_order: bool = False) -> bool: + """ + Compara dois vetores (colunas transpostas) elemento a elemento, + seguindo a lógica oficial do Spider 2.0. + + - Aceita tolerância absoluta ``tol`` para pares numéricos. + - Trata ``pd.isna`` como iguais entre si. + - Se ``ignore_order`` estiver ativado, ordena ambos os vetores antes + de comparar. + """ + if ignore_order: + v1 = sorted(v1, key=lambda x: (x is None, str(x), isinstance(x, (int, float)))) + v2 = sorted(v2, key=lambda x: (x is None, str(x), isinstance(x, (int, float)))) + + if len(v1) != len(v2): + return False + + for a, b in zip(v1, v2): + if pd.isna(a) and pd.isna(b): + continue + elif isinstance(a, (int, float)) and isinstance(b, (int, float)): + if not math.isclose(float(a), float(b), abs_tol=tol): + return False + elif a != b: + return False + return True + + +def _results_to_dataframe(results: list[dict[str, Any]]) -> pd.DataFrame: + """Converte list[dict] (formato do SpiderQueryExecutor) para DataFrame.""" + if not results: + return pd.DataFrame() + return pd.DataFrame(results) + + +# --------------------------------------------------------------------------- +# Comparação principal — Spider 2.0 +# --------------------------------------------------------------------------- +def compare_pandas_table( + pred: pd.DataFrame, + gold: pd.DataFrame, + condition_cols: list[int] | None = None, + ignore_order: bool = False, +) -> int: + """ + Compara pred vs gold seguindo a métrica oficial do Spider 2.0. + + Para cada coluna do gold (opcionalmente filtrada por ``condition_cols``), + verifica se **alguma** coluna do pred é equivalente (dentro de tolerância + numérica e tratando NaN). + + Args: + pred: DataFrame com resultados do agente. + gold: DataFrame com resultados esperados. + condition_cols: Índices das colunas do gold a avaliar. + Se ``None`` ou vazio, avalia todas. + ignore_order: Se True, ordena os valores de cada coluna antes + de comparar (útil quando não há ORDER BY). + + Returns: + 1 se todas as colunas do gold foram encontradas no pred, 0 caso contrário. + """ + if condition_cols: + gold_cols = gold.iloc[:, condition_cols] + else: + gold_cols = gold + + t_gold_list = gold_cols.transpose().values.tolist() + t_pred_list = pred.transpose().values.tolist() + + for gold_vec in t_gold_list: + if not any(_vectors_match(gold_vec, pred_vec, ignore_order=ignore_order) + for pred_vec in t_pred_list): + return 0 + return 1 + + +def compare_multi_pandas_table( + pred: pd.DataFrame, + multi_gold: list[pd.DataFrame], + multi_condition_cols: list | None = None, + multi_ignore_order: bool = False, +) -> int: + """ + Compara pred contra *múltiplas* respostas ouro válidas (Spider 2.0). + + Retorna 1 se o pred bater com **pelo menos uma** das respostas ouro. + + Args: + pred: DataFrame com resultados do agente. + multi_gold: Lista de DataFrames de respostas ouro. + multi_condition_cols: Lista de listas de índices, uma por gold. + multi_ignore_order: Se True, aplica ignore_order em todas. + Returns: + 1 se match com alguma resposta ouro, 0 caso contrário. + """ + if ( + multi_condition_cols is None + or multi_condition_cols == [] + or multi_condition_cols == [[]] + or multi_condition_cols == [None] + ): + multi_condition_cols = [[] for _ in range(len(multi_gold))] + elif len(multi_gold) > 1 and not all(isinstance(s, list) for s in multi_condition_cols): + multi_condition_cols = [multi_condition_cols for _ in range(len(multi_gold))] + + for i, gold in enumerate(multi_gold): + if compare_pandas_table(pred, gold, multi_condition_cols[i], multi_ignore_order): + return 1 + return 0 + + +# --------------------------------------------------------------------------- +# Funções de interface pública (mantêm assinatura list[dict]) +# --------------------------------------------------------------------------- def results_exact_match( results_gold: list[dict[str, Any]], results_agent: list[dict[str, Any]], + ignore_order: bool = True, ) -> bool: """ - Compara se dois conjuntos de resultados são iguais baseando-se APENAS nos valores. - Ignora os nomes das colunas e a ordem das linhas. + Compara se dois conjuntos de resultados são equivalentes usando + a métrica oficial do Spider 2.0 (column-level matching). + + Ignora nomes de colunas — apenas os *valores* das colunas importam. + Usa tolerância de 1e-2 para números e trata None/NaN como iguais. + + Args: + results_gold: Resultados da query ouro (list[dict]). + results_agent: Resultados da query do agente (list[dict]). + ignore_order: Se True (padrão), a ordem das linhas é ignorada. + + Returns: + True se todas as colunas do gold foram encontradas no pred. """ - # Se não têm o mesmo número de linhas, já é False - if len(results_gold) != len(results_agent): - return False - - # Se as duas listas vierem vazias (0 linhas), é True - if not results_gold: + # Ambos vazios + if not results_gold and not results_agent: return True - def extract_values_to_numpy(results: list[dict[str, Any]]) -> np.ndarray: - matrix = [] - for row in results: - # Pega APENAS os valores, ignora as chaves - # Converte tudo para string (evita falsos negativos entre 0 inteiro e 0.0 float) - row_values = [str(v) if v is not None else "NULL" for v in row.values()] - matrix.append(row_values) - - # Converte a matriz nativa do Python para um Array NumPy - arr = np.array(matrix) - - # Como as queries podem retornar as linhas em ordens diferentes (se não houver ORDER BY), - # precisamos ordenar as linhas do array numpy lexograficamente para uma comparação justa. - # np.lexsort ordena pelas colunas, da última para a primeira, então passamos transposto e invertido - sorted_indices = np.lexsort(arr.T[::-1]) - return arr[sorted_indices] - - # Extrai, processa e ordena os arrays - gold_array = extract_values_to_numpy(results_gold) - agent_array = extract_values_to_numpy(results_agent) - - # np.array_equal compara a estrutura (dimensões) e o conteúdo. - # Usamos bool() para garantir que retorne um booleano nativo do Python e não um np.bool_ - return bool(np.array_equal(gold_array, agent_array)) + # Um vazio e outro não + if not results_gold or not results_agent: + return False + + gold_df = _results_to_dataframe(results_gold) + pred_df = _results_to_dataframe(results_agent) + + # Número de linhas diferente → impossível match + if len(gold_df) != len(pred_df): + return False + + return compare_pandas_table(pred_df, gold_df, ignore_order=ignore_order) == 1 def results_f1_score( results_gold: list[dict[str, Any]], results_agent: list[dict[str, Any]], + ignore_order: bool = True, ) -> dict[str, float]: """ - Calcula Precision, Recall e F1 row-level entre resultados gold e agent. + Calcula Precision, Recall e F1 a nível de coluna, seguindo a lógica + do Spider 2.0. + + Para cada coluna do gold, verifica se alguma coluna do pred é + equivalente (tolerância numérica de 1e-2, NaN-aware). - Cada linha é convertida em uma tupla canônica (valores ordenados, como string) - e tratada como membro de um multiset (Counter). Isso permite medir parcialmente - quantas linhas o agente acertou, mesmo que não tenha acertado todas. + - Precision: das colunas que o agente retornou, quantas batem com + alguma coluna do gold? + - Recall: das colunas do gold, quantas foram cobertas pelo agente? + - F1: média harmônica de precision e recall. - - Precision: das linhas que o agente retornou, quantas estão no gold? - - Recall: das linhas do gold, quantas o agente retornou? - - F1: média harmônica de precision e recall. + Quando o número de linhas difere, as linhas excedentes são tratadas + como colunas não-matching, penalizando precision ou recall conforme + o caso. Args: - results_gold: Resultados da query ouro - results_agent: Resultados da query do agente + results_gold: Resultados da query ouro (list[dict]). + results_agent: Resultados da query do agente (list[dict]). + ignore_order: Se True (padrão), a ordem das linhas é ignorada. Returns: - Dict com chaves: precision, recall, f1 (floats de 0 a 1) + Dict com chaves: precision, recall, f1 (floats de 0 a 1). """ - def _row_to_canonical(row: dict[str, Any]) -> tuple: - """Converte uma linha em tupla canônica de valores (ordenados, stringificados).""" - values = [str(v) if v is not None else "NULL" for v in row.values()] - return tuple(sorted(values)) - # Ambos vazios → match perfeito if not results_gold and not results_agent: return {"precision": 1.0, "recall": 1.0, "f1": 1.0} @@ -193,25 +286,62 @@ def _row_to_canonical(row: dict[str, Any]) -> tuple: if not results_agent: return {"precision": 1.0, "recall": 0.0, "f1": 0.0} - gold_bag = Counter(_row_to_canonical(r) for r in results_gold) - agent_bag = Counter(_row_to_canonical(r) for r in results_agent) - - # Interseção: min(count_gold, count_agent) para cada tupla - true_positives = sum((gold_bag & agent_bag).values()) - total_agent = sum(agent_bag.values()) - total_gold = sum(gold_bag.values()) - - precision = true_positives / total_agent if total_agent > 0 else 0.0 - recall = true_positives / total_gold if total_gold > 0 else 0.0 + gold_df = _results_to_dataframe(results_gold) + pred_df = _results_to_dataframe(results_agent) + + # Número de linhas diferente — padroniza para o mesmo tamanho usando NaN + # para viabilizar a comparação coluna-a-coluna. As colunas onde os NaNs + # extras forem injetados não vão bater, o que penaliza corretamente. + max_rows = max(len(gold_df), len(pred_df)) + if len(gold_df) < max_rows: + padding = pd.DataFrame( + [[None] * gold_df.shape[1]] * (max_rows - len(gold_df)), + columns=gold_df.columns, + ) + gold_df = pd.concat([gold_df, padding], ignore_index=True) + if len(pred_df) < max_rows: + padding = pd.DataFrame( + [[None] * pred_df.shape[1]] * (max_rows - len(pred_df)), + columns=pred_df.columns, + ) + pred_df = pd.concat([pred_df, padding], ignore_index=True) + + t_gold_list = gold_df.transpose().values.tolist() + t_pred_list = pred_df.transpose().values.tolist() + + total_gold = len(t_gold_list) + total_pred = len(t_pred_list) + + # Recall: quantas colunas do gold batem com alguma coluna do pred? + gold_matched = sum( + 1 for g in t_gold_list + if any(_vectors_match(g, p, ignore_order=ignore_order) for p in t_pred_list) + ) + + # Precision: quantas colunas do pred batem com alguma coluna do gold? + pred_matched = sum( + 1 for p in t_pred_list + if any(_vectors_match(p, g, ignore_order=ignore_order) for g in t_gold_list) + ) + + recall = gold_matched / total_gold if total_gold > 0 else 0.0 + precision = pred_matched / total_pred if total_pred > 0 else 0.0 if precision + recall == 0: f1 = 0.0 else: f1 = 2 * (precision * recall) / (precision + recall) - return {"precision": round(precision, 4), "recall": round(recall, 4), "f1": round(f1, 4)} + return { + "precision": round(precision, 4), + "recall": round(recall, 4), + "f1": round(f1, 4), + } +# --------------------------------------------------------------------------- +# Construtor de linha para CSV de avaliação +# --------------------------------------------------------------------------- def build_comparison_row( id_exemplo: int, tentativa_numero: int, @@ -228,6 +358,14 @@ def build_comparison_row( resultado_f1: float = 0.0, resultado_precision: float = 0.0, resultado_recall: float = 0.0, + tokens_input: int = 0, + tokens_output: int = 0, + tokens_total: int = 0, + viz_acionado: bool = False, + viz_sucesso: bool = False, + resultado_exato_match_1a_tentativa: bool | None = None, + resultado_f1_1a_tentativa: float = 0.0, + query_1a_tentativa: str = "", ) -> dict[str, Any]: """ Constrói uma linha para o CSV de avaliação. @@ -248,9 +386,17 @@ def build_comparison_row( resultado_f1: F1 score row-level (0-1) resultado_precision: Precision row-level (0-1) resultado_recall: Recall row-level (0-1) + tokens_input: Total de tokens de entrada consumidos + tokens_output: Total de tokens de saída consumidos + tokens_total: Total de tokens consumidos + viz_acionado: Se o agente de visualização foi acionado + viz_sucesso: Se o gráfico foi gerado com sucesso + resultado_exato_match_1a_tentativa: Exact match da 1ª tentativa (para ablação) + resultado_f1_1a_tentativa: F1 score da 1ª tentativa (para ablação) + query_1a_tentativa: SQL gerada na 1ª tentativa Returns: - Dict com 15 chaves para CSV + Dict com chaves para CSV """ return { "id_exemplo": id_exemplo, @@ -268,5 +414,14 @@ def build_comparison_row( "resultado_f1": resultado_f1, "resultado_precision": resultado_precision, "resultado_recall": resultado_recall, + "tokens_input": tokens_input, + "tokens_output": tokens_output, + "tokens_total": tokens_total, + "viz_acionado": viz_acionado, + "viz_sucesso": viz_sucesso, + "resultado_exato_match_1a_tentativa": resultado_exato_match_1a_tentativa if resultado_exato_match_1a_tentativa is not None else "", + "resultado_f1_1a_tentativa": resultado_f1_1a_tentativa, + "query_1a_tentativa": query_1a_tentativa, } + diff --git a/tests/test_componentes.py b/tests/test_componentes.py index ffc0e15..ced1bb3 100644 --- a/tests/test_componentes.py +++ b/tests/test_componentes.py @@ -208,7 +208,9 @@ def test_roteador_sandbox_muitas_tentativas(): from text_to_insight.routers.edges import roteador_sandbox estado = {"status": "exec_erro", "tentativas_loop": 5} - assert roteador_sandbox(estado) == "planejador" + ## alteração: com muitas tentativas, o roteador deve direcionar para "critico" para forçar o fim do loop, não para "planejador" + #asert roteador_sandbox(estado) == "planejador" + assert roteador_sandbox(estado) == "critico" def test_roteador_planejador_sem_schema(): diff --git a/tests/test_main_engine_integracao.py b/tests/test_main_engine_integracao.py index 04e533a..525db7a 100644 --- a/tests/test_main_engine_integracao.py +++ b/tests/test_main_engine_integracao.py @@ -77,9 +77,8 @@ def update_state(self, config, values): state = self._thread_state(thread_id) state["values"].update(values) - class _FakeGraph: - def __init__(self, api_key, model, hitl=True): + def __init__(self, api_key, model, hitl=True, **kwargs): self.grafo_text_to_insight = _FakeCompiledGraph() diff --git a/text_to_insight/InsightEngine.py b/text_to_insight/InsightEngine.py index 474bf2c..ffdf210 100644 --- a/text_to_insight/InsightEngine.py +++ b/text_to_insight/InsightEngine.py @@ -16,18 +16,20 @@ class InsightEngine: - `resume(...)` continua uma consulta que ficou pausada em HITL. """ - def __init__(self, api_key: str, model: str, db_path: str, hitl: bool = False, show_output: bool = False): + def __init__(self, api_key: str, model: str, db_path: str, hitl: bool = False, show_output: bool = False, enable_graphs: bool = True): self._hitl_ativado = hitl # `show_output` controla se a engine imprime o resultado final no terminal. # Em cenários com CLI, normalmente deixamos False para evitar saída duplicada. self._show_output = show_output + self._enable_graphs = enable_graphs self._model = model self._db_path = db_path # O grafo compila os nós/roteadores e guarda memória por thread_id. - self._grafo = Graph(api_key=api_key, model=self._model, hitl=self._hitl_ativado) + self._grafo = Graph(api_key=api_key, model=self._model, hitl=self._hitl_ativado, enable_graphs=self._enable_graphs) print(f"[CONFIG] HITL: {'ATIVADO' if self._hitl_ativado else 'DESATIVADO'}") print(f"[CONFIG] SHOW_OUTPUT: {'ATIVADO' if self._show_output else 'DESATIVADO'}") + print(f"[CONFIG] GRÁFICOS: {'ATIVADO' if self._enable_graphs else 'DESATIVADO'}") def _config(self, thread_id: str) -> dict[str, Any]: # O LangGraph usa esse bloco "configurable" para identificar a conversa. diff --git a/text_to_insight/graph.py b/text_to_insight/graph.py index 0c938fc..b66647e 100644 --- a/text_to_insight/graph.py +++ b/text_to_insight/graph.py @@ -7,7 +7,10 @@ 3. Agente de Código: Gera SQL (LLM) 4. Executor: Executa SQL no banco real 5. Crítico: Avalia qualidade (LLM) -6. Roteadores condicionais decidem continuação ou conclusão +6. Salvar CSV: Exporta resultado para CSV +7. Roteador Gráfico: Decide se gera visualização (LLM) +8. Gerador Gráfico: Gera gráfico matplotlib (LLM + subprocess) +9. Resposta: Gera resposta em linguagem natural (LLM) """ from functools import partial @@ -23,8 +26,10 @@ nos_nodo_sandbox, nos_nodo_critico, nos_nodo_resposta, + nos_nodo_salvar_csv, + nos_nodo_gerador_grafico, ) -from .routers import roteador_sandbox, roteador_planejador +from .routers import roteador_sandbox, roteador_planejador, roteador_grafico from .model_selection import get_model def nos_nodo_espera_humana(estado: EstadoTextToInsight): @@ -32,9 +37,10 @@ def nos_nodo_espera_humana(estado: EstadoTextToInsight): return estado class Graph: - def __init__(self, api_key: str, model: str, hitl: bool = True): + def __init__(self, api_key: str, model: str, hitl: bool = True, enable_graphs: bool = True): self.llm = get_model(model, api_key) self.memory = MemorySaver() + self.enable_graphs = enable_graphs self.grafo_text_to_insight = self._compilar_grafo(hitl) def _construir_grafo_text_to_insight(self, hitl: bool) -> StateGraph: @@ -50,6 +56,8 @@ def _construir_grafo_text_to_insight(self, hitl: bool) -> StateGraph: construtor_grafo.add_node("agente_codigo", partial(nos_nodo_agente_codigo, llm=self.llm)) construtor_grafo.add_node("sandbox", nos_nodo_sandbox) construtor_grafo.add_node("critico", partial(nos_nodo_critico, llm=self.llm)) + construtor_grafo.add_node("salvar_csv", nos_nodo_salvar_csv) + construtor_grafo.add_node("gerador_grafico", partial(nos_nodo_gerador_grafico, llm=self.llm)) construtor_grafo.add_node("resposta", partial(nos_nodo_resposta, llm=self.llm)) # 2. ARESTAS FIXAS @@ -57,7 +65,10 @@ def _construir_grafo_text_to_insight(self, hitl: bool) -> StateGraph: construtor_grafo.add_edge("espera_humana", "planejador") construtor_grafo.add_edge("esquema", "planejador") construtor_grafo.add_edge("agente_codigo", "sandbox") - + + # Gerador de gráfico sempre vai para resposta (sucesso ou falha) + construtor_grafo.add_edge("gerador_grafico", "resposta") + # 3. ARESTAS CONDICIONAIS construtor_grafo.add_conditional_edges( "sandbox", @@ -86,13 +97,16 @@ def _construir_grafo_text_to_insight(self, hitl: bool) -> StateGraph: def roteador_critico(estado: EstadoTextToInsight) -> str: status = estado.get("status", "") tentativas = estado.get("tentativas_loop", 0) - # Se aprovado, enviar para nó de resposta + + next_step = "salvar_csv" if self.enable_graphs else "resposta" + + # Se aprovado, enviar para proximo passo if status == "aprovado": - return "resposta" + return next_step # Se atingiu limite de tentativas, encerrar mesmo reprovado if tentativas >= MAX_TENTATIVAS_CRITICO: - print(f"[ROTEADOR_CRITICO] Limite de {MAX_TENTATIVAS_CRITICO} tentativas atingido → resposta (forçado)") - return "resposta" + print(f"[ROTEADOR_CRITICO] Limite de {MAX_TENTATIVAS_CRITICO} tentativas atingido → {next_step} (forçado)") + return next_step return "planejador" construtor_grafo.add_conditional_edges( @@ -100,6 +114,17 @@ def roteador_critico(estado: EstadoTextToInsight) -> str: roteador_critico, { "planejador": "planejador", + "salvar_csv": "salvar_csv", + "resposta": "resposta", + } + ) + + # Após salvar CSV, o roteador de gráfico decide se gera visualização + construtor_grafo.add_conditional_edges( + "salvar_csv", + partial(roteador_grafico, llm=self.llm), + { + "gerador_grafico": "gerador_grafico", "resposta": "resposta", } ) @@ -136,3 +161,4 @@ def stream(self, estado: EstadoTextToInsight, config: dict = None): if config is None: config = {} return self.grafo_text_to_insight.stream(estado, config) + diff --git a/text_to_insight/nodes/__init__.py b/text_to_insight/nodes/__init__.py index d099077..437ad6a 100644 --- a/text_to_insight/nodes/__init__.py +++ b/text_to_insight/nodes/__init__.py @@ -10,6 +10,8 @@ - code_agent: Gera código Python baseado no plano. - sandbox: Executa o código de forma segura. - critic: Avalia a saída e fornece feedback. + - csv_saver: Salva resultado da query em CSV. + - graph_generator: Gera gráfico matplotlib a partir do CSV. """ from .planner import nos_nodo_planejador @@ -18,6 +20,8 @@ from .sandbox import nos_nodo_sandbox from .critic import nos_nodo_critico from .response import nos_nodo_resposta +from .csv_saver import nos_nodo_salvar_csv +from .graph_generator import nos_nodo_gerador_grafico __all__ = [ "nos_nodo_planejador", @@ -26,4 +30,7 @@ "nos_nodo_sandbox", "nos_nodo_critico", "nos_nodo_resposta", + "nos_nodo_salvar_csv", + "nos_nodo_gerador_grafico", ] + diff --git a/text_to_insight/nodes/code_agent/code_sql.py b/text_to_insight/nodes/code_agent/code_sql.py index e737e23..33b1977 100644 --- a/text_to_insight/nodes/code_agent/code_sql.py +++ b/text_to_insight/nodes/code_agent/code_sql.py @@ -52,13 +52,17 @@ def validar_sql_segura(sql: str) -> tuple[bool, str]: return True, "" +import time + def executar_sql_sqlite( db_path: str, sql: str, limite_preview: int = 5, + timeout_segundos: float = 15.0, ) -> dict[str, Any]: """ Executa SQL validada em SQLite modo read-only e retorna resultado estruturado. + Possui um timeout embutido para evitar queries infinitas (ex: cross joins enormes). """ ok, erro_validacao = validar_sql_segura(sql) if not ok: @@ -86,6 +90,17 @@ def executar_sql_sqlite( try: conn = sqlite3.connect(f"file:{caminho}?mode=ro", uri=True) conn.row_factory = sqlite3.Row + + # Define um handler para monitorar o tempo de execução e abortar se passar do limite + start_time = time.time() + def _progress_handler(): + if time.time() - start_time > timeout_segundos: + return 1 # abortar query + return 0 + + # Invoca a cada 1000 instruções da máquina virtual do SQLite + conn.set_progress_handler(_progress_handler, 1000) + try: cur = conn.cursor() cur.execute(sql) @@ -107,6 +122,18 @@ def executar_sql_sqlite( } finally: conn.close() + except sqlite3.OperationalError as e: + erro_msg = str(e) + if "interrupted" in erro_msg.lower(): + erro_msg = f"Query abortada por timeout (> {timeout_segundos}s)." + return { + "ok": False, + "erro_execucao": f"Falha ao executar SQL: {erro_msg}", + "linhas_resultado_preview": [], + "linhas_resultado_completo": [], + "total_linhas_resultado": 0, + "saida_terminal": f"[SANDBOX] Erro de execucao: {erro_msg}", + } except Exception as e: return { "ok": False, diff --git a/text_to_insight/nodes/critic.py b/text_to_insight/nodes/critic.py index 6f103b2..0acacd6 100644 --- a/text_to_insight/nodes/critic.py +++ b/text_to_insight/nodes/critic.py @@ -18,6 +18,9 @@ === PERGUNTA DO USUÁRIO === {pergunta} +=== SCHEMA DO BANCO === +{schema} + === CONVERSA COM O AGENTE (se houver) === {conversa_previa} @@ -36,37 +39,77 @@ === TENTATIVAS ANTERIORES === {historico_tentativas_section} -Avalie: -1. A SQL responde à pergunta do usuário? -2. Os resultados fazem sentido? -3. Há algum erro lógico ou de interpretação? -4. Se houve tentativas anteriores, verifique se os mesmos problemas persistem. - -Ao avaliar, priorize utilidade prática e correção semântica da resposta, -não perfeição formal. - -Diferenças de formato, representação ou precisão que não alterem -substancialmente a resposta NÃO devem causar reprovação. - -Exemplos de casos que normalmente devem ser APROVADOS: -- Ano médio retornado como float em vez de inteiro/data -- Pequenas diferenças de arredondamento -- Colunas extras irrelevantes -- Nomes/aliases diferentes -- Resultado parcialmente correto mas ainda útil -- Agregações corretas com precisão numérica diferente da esperada - -REPROVE apenas quando houver falha material, por exemplo: -- A query responde outra pergunta -- O dado necessário para responder não está presentes -- Filtros importantes estão errados ou ausentes -- JOIN incorreto altera significativamente os resultados +=== EXEMPLOS DE AVALIAÇÃO === + +-- EXEMPLO 1: REPROVADO (escopo incompleto) -- +Pergunta: "Which airport has the least number of flights?" +SQL: SELECT SourceAirport FROM flights GROUP BY SourceAirport ORDER BY COUNT(*) ASC LIMIT 1 +Resultado: [('AID',)] +VEREDITO: REPROVADO +Razão: A query conta apenas voos com partida (SourceAirport) e ignora voos com chegada (DestAirport). +O escopo da pergunta é "flights" em geral — a query responde a uma pergunta diferente. + +-- EXEMPLO 2: REPROVADO (erro semântico: MIN vs MAX) -- +Pergunta: "Which Asian countries have a population larger than any country in Africa?" +SQL: SELECT Name FROM country WHERE Continent='Asia' AND Population > (SELECT MAX(Population) FROM country WHERE Continent='Africa') +Resultado: [] (vazio) +VEREDITO: REPROVADO +Razão: "Larger than any country in Africa" significa maior que pelo menos um país africano (MIN), +não maior que todos os países africanos (MAX). A lógica está semanticamente errada. + +-- EXEMPLO 3: REPROVADO (resultado vazio suspeito) -- +Pergunta: "Find the last name of students who live in North Carolina and are not enrolled in any degree." +SQL: SELECT last_name FROM Students WHERE state_province_county = 'North Carolina' AND ... +Resultado: [] (vazio) +VEREDITO: REPROVADO +Razão: Resultado vazio quando a pergunta espera dados reais é suspeito. Verifique se o filtro +de string corresponde exatamente ao valor no banco (ex: 'NorthCarolina' vs 'North Carolina'). + +-- EXEMPLO 4: REPROVADO (JOIN incorreto muda o que está sendo contado) -- +Pergunta: "Find the name of makers that produced some cars in 1970." +SQL: SELECT DISTINCT Maker FROM car_makers JOIN car_names ON car_makers.Id = car_names.MakeId JOIN cars_data ON car_names.MakeId = cars_data.Id WHERE cars_data.Year = 1970 +Resultado: [('chevrolet',), ('buick',)] +VEREDITO: REPROVADO +Razão: O JOIN usa car_names.MakeId para conectar a cars_data, mas cars_data.Id refere-se +ao ID do carro, não do fabricante. O caminho correto seria via model_list. Os resultados +parecem plausíveis mas derivam de uma junção incorreta. + +-- EXEMPLO 5: APROVADO (formato diferente, resposta correta) -- +Pergunta: "On average, when were the transcripts printed?" +SQL: SELECT AVG(transcript_date) AS average_transcript_date FROM Transcripts +Resultado: [('1989.9333333333334',)] +VEREDITO: APROVADO +Razão: O resultado é um número que representa a média das datas (formato numérico do SQLite). +Embora não seja uma data formatada, responde corretamente à pergunta. Diferença de +representação não é motivo de reprovação. + +-- EXEMPLO 6: APROVADO (query mais simples que o gold, resultado equivalente) -- +Pergunta: "Which model of car has the minimum horsepower?" +SQL: SELECT Model FROM car_names JOIN cars_data ON car_names.MakeId = cars_data.Id WHERE Horsepower = (SELECT MIN(Horsepower) FROM cars_data) LIMIT 1 +Resultado: [('triumph',)] +VEREDITO: APROVADO +Razão: A query retorna corretamente o modelo com menor potência. O LIMIT 1 garante unicidade +e o resultado é semanticamente correto. Aprovar. + +=== CRITÉRIOS DE AVALIAÇÃO === + +REPROVE quando houver: +- Escopo incompleto: query cobre apenas parte do que a pergunta pede +- Erro semântico: lógica correta na forma mas errada no significado (MIN vs MAX, ANY vs ALL) +- JOIN incorreto que altera os dados sendo agregados ou filtrados +- Resultado vazio quando a pergunta claramente espera dados +- Filtro com valor literal diferente do que está no banco - Métrica errada (SUM vs AVG, COUNT vs COUNT DISTINCT, etc.) -- Resultado vazio inesperado -- Erro SQL ou inconsistência lógica grave +- Erro de execução SQL + +APROVE quando: +- O resultado responde à pergunta, mesmo com formato ou representação diferente +- Há colunas extras que não prejudicam a resposta +- A precisão numérica difere mas o valor está correto +- A query é mais simples que o esperado mas semanticamente equivalente -Considere o custo de retentativas. Em caso de dúvida entre APROVADO -e REPROVADO, prefira APROVADO se a resposta ainda for útil para o usuário. Leve em consideração que ainda tem um agente depois de você que irá interpretar o resultado da query e criar uma resposta em linguagem natural. +Avalie com rigor semântico. Resultados que parecem plausíveis mas derivam de lógica +incorreta devem ser reprovados. Não presuma que uma query bem-formada está correta. Responda no formato: VEREDITO: APROVADO ou REPROVADO @@ -101,6 +144,7 @@ def nos_nodo_critico(estado: EstadoTextToInsight, llm: ChatGoogleGenerativeAI) - or estado.get("pergunta_usuario", "") ) sql = estado.get("sql_gerada", "") + schema = estado.get("contexto_schema", "") preview = estado.get("linhas_resultado_preview", []) total = estado.get("total_linhas_resultado", 0) saida = estado.get("saida_terminal", "") @@ -128,6 +172,7 @@ def nos_nodo_critico(estado: EstadoTextToInsight, llm: ChatGoogleGenerativeAI) - prompt = PROMPT_CRITIC.format( pergunta=pergunta, + schema=schema, sql=sql, status_exec=status_exec, conversa_previa=conversa_previa if conversa_previa else "Nenhuma", diff --git a/text_to_insight/nodes/csv_saver.py b/text_to_insight/nodes/csv_saver.py new file mode 100644 index 0000000..1ad022a --- /dev/null +++ b/text_to_insight/nodes/csv_saver.py @@ -0,0 +1,42 @@ +""" +Nó Salvar CSV do grafo de agentes Text-to-Insight. + +Responsabilidade única: salvar o resultado da query em um CSV em local padrão +e armazenar o caminho no estado para uso posterior (ex: geração de gráficos). +""" + +import csv +from datetime import datetime +from pathlib import Path + +from ..state import EstadoTextToInsight + +# Diretório padrão para salvar os resultados CSV +RESULTS_DIR = Path(__file__).parent.parent.parent / "results" + + +def nos_nodo_salvar_csv(estado: EstadoTextToInsight) -> dict: + """ + Nó Salvar CSV: exporta linhas_resultado_completo para um arquivo CSV + e registra o caminho no estado. + """ + linhas = estado.get("linhas_resultado_completo", []) or [] + + if not linhas: + print("[SALVAR_CSV] Nenhuma linha para exportar — pulando.") + return {"caminho_csv_resultado": ""} + + RESULTS_DIR.mkdir(exist_ok=True) + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + csv_path = RESULTS_DIR / f"query_{timestamp}.csv" + + with csv_path.open("w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=list(linhas[0].keys())) + writer.writeheader() + writer.writerows(linhas) + + caminho_absoluto = str(csv_path.resolve()) + total = len(linhas) + print(f"[SALVAR_CSV] {total} linhas salvas em: {caminho_absoluto}") + + return {"caminho_csv_resultado": caminho_absoluto} diff --git a/text_to_insight/nodes/graph_generator.py b/text_to_insight/nodes/graph_generator.py new file mode 100644 index 0000000..c29fb2c --- /dev/null +++ b/text_to_insight/nodes/graph_generator.py @@ -0,0 +1,184 @@ +""" +Nó Gerador de Gráficos do grafo de agentes Text-to-Insight. + +Responsabilidade: usar o LLM para gerar código Python (matplotlib) que +visualize os dados do CSV de resultado da query, executar o código gerado, +e armazenar o caminho da imagem resultante no estado. +""" + +import re +import subprocess +import sys +import tempfile +from datetime import datetime +from pathlib import Path + +from langchain_google_genai import ChatGoogleGenerativeAI + +from ..state import EstadoTextToInsight +from ..utils import extrair_tokens + +# Diretório padrão para salvar os gráficos gerados +GRAPHS_DIR = Path(__file__).parent.parent.parent / "graphs" + +PROMPT_GRAPH_GENERATOR = """Você é um especialista em visualização de dados com Python e matplotlib. + +Sua tarefa: gerar APENAS o bloco de código Python que cria um gráfico matplotlib +para visualizar os dados descritos abaixo. + +=== PERGUNTA DO USUÁRIO === +{pergunta} + +=== SQL EXECUTADA === +{sql} + +=== COLUNAS DO RESULTADO === +{colunas} + +=== AMOSTRA DOS DADOS (primeiras linhas do CSV) === +{amostra} + +=== TOTAL DE LINHAS === +{total_linhas} + +Regras: +- O código será inserido dentro de um script que já importou pandas, matplotlib e já + carregou o DataFrame com `df = pd.read_csv(...)`. +- Você NÃO deve importar nada nem carregar dados. Apenas use a variável `df`. +- Use `plt` (já importado como `import matplotlib.pyplot as plt`). +- Escolha o tipo de gráfico mais adequado à pergunta e aos dados (barras, linhas, pizza, dispersão, etc.). +- Adicione título, labels nos eixos, e legenda quando relevante. +- Use cores visualmente agradáveis. +- Se necessário, rotacione labels do eixo X para legibilidade. +- O gráfico será salvo automaticamente, você NÃO deve chamar plt.savefig() nem plt.show(). +- Responda APENAS com o código Python puro, sem markdown, sem explicações. +- Se os dados tiverem muitas categorias (>15), mostre apenas o top 10-15 mais relevantes. + +Não faça um gráfico basico visualmente, faça ele bonito, use cores agradaveis e que ajude o usuario a entender os dados. Tenta fazer algo com cara profissional! Feito por um analista apresentando para um grande cliente que julga o livro pela capa. +""" + + +def _extrair_codigo_python(texto: str) -> str: + """Extrai código Python puro da resposta do LLM, removendo markdown.""" + match = re.search(r"```(?:python)?\s*\n?(.*?)```", texto, re.DOTALL) + if match: + return match.group(1).strip() + return texto.strip() + + +def _construir_script(csv_path: str, output_path: str, codigo_visualizacao: str) -> str: + """Monta o script Python completo que será executado.""" + return f'''import pandas as pd +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt + +df = pd.read_csv("{csv_path}") + +{codigo_visualizacao} + +plt.tight_layout() +plt.savefig("{output_path}", dpi=150, bbox_inches='tight') +plt.close() +print("GRAPH_OK") +''' + + +def _executar_script(script: str) -> tuple[bool, str]: + """Executa o script Python em um subprocesso e retorna (sucesso, saída/erro).""" + with tempfile.NamedTemporaryFile( + mode="w", suffix=".py", delete=False, encoding="utf-8" + ) as tmp: + tmp.write(script) + tmp_path = tmp.name + + try: + resultado = subprocess.run( + [sys.executable, tmp_path], + capture_output=True, + text=True, + timeout=30, + ) + stdout = resultado.stdout.strip() + stderr = resultado.stderr.strip() + + if resultado.returncode == 0 and "GRAPH_OK" in stdout: + return True, stdout + else: + erro = stderr if stderr else stdout + return False, erro + except subprocess.TimeoutExpired: + return False, "Timeout: o script de geração do gráfico excedeu 30 segundos." + except Exception as e: + return False, f"Erro ao executar script: {e}" + finally: + Path(tmp_path).unlink(missing_ok=True) + + +def nos_nodo_gerador_grafico(estado: EstadoTextToInsight, llm: ChatGoogleGenerativeAI) -> dict: + """ + Nó Gerador de Gráficos: usa o LLM para gerar código matplotlib e o executa. + """ + csv_path = estado.get("caminho_csv_resultado", "") + pergunta = estado.get("pergunta_usuario", "") + sql = estado.get("sql_gerada", "") + preview = estado.get("linhas_resultado_preview", []) + total = estado.get("total_linhas_resultado", 0) + + print("[GERADOR_GRAFICO] Iniciando geração de gráfico...") + + if not csv_path or not Path(csv_path).exists(): + print("[GERADOR_GRAFICO] CSV não encontrado — pulando geração.") + return {"grafico_gerado": False, "caminho_grafico": ""} + + # Extrair colunas e amostra para o prompt + colunas = list(preview[0].keys()) if preview and isinstance(preview[0], dict) else [] + amostra_str = str(preview[:5]) if preview else "(vazio)" + + prompt = PROMPT_GRAPH_GENERATOR.format( + pergunta=pergunta, + sql=sql, + colunas=", ".join(colunas) if colunas else "(desconhecidas)", + amostra=amostra_str, + total_linhas=total, + ) + + # Chamada ao LLM para gerar o código de visualização + resposta = llm.invoke(prompt) + codigo_bruto = resposta.content.strip() + codigo_visualizacao = _extrair_codigo_python(codigo_bruto) + + print(f"[GERADOR_GRAFICO] Código gerado ({len(codigo_visualizacao)} bytes)") + + # Montar caminho de saída do gráfico + GRAPHS_DIR.mkdir(exist_ok=True) + timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + output_path = str((GRAPHS_DIR / f"grafico_{timestamp}.png").resolve()) + + # Montar e executar o script completo + script = _construir_script(csv_path, output_path, codigo_visualizacao) + print(script) + print(f"[GERADOR_GRAFICO] Executando script...") + + sucesso, saida = _executar_script(script) + + in_tokens, out_tokens, total_tokens = extrair_tokens(resposta) + + if sucesso and Path(output_path).exists(): + print(f"[GERADOR_GRAFICO] ✓ Gráfico salvo em: {output_path}") + return { + "grafico_gerado": True, + "caminho_grafico": output_path, + "tokens_input": in_tokens, + "tokens_output": out_tokens, + "tokens_total": total_tokens, + } + else: + print(f"[GERADOR_GRAFICO] ✗ Falha na geração do gráfico: {saida[:200]}") + return { + "grafico_gerado": False, + "caminho_grafico": "", + "tokens_input": in_tokens, + "tokens_output": out_tokens, + "tokens_total": total_tokens, + } diff --git a/text_to_insight/nodes/response.py b/text_to_insight/nodes/response.py index cf0a403..22e7b2c 100644 --- a/text_to_insight/nodes/response.py +++ b/text_to_insight/nodes/response.py @@ -22,6 +22,7 @@ - Inclua um resumo do que os resultados indicam e, quando relevante, uma interpretação simples (por exemplo: totais, médias, top N, ausência de dados, etc.). - Seja claro sobre quaisquer limitações (por exemplo: amostra limitada de linhas). +- Se um gráfico foi gerado, mencione que um gráfico acompanha a resposta. - Responda em Português, no máximo 3-5 frases, sem mostrar a SQL completa nem blocos de código. Contexto: @@ -30,6 +31,7 @@ Total de linhas: {total} Amostra de resultados: {preview} Saída resumida: {saida} +Gráfico gerado: {grafico_info} Gere APENAS a resposta final para o usuário (sem títulos, sem marcas, sem explicações sobre o que você está fazendo). """ @@ -54,7 +56,12 @@ def nos_nodo_resposta(estado: EstadoTextToInsight, llm: ChatGoogleGenerativeAI) total = estado.get("total_linhas_resultado", None) saida = estado.get("saida_terminal", "") + grafico_gerado = estado.get("grafico_gerado", False) + caminho_grafico = estado.get("caminho_grafico", "") + print(f"[RESPOSTA] Executando nó de resposta — status atual: {status}") + if grafico_gerado: + print(f"[RESPOSTA] Gráfico disponível em: {caminho_grafico}") # Só gera resposta natural se o crítico aprovou if status != "aprovado": @@ -64,6 +71,7 @@ def nos_nodo_resposta(estado: EstadoTextToInsight, llm: ChatGoogleGenerativeAI) # Formata preview de forma compacta para o prompt preview_str = str(preview[:10]) if preview else "(sem amostra)" total_str = str(total) if total is not None else "desconhecido" + grafico_info = f"Sim — gráfico salvo em {caminho_grafico}" if grafico_gerado else "Não" # Durante execução de testes (pytest) evitamos invocar a API externa # para não depender de cassetes adicionais. Detectamos pytest através @@ -89,6 +97,7 @@ def nos_nodo_resposta(estado: EstadoTextToInsight, llm: ChatGoogleGenerativeAI) total=total_str, preview=preview_str, saida=(saida if saida else "Nenhuma saída resumida"), + grafico_info=grafico_info, ) try: diff --git a/text_to_insight/nodes/sandbox.py b/text_to_insight/nodes/sandbox.py index ed29409..a5dc89b 100644 --- a/text_to_insight/nodes/sandbox.py +++ b/text_to_insight/nodes/sandbox.py @@ -50,4 +50,5 @@ def nos_nodo_sandbox(estado: EstadoTextToInsight) -> dict: "saida_terminal": resultado["saida_terminal"], "erro_execucao": resultado["erro_execucao"], "status": "exec_erro", + "historico_tentativas": [{"sql": sql, "erro": resultado["erro_execucao"]}], } diff --git a/text_to_insight/routers/__init__.py b/text_to_insight/routers/__init__.py index 40591a1..d81850a 100644 --- a/text_to_insight/routers/__init__.py +++ b/text_to_insight/routers/__init__.py @@ -7,8 +7,10 @@ Roteadores disponíveis: - roteador_sandbox: Define o fluxo após execução do código - roteador_planejador: Define o fluxo após planejamento + - roteador_grafico: Define se gera gráfico ou vai direto para resposta """ -from .edges import roteador_sandbox, roteador_planejador +from .edges import roteador_sandbox, roteador_planejador, roteador_grafico + +__all__ = ["roteador_sandbox", "roteador_planejador", "roteador_grafico"] -__all__ = ["roteador_sandbox", "roteador_planejador"] diff --git a/text_to_insight/routers/edges.py b/text_to_insight/routers/edges.py index 673f780..475b2cb 100644 --- a/text_to_insight/routers/edges.py +++ b/text_to_insight/routers/edges.py @@ -6,7 +6,43 @@ """ from typing import Literal +from langchain_google_genai import ChatGoogleGenerativeAI + from ..state import EstadoTextToInsight +from ..utils import extrair_tokens + +PROMPT_ROTEADOR_GRAFICO = """Você é um assistente que decide se os resultados de uma consulta SQL +devem ser acompanhados de um gráfico (visualização). + +Analise a pergunta do usuário e as características dos dados retornados. + +=== PERGUNTA DO USUÁRIO === +{pergunta} + +=== COLUNAS DO RESULTADO === +{colunas} + +=== TOTAL DE LINHAS === +{total_linhas} + +=== AMOSTRA DOS DADOS === +{amostra} + +Um gráfico é útil quando: +- A pergunta envolve comparações entre categorias (ex: vendas por região) +- Há dados temporais ou tendências (ex: evolução ao longo dos meses) +- Há distribuições ou rankings (ex: top 10 produtos) +- Há agregações numéricas que se beneficiam de visualização +- O resultado tem mais de 1 linha com pelo menos uma coluna numérica + +Um gráfico NÃO é útil quando: +- O resultado é um único valor escalar (ex: total geral) +- A pergunta pede um dado específico pontual (ex: nome de um cliente) +- O resultado tem apenas 1 linha +- Não há colunas numéricas para plotar + +Responda APENAS com uma palavra: SIM ou NAO +""" def roteador_sandbox(estado: EstadoTextToInsight) -> Literal["critico", "planejador"]: @@ -15,7 +51,7 @@ def roteador_sandbox(estado: EstadoTextToInsight) -> Literal["critico", "planeja - exec_ok → critico (avaliar resultado) - exec_erro + tentativas < 3 → planejador (reconsiderar) - - tentativas >= 3 → planejador (desistir/reiniciar) + - tentativas >= 3 → crítico (desistir/reiniciar -> encerrar loop) """ status = estado.get("status", "") tentativas = estado.get("tentativas_loop", 0) @@ -30,8 +66,8 @@ def roteador_sandbox(estado: EstadoTextToInsight) -> Literal["critico", "planeja print("[ROTEADOR_SANDBOX] Erro detectado → planejador para retry") return "planejador" - print("[ROTEADOR_SANDBOX] Muitas tentativas ou erro → planejador") - return "planejador" + print("[ROTEADOR_SANDBOX] Muitas tentativas ou erro → critico (para forçar o fim do loop)") + return "critico" def roteador_planejador(estado: EstadoTextToInsight) -> Literal["esquema", "agente_codigo", "critico", "fim"]: @@ -67,3 +103,52 @@ def roteador_planejador(estado: EstadoTextToInsight) -> Literal["esquema", "agen # Default: gera código print("[ROTEADOR_PLANEJADOR] Default → planejador") return "planejador" + + +def roteador_grafico(estado: EstadoTextToInsight, llm: ChatGoogleGenerativeAI) -> Literal["gerador_grafico", "resposta"]: + """ + Roteador após salvar CSV: decide se gera gráfico ou vai direto para resposta. + + Usa o LLM para avaliar se a pergunta e os dados justificam uma visualização. + """ + pergunta = estado.get("pergunta_usuario", "") + preview = estado.get("linhas_resultado_preview", []) + total = estado.get("total_linhas_resultado", 0) + csv_path = estado.get("caminho_csv_resultado", "") + + # Se não há CSV ou dados, pular gráfico + if not csv_path or not preview or total == 0: + print("[ROTEADOR_GRAFICO] Sem dados para gráfico → resposta") + return "resposta" + + # Se resultado é uma única linha, provavelmente não precisa de gráfico + if total == 1: + print("[ROTEADOR_GRAFICO] Apenas 1 linha → resposta") + return "resposta" + + colunas = list(preview[0].keys()) if preview and isinstance(preview[0], dict) else [] + amostra_str = str(preview[:5]) if preview else "(vazio)" + + prompt = PROMPT_ROTEADOR_GRAFICO.format( + pergunta=pergunta, + colunas=", ".join(colunas) if colunas else "(desconhecidas)", + total_linhas=total, + amostra=amostra_str, + ) + + try: + resposta = llm.invoke(prompt) + texto = resposta.content.strip().upper() + + in_tokens, out_tokens, total_tokens = extrair_tokens(resposta) + + decisao = "SIM" in texto + print(f"[ROTEADOR_GRAFICO] Decisão LLM: {'SIM' if decisao else 'NAO'} → {'gerador_grafico' if decisao else 'resposta'}") + + if decisao: + return "gerador_grafico" + else: + return "resposta" + except Exception as e: + print(f"[ROTEADOR_GRAFICO] Erro ao consultar LLM: {e} → resposta (fallback)") + return "resposta" diff --git a/text_to_insight/runtime.py b/text_to_insight/runtime.py index 09a3541..60759b6 100644 --- a/text_to_insight/runtime.py +++ b/text_to_insight/runtime.py @@ -367,6 +367,15 @@ def exibir_resultado_console(resultado: dict[str, Any]) -> None: resposta_natural = str(resultado.get("resposta_natural", "")).strip() print(resposta_natural if resposta_natural else "[Nenhuma resposta natural]") + # Exibir info do gráfico gerado (se houver) + grafico_gerado = resultado.get("grafico_gerado", False) + caminho_grafico = resultado.get("caminho_grafico", "") + if grafico_gerado and caminho_grafico: + print("\n" + "-" * 70) + print("GRAFICO GERADO:") + print("-" * 70) + print(f"✓ Gráfico salvo em: {caminho_grafico}") + print("\n" + "=" * 70 + "\n") diff --git a/text_to_insight/state.py b/text_to_insight/state.py index fd53538..08419ad 100644 --- a/text_to_insight/state.py +++ b/text_to_insight/state.py @@ -70,6 +70,11 @@ class EstadoTextToInsight(EstadoEntrada, total = False): historico_tentativas: Annotated[list[dict[str, str]], operator.add] linhas_resultado_completo: list[dict[str, Any]] + # Campos para geração de gráficos + caminho_csv_resultado: str + caminho_grafico: str + grafico_gerado: bool + # Campos exclusivos para métricas. Possibilita a soma automática dos tokens utilizados # por cada chamada do Gemini nos vários diferentes nós. tokens_input: Annotated[int, operator.add]