feat: Add GitHub sync functionality to sync router (#2871)

The code changes in `sync_routes.py` add the GitHub sync functionality
to the sync router. This allows users to sync their GitHub repositories
with Quivr.

# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):

---------

Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
This commit is contained in:
Chloé Daems 2024-08-05 14:17:53 +02:00 committed by GitHub
parent b3debeefee
commit 9934a7a8ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 397 additions and 16 deletions

View File

@ -0,0 +1,155 @@
import os
import requests
from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import HTMLResponse
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
from quivr_api.modules.user.entity.user_identity import UserIdentity
from .successfull_connection import successfullConnectionPage
# Initialize logger
logger = get_logger(__name__)
# Initialize sync service
sync_service = SyncService()
sync_user_service = SyncUserService()
# Initialize API router
github_sync_router = APIRouter()
# Constants
CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:5050")
REDIRECT_URI = f"{BACKEND_URL}/sync/github/oauth2callback"
SCOPE = "repo user"
@github_sync_router.post(
"/sync/github/authorize",
dependencies=[Depends(AuthBearer())],
tags=["Sync"],
)
def authorize_github(
request: Request, name: str, current_user: UserIdentity = Depends(get_current_user)
):
"""
Authorize GitHub sync for the current user.
Args:
request (Request): The request object.
current_user (UserIdentity): The current authenticated user.
Returns:
dict: A dictionary containing the authorization URL.
"""
logger.debug(f"Authorizing GitHub sync for user: {current_user.id}")
state = f"user_id={current_user.id},name={name}"
authorization_url = (
f"https://github.com/login/oauth/authorize?client_id={CLIENT_ID}"
f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state}"
)
sync_user_input = SyncsUserInput(
user_id=str(current_user.id),
name=name,
provider="GitHub",
credentials={},
state={"state": state},
)
sync_user_service.create_sync_user(sync_user_input)
return {"authorization_url": authorization_url}
@github_sync_router.get("/sync/github/oauth2callback", tags=["Sync"])
def oauth2callback_github(request: Request):
"""
Handle OAuth2 callback from GitHub.
Args:
request (Request): The request object.
Returns:
dict: A dictionary containing a success message.
"""
state = request.query_params.get("state")
state_split = state.split(",")
current_user = state_split[0].split("=")[1] # Extract user_id from state
name = state_split[1].split("=")[1] if state else None
state_dict = {"state": state}
logger.debug(
f"Handling OAuth2 callback for user: {current_user} with state: {state}"
)
sync_user_state = sync_user_service.get_sync_user_by_state(state_dict)
logger.info(f"Retrieved sync user state: {sync_user_state}")
if state_dict != sync_user_state["state"]:
logger.error("Invalid state parameter")
raise HTTPException(status_code=400, detail="Invalid state parameter")
if sync_user_state.get("user_id") != current_user:
logger.error("Invalid user")
raise HTTPException(status_code=400, detail="Invalid user")
token_url = "https://github.com/login/oauth/access_token"
data = {
"client_id": CLIENT_ID,
"client_secret": CLIENT_SECRET,
"code": request.query_params.get("code"),
"redirect_uri": REDIRECT_URI,
"state": state,
}
headers = {"Accept": "application/json"}
response = requests.post(token_url, data=data, headers=headers)
if response.status_code != 200:
logger.error(f"Failed to acquire token: {response.json()}")
raise HTTPException(
status_code=400,
detail=f"Failed to acquire token: {response.json()}",
)
result = response.json()
access_token = result.get("access_token")
if not access_token:
logger.error(f"Failed to acquire token: {result}")
raise HTTPException(
status_code=400,
detail=f"Failed to acquire token: {result}",
)
creds = result
logger.info(f"Fetched OAuth2 token for user: {current_user}")
# Fetch user email from GitHub API
github_api_url = "https://api.github.com/user"
headers = {"Authorization": f"Bearer {access_token}"}
response = requests.get(github_api_url, headers=headers)
if response.status_code != 200:
logger.error("Failed to fetch user profile from GitHub API")
raise HTTPException(status_code=400, detail="Failed to fetch user profile")
user_info = response.json()
user_email = user_info.get("email")
if not user_email:
# If the email is not public, make a separate API call to get emails
emails_url = "https://api.github.com/user/emails"
response = requests.get(emails_url, headers=headers)
if response.status_code == 200:
emails = response.json()
user_email = next(email["email"] for email in emails if email["primary"])
else:
logger.error("Failed to fetch user email from GitHub API")
raise HTTPException(status_code=400, detail="Failed to fetch user email")
logger.info(f"Retrieved email for user: {current_user} - {user_email}")
sync_user_input = SyncUserUpdateInput(
credentials=result, state={}, email=user_email
)
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
logger.info(f"GitHub sync created successfully for user: {current_user}")
return HTMLResponse(successfullConnectionPage)

View File

@ -3,6 +3,7 @@ import uuid
from typing import List
from fastapi import APIRouter, Depends, status
from quivr_api.logger import get_logger
from quivr_api.middlewares.auth import AuthBearer, get_current_user
from quivr_api.modules.notification.dto.inputs import CreateNotification
@ -12,6 +13,7 @@ from quivr_api.modules.notification.service.notification_service import (
)
from quivr_api.modules.sync.controller.azure_sync_routes import azure_sync_router
from quivr_api.modules.sync.controller.dropbox_sync_routes import dropbox_sync_router
from quivr_api.modules.sync.controller.github_sync_routes import github_sync_router
from quivr_api.modules.sync.controller.google_sync_routes import google_sync_router
from quivr_api.modules.sync.dto import SyncsDescription
from quivr_api.modules.sync.dto.inputs import SyncsActiveInput, SyncsActiveUpdateInput
@ -38,6 +40,7 @@ sync_router = APIRouter()
# Add Google routes here
sync_router.include_router(google_sync_router)
sync_router.include_router(azure_sync_router)
sync_router.include_router(github_sync_router)
sync_router.include_router(dropbox_sync_router)
@ -60,6 +63,12 @@ dropbox_sync = SyncsDescription(
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
)
github_sync = SyncsDescription(
name="GitHub",
description="Sync your GitHub Drive with Quivr",
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
)
@sync_router.get(
"/sync/all",

View File

@ -48,14 +48,13 @@ class SyncUserInterface(ABC):
self,
sync_active_id: int,
user_id: str,
folder_id: int = None,
folder_id: int | str | None = None,
recursive: bool = False,
):
pass
class SyncInterface(ABC):
@abstractmethod
def create_sync_active(
self,

View File

@ -11,6 +11,7 @@ from quivr_api.modules.sync.repository.sync_interfaces import SyncUserInterface
from quivr_api.modules.sync.utils.sync import (
AzureDriveSync,
DropboxSync,
GitHubSync,
GoogleDriveSync,
)
@ -213,6 +214,15 @@ class SyncUser(SyncUserInterface):
sync_user["credentials"], folder_id if folder_id else "", recursive
)
}
elif provider == "github":
logger.info("Getting files for GitHub sync")
sync = GitHubSync()
return {
"files": sync.get_files(
sync_user["credentials"], folder_id if folder_id else "", recursive
)
}
else:
logger.warning(
"No sync found for provider: %s", sync_user["provider"], recursive

View File

@ -11,6 +11,7 @@ from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserSer
from quivr_api.modules.sync.utils.sync import (
AzureDriveSync,
DropboxSync,
GitHubSync,
GoogleDriveSync,
)
from quivr_api.modules.sync.utils.syncutils import SyncUtils
@ -57,6 +58,14 @@ async def _process_sync_active():
sync_cloud=DropboxSync(),
)
github_sync_utils = SyncUtils(
sync_user_service=sync_user_service,
sync_active_service=sync_active_service,
sync_files_repo=sync_files_repo_service,
storage=storage,
sync_cloud=GitHubSync(),
)
active = await sync_active_service.get_syncs_active_in_interval()
for sync in active:
@ -80,6 +89,10 @@ async def _process_sync_active():
await azure_sync_utils.sync(
sync_active_id=sync.id, user_id=sync.user_id
)
elif details_user_sync["provider"].lower() == "github":
await github_sync_utils.sync(
sync_active_id=sync.id, user_id=sync.user_id
)
elif details_user_sync["provider"].lower() == "dropbox":
await dropbox_sync_utils.sync(
sync_active_id=sync.id, user_id=sync.user_id

View File

@ -2,6 +2,7 @@ import json
import os
import time
from abc import ABC, abstractmethod
from datetime import datetime
from io import BytesIO
from typing import Any, Dict, List, Union
@ -12,10 +13,11 @@ from fastapi import HTTPException
from google.auth.transport.requests import Request as GoogleRequest
from google.oauth2.credentials import Credentials
from googleapiclient.discovery import build
from requests import HTTPError
from quivr_api.logger import get_logger
from quivr_api.modules.sync.entity.sync import SyncFile
from quivr_api.modules.sync.utils.normalize import remove_special_characters
from requests import HTTPError
logger = get_logger(__name__)
@ -66,7 +68,6 @@ class GoogleDriveSync(BaseSync):
file_id = file.id
file_name = file.name
mime_type = file.mime_type
modified_time = file.last_modified
if not self.creds:
self.check_and_refresh_access_token(credentials)
if not self.service:
@ -291,7 +292,6 @@ class AzureDriveSync(BaseSync):
"Authorization": f"Bearer {token_data['access_token']}",
"Accept": "application/json",
}
def check_and_refresh_access_token(self, credentials) -> Dict:
if "refresh_token" not in credentials:
@ -366,7 +366,7 @@ class AzureDriveSync(BaseSync):
if not folder_id and not site_id:
# Fetch the sites
#endpoint = "https://graph.microsoft.com/v1.0/me/followedSites"
# endpoint = "https://graph.microsoft.com/v1.0/me/followedSites"
endpoint = "https://graph.microsoft.com/v1.0/sites?search=*"
elif site_id == "root":
if not folder_id:
@ -427,9 +427,7 @@ class AzureDriveSync(BaseSync):
logger.info("Azure Drive files retrieved successfully: %s", len(files))
return files
def get_files_by_id(
self, credentials: dict, file_ids: List[str]
) -> List[SyncFile] | dict:
def get_files_by_id(self, credentials: dict, file_ids: List[str]) -> List[SyncFile]:
"""
Retrieve files from Azure Drive by their IDs.
@ -489,9 +487,7 @@ class AzureDriveSync(BaseSync):
) -> Dict[str, Union[str, BytesIO]]:
file_id = file.id
file_name = file.name
modified_time = file.last_modified
headers = self.get_azure_headers(credentials)
site_id, folder_id = file_id.split(":")
if folder_id == "":
folder_id = None
@ -687,3 +683,194 @@ class DropboxSync(BaseSync):
metadata, file_data = self.dbx.files_download(file_id) # type: ignore
return {"file_name": file_name, "content": BytesIO(file_data.content)}
class GitHubSync(BaseSync):
name = "GitHub"
lower_name = "github"
datetime_format = "%Y-%m-%dT%H:%M:%SZ"
def __init__(self):
self.CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
self.CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
self.BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:5050")
self.REDIRECT_URI = f"{self.BACKEND_URL}/sync/github/oauth2callback"
self.SCOPE = "repo user"
def get_github_token_data(self, credentials):
if "access_token" not in credentials:
raise HTTPException(status_code=401, detail="Invalid token data")
return credentials
def get_github_headers(self, token_data):
return {
"Authorization": f"Bearer {token_data['access_token']}",
"Accept": "application/json",
}
def check_and_refresh_access_token(self, credentials: dict) -> Dict:
# GitHub tokens do not support refresh token, usually need to re-authenticate
raise HTTPException(
status_code=400, detail="GitHub does not support token refresh"
)
def get_files(
self, credentials: Dict, folder_id: str | None = None, recursive: bool = False
) -> List[SyncFile]:
logger.info("Retrieving GitHub files with folder_id: %s", folder_id)
if folder_id:
return self.list_github_files_in_repo(
credentials, folder_id, recursive=recursive
)
else:
return self.list_github_repos(credentials, recursive=recursive)
def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFile]:
token_data = self.get_github_token_data(credentials)
headers = self.get_github_headers(token_data)
files = []
for file_id in file_ids:
repo_name, file_path = file_id.split(":")
endpoint = f" https://api.github.com/repos/{repo_name}/contents/{file_path}"
response = requests.get(endpoint, headers=headers)
if response.status_code == 401:
raise HTTPException(status_code=401, detail="Unauthorized")
if response.status_code != 200:
logger.error(
"An error occurred while retrieving GitHub files: %s", response.text
)
raise Exception("Failed to retrieve files")
result = response.json()
logger.debug("GitHub file result: %s", result)
files.append(
SyncFile(
name=remove_special_characters(result.get("name")),
id=f"{repo_name}:{result.get('path')}",
is_folder=False,
last_modified=datetime.now().strftime(self.datetime_format),
mime_type=result.get("type"),
web_view_link=result.get("html_url"),
)
)
logger.info("GitHub files retrieved successfully: %s", len(files))
return files
def download_file(
self, credentials: Dict, file: SyncFile
) -> Dict[str, Union[str, BytesIO]]:
token_data = self.get_github_token_data(credentials)
headers = self.get_github_headers(token_data)
project_name, file_path = file.id.split(":")
# Construct the API endpoint for the file content
endpoint = f"https://api.github.com/repos/{project_name}/contents/{file_path}"
response = requests.get(endpoint, headers=headers)
if response.status_code != 200:
raise HTTPException(
status_code=response.status_code, detail="Failed to download file"
)
content = response.json().get("content")
if not content:
raise HTTPException(status_code=404, detail="File content not found")
# GitHub API returns content as base64 encoded string
import base64
file_content = base64.b64decode(content)
return {"file_name": file.name, "content": BytesIO(file_content)}
def list_github_repos(self, credentials, recursive=False):
def fetch_repos(endpoint, headers):
response = requests.get(endpoint, headers=headers)
if response.status_code == 401:
raise HTTPException(status_code=401, detail="Unauthorized")
if response.status_code != 200:
logger.error(
"An error occurred while retrieving GitHub repositories: %s",
response.text,
)
return []
return response.json()
token_data = self.get_github_token_data(credentials)
headers = self.get_github_headers(token_data)
endpoint = "https://api.github.com/user/repos"
items = fetch_repos(endpoint, headers)
if not items:
logger.info("No repositories found in GitHub")
return []
repos = []
for item in items:
repo_data = SyncFile(
name=remove_special_characters(item.get("name")),
id=f"{item.get('full_name')}:",
is_folder=True,
last_modified=str(item.get("updated_at")),
mime_type="repository",
web_view_link=item.get("html_url"),
)
repos.append(repo_data)
if recursive:
submodule_files = self.list_github_files_in_repo(
credentials, repo_data.id
)
repos.extend(submodule_files)
logger.info("GitHub repositories retrieved successfully: %s", len(repos))
return repos
def list_github_files_in_repo(self, credentials, repo_folder, recursive=False):
def fetch_files(endpoint, headers):
response = requests.get(endpoint, headers=headers)
if response.status_code == 401:
raise HTTPException(status_code=401, detail="Unauthorized")
if response.status_code != 200:
logger.error(
"An error occurred while retrieving GitHub repository files: %s",
response.text,
)
return []
return response.json()
repo_name, folder_path = repo_folder.split(":")
token_data = self.get_github_token_data(credentials)
headers = self.get_github_headers(token_data)
endpoint = f"https://api.github.com/repos/{repo_name}/contents/{folder_path}"
logger.debug(f"Fetching files from GitHub with link: {endpoint}")
items = fetch_files(endpoint, headers)
if not items:
logger.info(f"No files found in GitHub repository {repo_name}")
return []
files = []
for item in items:
file_data = SyncFile(
name=remove_special_characters(item.get("name")),
id=f"{repo_name}:{item.get('path')}",
is_folder=item.get("type") == "dir",
last_modified=str(item.get("updated_at")),
mime_type=item.get("type"),
web_view_link=item.get("html_url"),
)
files.append(file_data)
if recursive and file_data.is_folder:
folder_files = self.list_github_files_in_repo(
credentials, repo_folder=file_data.id, recursive=True
)
files.extend(folder_files)
logger.info(f"GitHub repository files retrieved successfully: {len(files)}")
return files

View File

@ -4,6 +4,7 @@ from typing import List
from fastapi import UploadFile
from pydantic import BaseModel, ConfigDict
from quivr_api.logger import get_logger
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
from quivr_api.modules.knowledge.repository.storage import Storage
@ -59,7 +60,7 @@ class SyncUtils(BaseModel):
dict: A dictionary containing the status of the download or an error message.
"""
credentials = self.sync_cloud.check_and_refresh_access_token(credentials)
# credentials = self.sync_cloud.check_and_refresh_access_token(credentials)
downloaded_files = []
bulk_id = uuid.uuid4()

View File

@ -48,6 +48,7 @@ export const syncDropbox = async (
).data;
};
export const getUserSyncs = async (
axiosInstance: AxiosInstance
): Promise<Sync[]> => {

View File

@ -1,6 +1,6 @@
export type Provider = "Google" | "Azure" | "DropBox";
export type Provider = "Google" | "Azure" | "DropBox" | "GitHub";
export type Integration = "Google Drive" | "Share Point" | "Dropbox";
export type Integration = "Google Drive" | "Share Point" | "Dropbox" | "GitHub";
export interface SyncElement {
name?: string;

View File

@ -10,9 +10,10 @@ import {
getUserSyncs,
syncDropbox,
syncFiles,
syncGitHub,
syncGoogleDrive,
syncSharepoint,
updateActiveSync,
updateActiveSync
} from "./sync";
import { Integration, OpenedConnection, Provider } from "./types";
@ -27,6 +28,8 @@ export const useSync = () => {
"https://quivr-cms.s3.eu-west-3.amazonaws.com/sharepoint_8c41cfdb09.png",
DropBox:
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
GitHub:
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
};
const integrationIconUrls: Record<Integration, string> = {
@ -36,6 +39,8 @@ export const useSync = () => {
"https://quivr-cms.s3.eu-west-3.amazonaws.com/sharepoint_8c41cfdb09.png",
Dropbox:
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
GitHub:
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
};
const getActiveSyncsForBrain = async (brainId: string) => {
@ -49,6 +54,7 @@ export const useSync = () => {
syncGoogleDrive(name, axiosInstance),
syncSharepoint: async (name: string) => syncSharepoint(name, axiosInstance),
syncDropbox: async (name: string) => syncDropbox(name, axiosInstance),
syncGitHub: async (name: string) => syncGitHub(name, axiosInstance),
getUserSyncs: async () => getUserSyncs(axiosInstance),
getSyncFiles: async (userSyncId: number, folderId?: string) =>
getSyncFiles(axiosInstance, userSyncId, folderId),

View File

@ -10,7 +10,7 @@ interface ConnectionCardsProps {
export const ConnectionCards = ({
fromAddKnowledge,
}: ConnectionCardsProps): JSX.Element => {
const { syncGoogleDrive, syncSharepoint, syncDropbox } = useSync();
const { syncGoogleDrive, syncSharepoint, syncDropbox, syncGitHub } = useSync();
return (
<div