Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/linters.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
name: Linting

on:
pull_request:

jobs:
ruff:
name: Ruff Linting
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Run Ruff
uses: astral-sh/ruff-action@v1
with:
args: check --output-format=github
11 changes: 10 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
venv
.venv
.env
__pycache__/
*.pyc
.DS_Store
.idea/
__pycache__/
data/
chroma_db/
chroma_db/
ml/mlruns
ml/mlflow.db
mlflow.db
mlartifacts

.pytest_cache/
.mypy_cache/
.ruff_cache/
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
12 changes: 10 additions & 2 deletions app/RAG/ingest.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import os

from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain_core.Doduments import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter

load_dotenv()
HF_TOKEN = os.getenv("HF_TOKEN")
data_path = "data/data.pdf"


Expand Down Expand Up @@ -33,7 +38,10 @@ def save_splits_to_chroma(
:param chroma_db_path: Description
:type chroma_db_path: str
"""
embedding_function = HuggingFaceEmbeddings(model_name="BAAI/bge-small-en-v1.5")
embedding_function = HuggingFaceEmbeddings(
model_name="BAAI/bge-small-en-v1.5",
model_kwargs={"token": HF_TOKEN},
)
print("Creating vector database... this may take a moment.")
Chroma.from_documents(
documents=splits, embedding=embedding_function, persist_directory=chroma_db_path
Expand Down
50 changes: 36 additions & 14 deletions app/RAG/query.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
import os

from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_chroma import Chroma
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_huggingface import HuggingFaceEmbeddings

load_dotenv()
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
HF_TOKEN = os.getenv("HF_TOKEN")


class ITSmartAssistant:
def __init__(self):
self.embedding_function = HuggingFaceEmbeddings(
model_name="BAAI/bge-small-en-v1.5"
model_name="BAAI/bge-small-en-v1.5",
model_kwargs={"token": HF_TOKEN},
)
self.db = Chroma(
persist_directory="./chroma_db", embedding_function=self.embedding_function
Expand All @@ -33,8 +35,11 @@ def _format_docs(self, docs):
def _build_chain(self):
system_prompt = (
"You are an IT Support Assistant. Use the following context to answer the question. "
"If you don't know the answer, say you don't know. Keep it brief.\n\n"
"{context}"
'- If the answer is not in the context, respond: "I don’t know." '
"- Provide step-by-step instructions if applicable. "
"- Include citations for any information taken from the context in the format: [source, page]. "
"- Keep answers clear and concise.\n\n"
"Context:\n{context}"
)

prompt = ChatPromptTemplate.from_messages(
Expand All @@ -44,17 +49,34 @@ def _build_chain(self):
]
)

return prompt | self.llm | StrOutputParser()

def ask(self, question: str):
retriever = self.db.as_retriever()
docs = retriever.invoke(question)

return (
{"context": retriever | self._format_docs, "input": RunnablePassthrough()}
| prompt
| self.llm
| StrOutputParser()
context = self._format_docs(docs)
answer = self.chain.invoke(
{
"context": context,
"input": question,
}
)

def ask(self, question: str):
return {"query": question, "answer": self.chain.invoke(question)}
chunks = [
{
"content": doc.page_content,
"source": doc.metadata.get("source"),
"page": doc.metadata.get("page"),
}
for doc in docs
]

return {
"query": question,
"answer": answer,
"chunks": chunks,
}


# query = "How to reset my IT support password?"
Expand Down
19 changes: 9 additions & 10 deletions app/api/routers/auth.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
from fastapi import Depends, HTTPException, status, Request
from datetime import timedelta

from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
from fastapi.security import OAuth2PasswordRequestForm
from sqlalchemy.orm import Session

from ...authentication.auth import (
get_password_hash,
authenticate_user,
create_access_token,
get_current_user,
get_password_hash,
)
from ...schemas.user_schema import UserSchema, UserCreate
from ...models.user_model import User
from ...db.database import get_db
from fastapi.security import OAuth2PasswordRequestForm
from datetime import timedelta
from fastapi import Response
from fastapi import APIRouter
from ...config import settings

from ...db.database import get_db
from ...models.user_model import User
from ...schemas.user_schema import UserCreate, UserSchema

router = APIRouter(prefix="/api/v1/auth", tags=["Authentication routes"])

Expand Down
83 changes: 60 additions & 23 deletions app/api/routers/rag.py
Original file line number Diff line number Diff line change
@@ -1,41 +1,78 @@
from fastapi import APIRouter, Depends
import os
import time
from pathlib import Path

import mlflow
from fastapi import APIRouter, Depends, HTTPException, Request
from sqlalchemy.orm import Session

from app.authentication.auth import get_current_user
from ...schemas.user_schema import UserSchema, QueryRequest, QueryResponse
from ml.cluster_model import ClusterModel

from ...db.database import get_db
from sqlalchemy.orm import Session
from ...RAG.query import ITSmartAssistant
from ...models.user_model import Query
from ...schemas.user_schema import QueryRequest, QueryResponse, UserSchema

router = APIRouter(prefix="/api/v1/rag", tags=["RAG Routes"])

rag_instance = ITSmartAssistant()
BASE_DIR = Path(__file__).resolve().parent.parent.parent.parent
MODEL_DIR = os.path.join(BASE_DIR, "ml", "models")


cluster = ClusterModel(model_path=os.path.join(BASE_DIR, "ml", "model"))


@router.post("/query", response_model=QueryResponse)
async def get_rag_answer(
query: QueryRequest,
request: Request,
db: Session = Depends(get_db),
current_user: UserSchema = Depends(get_current_user),
):
time_start = time.time()
rag_answer = rag_instance.ask(query.question)
time_end = time.time()
latency_ms = time_end - time_start

answer = rag_answer["answer"]

new_query = Query(
user_id=current_user.id,
question=query.question,
answer=answer,
cluster="default",
latency_ms=latency_ms,
)
db.add(new_query)
db.commit()
db.refresh(new_query)
return new_query
assistant = getattr(request.app.state, "rag_assistant", None)

if not assistant:
raise HTTPException(status_code=503, detail="RAG Assistant is still loading.")

with mlflow.start_run(run_name="rag_query"):
mlflow.log_param("user_id", current_user.id)
mlflow.log_param("rag_version", "v1")

mlflow.log_param("llm_model", "gemini-2.5-flash")
mlflow.log_param("temperature", 0.7)
mlflow.log_param("top_k", 5)
mlflow.log_text(query.question, "input_question.txt")

time_start = time.time()
assistant = request.app.state.rag_assistant
rag_answer = assistant.ask(query.question)

cluster_id = cluster.predict_cluster(query.question)

time_end = time.time()
latency_ms = (time_end - time_start) * 1000

mlflow.log_metric("latency_ms", latency_ms)

mlflow.log_text(rag_answer["answer"], "generated_answer.txt")
mlflow.log_metric("num_chunks", len(rag_answer["chunks"]))

if "chunks" in rag_answer:
mlflow.log_dict(rag_answer["chunks"], "retrieved_chunks.json")

new_query = Query(
user_id=current_user.id,
question=query.question,
answer=rag_answer["answer"],
cluster=cluster_id,
latency_ms=latency_ms,
)

db.add(new_query)
db.commit()
db.refresh(new_query)

return new_query


@router.get("/history")
Expand Down
25 changes: 13 additions & 12 deletions app/authentication/auth.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
from datetime import UTC, datetime, timedelta

import jwt
from fastapi import Depends, HTTPException, status, Request
from fastapi import Depends, HTTPException, Request, status
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from passlib.context import CryptContext
from ..models.user_model import User
from ..schemas.user_schema import TokenData
from ..db import database
from sqlalchemy.orm import Session
from datetime import timedelta, datetime, timezone
from typing import Optional
from jwt.exceptions import InvalidTokenError

from ..config import settings
from ..db import database
from ..models.user_model import User
from ..schemas.user_schema import TokenData

pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
oauth2_schema = OAuth2PasswordBearer(tokenUrl="token")
Expand All @@ -31,12 +32,12 @@ def authenticate_user(db: Session, username: str, password: str):
return user


def create_access_token(data: dict, expires_delta: Optional[timedelta] = None):
def create_access_token(data: dict, expires_delta: timedelta | None = None):
to_encode = data.copy()
if expires_delta:
expire = datetime.now(timezone.utc) + expires_delta
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(timezone.utc) + timedelta(minutes=20)
expire = datetime.now(UTC) + timedelta(minutes=20)

to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(
Expand Down Expand Up @@ -64,8 +65,8 @@ def get_current_user(
if username is None:
raise credentials_exception
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception
except InvalidTokenError as err:
raise credentials_exception from err

user = db.query(User).filter(User.username == token_data.username).first()
if user is None:
Expand Down
4 changes: 2 additions & 2 deletions app/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from pydantic import ConfigDict, Field
from pydantic_settings import BaseSettings
from pydantic import Field
from pydantic import ConfigDict


class Settings(BaseSettings):
Expand All @@ -15,6 +14,7 @@ class Settings(BaseSettings):
FRONTEND_URL: str = Field(...)
GOOGLE_API_KEY: str = Field(...)
HF_TOKEN: str = Field(...)
MLFLOW_TRACKING_URI: str = Field(...)

model_config = ConfigDict(env_file=".env", env_file_encoding="utf-8")

Expand Down
3 changes: 1 addition & 2 deletions app/db/database.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
import os
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from ..config import settings

from ..config import settings

SQLALCHEMY_DATABASE_URL = f"postgresql://{settings.DATABASE_USER}:{settings.DATABASE_PASSWORD}@{settings.DATABASE_HOST}:{settings.DATABASE_PORT}/{settings.DATABASE_NAME}"

Expand Down
Loading