mirror of
https://github.com/QuivrHQ/quivr.git
synced 2024-12-15 01:21:48 +03:00
feat: Add GitHub sync functionality to sync router (#2871)
The code changes in `sync_routes.py` add the GitHub sync functionality to the sync router. This allows users to sync their GitHub repositories with Quivr. # Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Checklist before requesting a review Please delete options that are not relevant. - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented hard-to-understand areas - [ ] I have ideally added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes - [ ] Any dependent changes have been merged ## Screenshots (if appropriate): --------- Co-authored-by: Stan Girard <girard.stanislas@gmail.com>
This commit is contained in:
parent
b3debeefee
commit
9934a7a8ce
@ -0,0 +1,155 @@
|
|||||||
|
import os
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
|
from fastapi.responses import HTMLResponse
|
||||||
|
from quivr_api.logger import get_logger
|
||||||
|
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||||
|
from quivr_api.modules.sync.dto.inputs import SyncsUserInput, SyncUserUpdateInput
|
||||||
|
from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserService
|
||||||
|
from quivr_api.modules.user.entity.user_identity import UserIdentity
|
||||||
|
|
||||||
|
from .successfull_connection import successfullConnectionPage
|
||||||
|
|
||||||
|
# Initialize logger
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# Initialize sync service
|
||||||
|
sync_service = SyncService()
|
||||||
|
sync_user_service = SyncUserService()
|
||||||
|
|
||||||
|
# Initialize API router
|
||||||
|
github_sync_router = APIRouter()
|
||||||
|
|
||||||
|
# Constants
|
||||||
|
CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
|
||||||
|
CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
|
||||||
|
BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:5050")
|
||||||
|
REDIRECT_URI = f"{BACKEND_URL}/sync/github/oauth2callback"
|
||||||
|
SCOPE = "repo user"
|
||||||
|
|
||||||
|
|
||||||
|
@github_sync_router.post(
|
||||||
|
"/sync/github/authorize",
|
||||||
|
dependencies=[Depends(AuthBearer())],
|
||||||
|
tags=["Sync"],
|
||||||
|
)
|
||||||
|
def authorize_github(
|
||||||
|
request: Request, name: str, current_user: UserIdentity = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Authorize GitHub sync for the current user.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The request object.
|
||||||
|
current_user (UserIdentity): The current authenticated user.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing the authorization URL.
|
||||||
|
"""
|
||||||
|
logger.debug(f"Authorizing GitHub sync for user: {current_user.id}")
|
||||||
|
state = f"user_id={current_user.id},name={name}"
|
||||||
|
authorization_url = (
|
||||||
|
f"https://github.com/login/oauth/authorize?client_id={CLIENT_ID}"
|
||||||
|
f"&redirect_uri={REDIRECT_URI}&scope={SCOPE}&state={state}"
|
||||||
|
)
|
||||||
|
|
||||||
|
sync_user_input = SyncsUserInput(
|
||||||
|
user_id=str(current_user.id),
|
||||||
|
name=name,
|
||||||
|
provider="GitHub",
|
||||||
|
credentials={},
|
||||||
|
state={"state": state},
|
||||||
|
)
|
||||||
|
sync_user_service.create_sync_user(sync_user_input)
|
||||||
|
return {"authorization_url": authorization_url}
|
||||||
|
|
||||||
|
|
||||||
|
@github_sync_router.get("/sync/github/oauth2callback", tags=["Sync"])
|
||||||
|
def oauth2callback_github(request: Request):
|
||||||
|
"""
|
||||||
|
Handle OAuth2 callback from GitHub.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request (Request): The request object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: A dictionary containing a success message.
|
||||||
|
"""
|
||||||
|
state = request.query_params.get("state")
|
||||||
|
state_split = state.split(",")
|
||||||
|
current_user = state_split[0].split("=")[1] # Extract user_id from state
|
||||||
|
name = state_split[1].split("=")[1] if state else None
|
||||||
|
state_dict = {"state": state}
|
||||||
|
logger.debug(
|
||||||
|
f"Handling OAuth2 callback for user: {current_user} with state: {state}"
|
||||||
|
)
|
||||||
|
sync_user_state = sync_user_service.get_sync_user_by_state(state_dict)
|
||||||
|
logger.info(f"Retrieved sync user state: {sync_user_state}")
|
||||||
|
|
||||||
|
if state_dict != sync_user_state["state"]:
|
||||||
|
logger.error("Invalid state parameter")
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid state parameter")
|
||||||
|
if sync_user_state.get("user_id") != current_user:
|
||||||
|
logger.error("Invalid user")
|
||||||
|
raise HTTPException(status_code=400, detail="Invalid user")
|
||||||
|
|
||||||
|
token_url = "https://github.com/login/oauth/access_token"
|
||||||
|
data = {
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"client_secret": CLIENT_SECRET,
|
||||||
|
"code": request.query_params.get("code"),
|
||||||
|
"redirect_uri": REDIRECT_URI,
|
||||||
|
"state": state,
|
||||||
|
}
|
||||||
|
headers = {"Accept": "application/json"}
|
||||||
|
response = requests.post(token_url, data=data, headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(f"Failed to acquire token: {response.json()}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to acquire token: {response.json()}",
|
||||||
|
)
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
access_token = result.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
logger.error(f"Failed to acquire token: {result}")
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail=f"Failed to acquire token: {result}",
|
||||||
|
)
|
||||||
|
|
||||||
|
creds = result
|
||||||
|
logger.info(f"Fetched OAuth2 token for user: {current_user}")
|
||||||
|
|
||||||
|
# Fetch user email from GitHub API
|
||||||
|
github_api_url = "https://api.github.com/user"
|
||||||
|
headers = {"Authorization": f"Bearer {access_token}"}
|
||||||
|
response = requests.get(github_api_url, headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error("Failed to fetch user profile from GitHub API")
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to fetch user profile")
|
||||||
|
|
||||||
|
user_info = response.json()
|
||||||
|
user_email = user_info.get("email")
|
||||||
|
if not user_email:
|
||||||
|
# If the email is not public, make a separate API call to get emails
|
||||||
|
emails_url = "https://api.github.com/user/emails"
|
||||||
|
response = requests.get(emails_url, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
emails = response.json()
|
||||||
|
user_email = next(email["email"] for email in emails if email["primary"])
|
||||||
|
else:
|
||||||
|
logger.error("Failed to fetch user email from GitHub API")
|
||||||
|
raise HTTPException(status_code=400, detail="Failed to fetch user email")
|
||||||
|
|
||||||
|
logger.info(f"Retrieved email for user: {current_user} - {user_email}")
|
||||||
|
|
||||||
|
sync_user_input = SyncUserUpdateInput(
|
||||||
|
credentials=result, state={}, email=user_email
|
||||||
|
)
|
||||||
|
|
||||||
|
sync_user_service.update_sync_user(current_user, state_dict, sync_user_input)
|
||||||
|
logger.info(f"GitHub sync created successfully for user: {current_user}")
|
||||||
|
return HTMLResponse(successfullConnectionPage)
|
@ -3,6 +3,7 @@ import uuid
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, status
|
from fastapi import APIRouter, Depends, status
|
||||||
|
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
from quivr_api.middlewares.auth import AuthBearer, get_current_user
|
||||||
from quivr_api.modules.notification.dto.inputs import CreateNotification
|
from quivr_api.modules.notification.dto.inputs import CreateNotification
|
||||||
@ -12,6 +13,7 @@ from quivr_api.modules.notification.service.notification_service import (
|
|||||||
)
|
)
|
||||||
from quivr_api.modules.sync.controller.azure_sync_routes import azure_sync_router
|
from quivr_api.modules.sync.controller.azure_sync_routes import azure_sync_router
|
||||||
from quivr_api.modules.sync.controller.dropbox_sync_routes import dropbox_sync_router
|
from quivr_api.modules.sync.controller.dropbox_sync_routes import dropbox_sync_router
|
||||||
|
from quivr_api.modules.sync.controller.github_sync_routes import github_sync_router
|
||||||
from quivr_api.modules.sync.controller.google_sync_routes import google_sync_router
|
from quivr_api.modules.sync.controller.google_sync_routes import google_sync_router
|
||||||
from quivr_api.modules.sync.dto import SyncsDescription
|
from quivr_api.modules.sync.dto import SyncsDescription
|
||||||
from quivr_api.modules.sync.dto.inputs import SyncsActiveInput, SyncsActiveUpdateInput
|
from quivr_api.modules.sync.dto.inputs import SyncsActiveInput, SyncsActiveUpdateInput
|
||||||
@ -38,6 +40,7 @@ sync_router = APIRouter()
|
|||||||
# Add Google routes here
|
# Add Google routes here
|
||||||
sync_router.include_router(google_sync_router)
|
sync_router.include_router(google_sync_router)
|
||||||
sync_router.include_router(azure_sync_router)
|
sync_router.include_router(azure_sync_router)
|
||||||
|
sync_router.include_router(github_sync_router)
|
||||||
sync_router.include_router(dropbox_sync_router)
|
sync_router.include_router(dropbox_sync_router)
|
||||||
|
|
||||||
|
|
||||||
@ -60,6 +63,12 @@ dropbox_sync = SyncsDescription(
|
|||||||
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
|
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
github_sync = SyncsDescription(
|
||||||
|
name="GitHub",
|
||||||
|
description="Sync your GitHub Drive with Quivr",
|
||||||
|
auth_method=AuthMethodEnum.URI_WITH_CALLBACK,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@sync_router.get(
|
@sync_router.get(
|
||||||
"/sync/all",
|
"/sync/all",
|
||||||
|
@ -48,14 +48,13 @@ class SyncUserInterface(ABC):
|
|||||||
self,
|
self,
|
||||||
sync_active_id: int,
|
sync_active_id: int,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
folder_id: int = None,
|
folder_id: int | str | None = None,
|
||||||
recursive: bool = False,
|
recursive: bool = False,
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class SyncInterface(ABC):
|
class SyncInterface(ABC):
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def create_sync_active(
|
def create_sync_active(
|
||||||
self,
|
self,
|
||||||
|
@ -11,6 +11,7 @@ from quivr_api.modules.sync.repository.sync_interfaces import SyncUserInterface
|
|||||||
from quivr_api.modules.sync.utils.sync import (
|
from quivr_api.modules.sync.utils.sync import (
|
||||||
AzureDriveSync,
|
AzureDriveSync,
|
||||||
DropboxSync,
|
DropboxSync,
|
||||||
|
GitHubSync,
|
||||||
GoogleDriveSync,
|
GoogleDriveSync,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -213,6 +214,15 @@ class SyncUser(SyncUserInterface):
|
|||||||
sync_user["credentials"], folder_id if folder_id else "", recursive
|
sync_user["credentials"], folder_id if folder_id else "", recursive
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
elif provider == "github":
|
||||||
|
logger.info("Getting files for GitHub sync")
|
||||||
|
sync = GitHubSync()
|
||||||
|
return {
|
||||||
|
"files": sync.get_files(
|
||||||
|
sync_user["credentials"], folder_id if folder_id else "", recursive
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"No sync found for provider: %s", sync_user["provider"], recursive
|
"No sync found for provider: %s", sync_user["provider"], recursive
|
||||||
|
@ -11,6 +11,7 @@ from quivr_api.modules.sync.service.sync_service import SyncService, SyncUserSer
|
|||||||
from quivr_api.modules.sync.utils.sync import (
|
from quivr_api.modules.sync.utils.sync import (
|
||||||
AzureDriveSync,
|
AzureDriveSync,
|
||||||
DropboxSync,
|
DropboxSync,
|
||||||
|
GitHubSync,
|
||||||
GoogleDriveSync,
|
GoogleDriveSync,
|
||||||
)
|
)
|
||||||
from quivr_api.modules.sync.utils.syncutils import SyncUtils
|
from quivr_api.modules.sync.utils.syncutils import SyncUtils
|
||||||
@ -57,6 +58,14 @@ async def _process_sync_active():
|
|||||||
sync_cloud=DropboxSync(),
|
sync_cloud=DropboxSync(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
github_sync_utils = SyncUtils(
|
||||||
|
sync_user_service=sync_user_service,
|
||||||
|
sync_active_service=sync_active_service,
|
||||||
|
sync_files_repo=sync_files_repo_service,
|
||||||
|
storage=storage,
|
||||||
|
sync_cloud=GitHubSync(),
|
||||||
|
)
|
||||||
|
|
||||||
active = await sync_active_service.get_syncs_active_in_interval()
|
active = await sync_active_service.get_syncs_active_in_interval()
|
||||||
|
|
||||||
for sync in active:
|
for sync in active:
|
||||||
@ -80,6 +89,10 @@ async def _process_sync_active():
|
|||||||
await azure_sync_utils.sync(
|
await azure_sync_utils.sync(
|
||||||
sync_active_id=sync.id, user_id=sync.user_id
|
sync_active_id=sync.id, user_id=sync.user_id
|
||||||
)
|
)
|
||||||
|
elif details_user_sync["provider"].lower() == "github":
|
||||||
|
await github_sync_utils.sync(
|
||||||
|
sync_active_id=sync.id, user_id=sync.user_id
|
||||||
|
)
|
||||||
elif details_user_sync["provider"].lower() == "dropbox":
|
elif details_user_sync["provider"].lower() == "dropbox":
|
||||||
await dropbox_sync_utils.sync(
|
await dropbox_sync_utils.sync(
|
||||||
sync_active_id=sync.id, user_id=sync.user_id
|
sync_active_id=sync.id, user_id=sync.user_id
|
||||||
|
@ -2,6 +2,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Any, Dict, List, Union
|
from typing import Any, Dict, List, Union
|
||||||
|
|
||||||
@ -12,10 +13,11 @@ from fastapi import HTTPException
|
|||||||
from google.auth.transport.requests import Request as GoogleRequest
|
from google.auth.transport.requests import Request as GoogleRequest
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
from googleapiclient.discovery import build
|
from googleapiclient.discovery import build
|
||||||
|
from requests import HTTPError
|
||||||
|
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.modules.sync.entity.sync import SyncFile
|
from quivr_api.modules.sync.entity.sync import SyncFile
|
||||||
from quivr_api.modules.sync.utils.normalize import remove_special_characters
|
from quivr_api.modules.sync.utils.normalize import remove_special_characters
|
||||||
from requests import HTTPError
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
@ -66,7 +68,6 @@ class GoogleDriveSync(BaseSync):
|
|||||||
file_id = file.id
|
file_id = file.id
|
||||||
file_name = file.name
|
file_name = file.name
|
||||||
mime_type = file.mime_type
|
mime_type = file.mime_type
|
||||||
modified_time = file.last_modified
|
|
||||||
if not self.creds:
|
if not self.creds:
|
||||||
self.check_and_refresh_access_token(credentials)
|
self.check_and_refresh_access_token(credentials)
|
||||||
if not self.service:
|
if not self.service:
|
||||||
@ -291,7 +292,6 @@ class AzureDriveSync(BaseSync):
|
|||||||
"Authorization": f"Bearer {token_data['access_token']}",
|
"Authorization": f"Bearer {token_data['access_token']}",
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def check_and_refresh_access_token(self, credentials) -> Dict:
|
def check_and_refresh_access_token(self, credentials) -> Dict:
|
||||||
if "refresh_token" not in credentials:
|
if "refresh_token" not in credentials:
|
||||||
@ -366,7 +366,7 @@ class AzureDriveSync(BaseSync):
|
|||||||
|
|
||||||
if not folder_id and not site_id:
|
if not folder_id and not site_id:
|
||||||
# Fetch the sites
|
# Fetch the sites
|
||||||
#endpoint = "https://graph.microsoft.com/v1.0/me/followedSites"
|
# endpoint = "https://graph.microsoft.com/v1.0/me/followedSites"
|
||||||
endpoint = "https://graph.microsoft.com/v1.0/sites?search=*"
|
endpoint = "https://graph.microsoft.com/v1.0/sites?search=*"
|
||||||
elif site_id == "root":
|
elif site_id == "root":
|
||||||
if not folder_id:
|
if not folder_id:
|
||||||
@ -427,9 +427,7 @@ class AzureDriveSync(BaseSync):
|
|||||||
logger.info("Azure Drive files retrieved successfully: %s", len(files))
|
logger.info("Azure Drive files retrieved successfully: %s", len(files))
|
||||||
return files
|
return files
|
||||||
|
|
||||||
def get_files_by_id(
|
def get_files_by_id(self, credentials: dict, file_ids: List[str]) -> List[SyncFile]:
|
||||||
self, credentials: dict, file_ids: List[str]
|
|
||||||
) -> List[SyncFile] | dict:
|
|
||||||
"""
|
"""
|
||||||
Retrieve files from Azure Drive by their IDs.
|
Retrieve files from Azure Drive by their IDs.
|
||||||
|
|
||||||
@ -489,9 +487,7 @@ class AzureDriveSync(BaseSync):
|
|||||||
) -> Dict[str, Union[str, BytesIO]]:
|
) -> Dict[str, Union[str, BytesIO]]:
|
||||||
file_id = file.id
|
file_id = file.id
|
||||||
file_name = file.name
|
file_name = file.name
|
||||||
modified_time = file.last_modified
|
|
||||||
headers = self.get_azure_headers(credentials)
|
headers = self.get_azure_headers(credentials)
|
||||||
|
|
||||||
site_id, folder_id = file_id.split(":")
|
site_id, folder_id = file_id.split(":")
|
||||||
if folder_id == "":
|
if folder_id == "":
|
||||||
folder_id = None
|
folder_id = None
|
||||||
@ -687,3 +683,194 @@ class DropboxSync(BaseSync):
|
|||||||
|
|
||||||
metadata, file_data = self.dbx.files_download(file_id) # type: ignore
|
metadata, file_data = self.dbx.files_download(file_id) # type: ignore
|
||||||
return {"file_name": file_name, "content": BytesIO(file_data.content)}
|
return {"file_name": file_name, "content": BytesIO(file_data.content)}
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubSync(BaseSync):
|
||||||
|
name = "GitHub"
|
||||||
|
lower_name = "github"
|
||||||
|
datetime_format = "%Y-%m-%dT%H:%M:%SZ"
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.CLIENT_ID = os.getenv("GITHUB_CLIENT_ID")
|
||||||
|
self.CLIENT_SECRET = os.getenv("GITHUB_CLIENT_SECRET")
|
||||||
|
self.BACKEND_URL = os.getenv("BACKEND_URL", "http://localhost:5050")
|
||||||
|
self.REDIRECT_URI = f"{self.BACKEND_URL}/sync/github/oauth2callback"
|
||||||
|
self.SCOPE = "repo user"
|
||||||
|
|
||||||
|
def get_github_token_data(self, credentials):
|
||||||
|
if "access_token" not in credentials:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token data")
|
||||||
|
return credentials
|
||||||
|
|
||||||
|
def get_github_headers(self, token_data):
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {token_data['access_token']}",
|
||||||
|
"Accept": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
def check_and_refresh_access_token(self, credentials: dict) -> Dict:
|
||||||
|
# GitHub tokens do not support refresh token, usually need to re-authenticate
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400, detail="GitHub does not support token refresh"
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_files(
|
||||||
|
self, credentials: Dict, folder_id: str | None = None, recursive: bool = False
|
||||||
|
) -> List[SyncFile]:
|
||||||
|
logger.info("Retrieving GitHub files with folder_id: %s", folder_id)
|
||||||
|
if folder_id:
|
||||||
|
return self.list_github_files_in_repo(
|
||||||
|
credentials, folder_id, recursive=recursive
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return self.list_github_repos(credentials, recursive=recursive)
|
||||||
|
|
||||||
|
def get_files_by_id(self, credentials: Dict, file_ids: List[str]) -> List[SyncFile]:
|
||||||
|
token_data = self.get_github_token_data(credentials)
|
||||||
|
headers = self.get_github_headers(token_data)
|
||||||
|
files = []
|
||||||
|
|
||||||
|
for file_id in file_ids:
|
||||||
|
repo_name, file_path = file_id.split(":")
|
||||||
|
endpoint = f" https://api.github.com/repos/{repo_name}/contents/{file_path}"
|
||||||
|
response = requests.get(endpoint, headers=headers)
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
"An error occurred while retrieving GitHub files: %s", response.text
|
||||||
|
)
|
||||||
|
raise Exception("Failed to retrieve files")
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
logger.debug("GitHub file result: %s", result)
|
||||||
|
files.append(
|
||||||
|
SyncFile(
|
||||||
|
name=remove_special_characters(result.get("name")),
|
||||||
|
id=f"{repo_name}:{result.get('path')}",
|
||||||
|
is_folder=False,
|
||||||
|
last_modified=datetime.now().strftime(self.datetime_format),
|
||||||
|
mime_type=result.get("type"),
|
||||||
|
web_view_link=result.get("html_url"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("GitHub files retrieved successfully: %s", len(files))
|
||||||
|
return files
|
||||||
|
|
||||||
|
def download_file(
|
||||||
|
self, credentials: Dict, file: SyncFile
|
||||||
|
) -> Dict[str, Union[str, BytesIO]]:
|
||||||
|
token_data = self.get_github_token_data(credentials)
|
||||||
|
headers = self.get_github_headers(token_data)
|
||||||
|
project_name, file_path = file.id.split(":")
|
||||||
|
|
||||||
|
# Construct the API endpoint for the file content
|
||||||
|
endpoint = f"https://api.github.com/repos/{project_name}/contents/{file_path}"
|
||||||
|
|
||||||
|
response = requests.get(endpoint, headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=response.status_code, detail="Failed to download file"
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.json().get("content")
|
||||||
|
if not content:
|
||||||
|
raise HTTPException(status_code=404, detail="File content not found")
|
||||||
|
|
||||||
|
# GitHub API returns content as base64 encoded string
|
||||||
|
import base64
|
||||||
|
|
||||||
|
file_content = base64.b64decode(content)
|
||||||
|
|
||||||
|
return {"file_name": file.name, "content": BytesIO(file_content)}
|
||||||
|
|
||||||
|
def list_github_repos(self, credentials, recursive=False):
|
||||||
|
def fetch_repos(endpoint, headers):
|
||||||
|
response = requests.get(endpoint, headers=headers)
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
"An error occurred while retrieving GitHub repositories: %s",
|
||||||
|
response.text,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
token_data = self.get_github_token_data(credentials)
|
||||||
|
headers = self.get_github_headers(token_data)
|
||||||
|
endpoint = "https://api.github.com/user/repos"
|
||||||
|
|
||||||
|
items = fetch_repos(endpoint, headers)
|
||||||
|
|
||||||
|
if not items:
|
||||||
|
logger.info("No repositories found in GitHub")
|
||||||
|
return []
|
||||||
|
|
||||||
|
repos = []
|
||||||
|
for item in items:
|
||||||
|
repo_data = SyncFile(
|
||||||
|
name=remove_special_characters(item.get("name")),
|
||||||
|
id=f"{item.get('full_name')}:",
|
||||||
|
is_folder=True,
|
||||||
|
last_modified=str(item.get("updated_at")),
|
||||||
|
mime_type="repository",
|
||||||
|
web_view_link=item.get("html_url"),
|
||||||
|
)
|
||||||
|
repos.append(repo_data)
|
||||||
|
|
||||||
|
if recursive:
|
||||||
|
submodule_files = self.list_github_files_in_repo(
|
||||||
|
credentials, repo_data.id
|
||||||
|
)
|
||||||
|
repos.extend(submodule_files)
|
||||||
|
|
||||||
|
logger.info("GitHub repositories retrieved successfully: %s", len(repos))
|
||||||
|
return repos
|
||||||
|
|
||||||
|
def list_github_files_in_repo(self, credentials, repo_folder, recursive=False):
|
||||||
|
def fetch_files(endpoint, headers):
|
||||||
|
response = requests.get(endpoint, headers=headers)
|
||||||
|
if response.status_code == 401:
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.error(
|
||||||
|
"An error occurred while retrieving GitHub repository files: %s",
|
||||||
|
response.text,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
repo_name, folder_path = repo_folder.split(":")
|
||||||
|
token_data = self.get_github_token_data(credentials)
|
||||||
|
headers = self.get_github_headers(token_data)
|
||||||
|
endpoint = f"https://api.github.com/repos/{repo_name}/contents/{folder_path}"
|
||||||
|
logger.debug(f"Fetching files from GitHub with link: {endpoint}")
|
||||||
|
|
||||||
|
items = fetch_files(endpoint, headers)
|
||||||
|
|
||||||
|
if not items:
|
||||||
|
logger.info(f"No files found in GitHub repository {repo_name}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
files = []
|
||||||
|
for item in items:
|
||||||
|
file_data = SyncFile(
|
||||||
|
name=remove_special_characters(item.get("name")),
|
||||||
|
id=f"{repo_name}:{item.get('path')}",
|
||||||
|
is_folder=item.get("type") == "dir",
|
||||||
|
last_modified=str(item.get("updated_at")),
|
||||||
|
mime_type=item.get("type"),
|
||||||
|
web_view_link=item.get("html_url"),
|
||||||
|
)
|
||||||
|
files.append(file_data)
|
||||||
|
|
||||||
|
if recursive and file_data.is_folder:
|
||||||
|
folder_files = self.list_github_files_in_repo(
|
||||||
|
credentials, repo_folder=file_data.id, recursive=True
|
||||||
|
)
|
||||||
|
files.extend(folder_files)
|
||||||
|
|
||||||
|
logger.info(f"GitHub repository files retrieved successfully: {len(files)}")
|
||||||
|
return files
|
||||||
|
@ -4,6 +4,7 @@ from typing import List
|
|||||||
|
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
from pydantic import BaseModel, ConfigDict
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
from quivr_api.logger import get_logger
|
from quivr_api.logger import get_logger
|
||||||
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
from quivr_api.modules.brain.repository.brains_vectors import BrainsVectors
|
||||||
from quivr_api.modules.knowledge.repository.storage import Storage
|
from quivr_api.modules.knowledge.repository.storage import Storage
|
||||||
@ -59,7 +60,7 @@ class SyncUtils(BaseModel):
|
|||||||
dict: A dictionary containing the status of the download or an error message.
|
dict: A dictionary containing the status of the download or an error message.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
credentials = self.sync_cloud.check_and_refresh_access_token(credentials)
|
# credentials = self.sync_cloud.check_and_refresh_access_token(credentials)
|
||||||
|
|
||||||
downloaded_files = []
|
downloaded_files = []
|
||||||
bulk_id = uuid.uuid4()
|
bulk_id = uuid.uuid4()
|
||||||
|
@ -48,6 +48,7 @@ export const syncDropbox = async (
|
|||||||
).data;
|
).data;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
||||||
export const getUserSyncs = async (
|
export const getUserSyncs = async (
|
||||||
axiosInstance: AxiosInstance
|
axiosInstance: AxiosInstance
|
||||||
): Promise<Sync[]> => {
|
): Promise<Sync[]> => {
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
export type Provider = "Google" | "Azure" | "DropBox";
|
export type Provider = "Google" | "Azure" | "DropBox" | "GitHub";
|
||||||
|
|
||||||
export type Integration = "Google Drive" | "Share Point" | "Dropbox";
|
export type Integration = "Google Drive" | "Share Point" | "Dropbox" | "GitHub";
|
||||||
|
|
||||||
export interface SyncElement {
|
export interface SyncElement {
|
||||||
name?: string;
|
name?: string;
|
||||||
|
@ -10,9 +10,10 @@ import {
|
|||||||
getUserSyncs,
|
getUserSyncs,
|
||||||
syncDropbox,
|
syncDropbox,
|
||||||
syncFiles,
|
syncFiles,
|
||||||
|
syncGitHub,
|
||||||
syncGoogleDrive,
|
syncGoogleDrive,
|
||||||
syncSharepoint,
|
syncSharepoint,
|
||||||
updateActiveSync,
|
updateActiveSync
|
||||||
} from "./sync";
|
} from "./sync";
|
||||||
import { Integration, OpenedConnection, Provider } from "./types";
|
import { Integration, OpenedConnection, Provider } from "./types";
|
||||||
|
|
||||||
@ -27,6 +28,8 @@ export const useSync = () => {
|
|||||||
"https://quivr-cms.s3.eu-west-3.amazonaws.com/sharepoint_8c41cfdb09.png",
|
"https://quivr-cms.s3.eu-west-3.amazonaws.com/sharepoint_8c41cfdb09.png",
|
||||||
DropBox:
|
DropBox:
|
||||||
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
|
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
|
||||||
|
GitHub:
|
||||||
|
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
|
||||||
};
|
};
|
||||||
|
|
||||||
const integrationIconUrls: Record<Integration, string> = {
|
const integrationIconUrls: Record<Integration, string> = {
|
||||||
@ -36,6 +39,8 @@ export const useSync = () => {
|
|||||||
"https://quivr-cms.s3.eu-west-3.amazonaws.com/sharepoint_8c41cfdb09.png",
|
"https://quivr-cms.s3.eu-west-3.amazonaws.com/sharepoint_8c41cfdb09.png",
|
||||||
Dropbox:
|
Dropbox:
|
||||||
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
|
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
|
||||||
|
GitHub:
|
||||||
|
"https://quivr-cms.s3.eu-west-3.amazonaws.com/dropbox_dce4f3d753.png",
|
||||||
};
|
};
|
||||||
|
|
||||||
const getActiveSyncsForBrain = async (brainId: string) => {
|
const getActiveSyncsForBrain = async (brainId: string) => {
|
||||||
@ -49,6 +54,7 @@ export const useSync = () => {
|
|||||||
syncGoogleDrive(name, axiosInstance),
|
syncGoogleDrive(name, axiosInstance),
|
||||||
syncSharepoint: async (name: string) => syncSharepoint(name, axiosInstance),
|
syncSharepoint: async (name: string) => syncSharepoint(name, axiosInstance),
|
||||||
syncDropbox: async (name: string) => syncDropbox(name, axiosInstance),
|
syncDropbox: async (name: string) => syncDropbox(name, axiosInstance),
|
||||||
|
syncGitHub: async (name: string) => syncGitHub(name, axiosInstance),
|
||||||
getUserSyncs: async () => getUserSyncs(axiosInstance),
|
getUserSyncs: async () => getUserSyncs(axiosInstance),
|
||||||
getSyncFiles: async (userSyncId: number, folderId?: string) =>
|
getSyncFiles: async (userSyncId: number, folderId?: string) =>
|
||||||
getSyncFiles(axiosInstance, userSyncId, folderId),
|
getSyncFiles(axiosInstance, userSyncId, folderId),
|
||||||
|
@ -10,7 +10,7 @@ interface ConnectionCardsProps {
|
|||||||
export const ConnectionCards = ({
|
export const ConnectionCards = ({
|
||||||
fromAddKnowledge,
|
fromAddKnowledge,
|
||||||
}: ConnectionCardsProps): JSX.Element => {
|
}: ConnectionCardsProps): JSX.Element => {
|
||||||
const { syncGoogleDrive, syncSharepoint, syncDropbox } = useSync();
|
const { syncGoogleDrive, syncSharepoint, syncDropbox, syncGitHub } = useSync();
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<div
|
<div
|
||||||
|
Loading…
Reference in New Issue
Block a user