doc-mcp / rag /github_file_loader.py
mdabidhussain's picture
created doc-mcp
56f7920
raw
history blame
17.8 kB
import asyncio
import base64
import logging
import re
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import aiohttp
import requests
from llama_index.core.schema import Document
logger = logging.getLogger(__name__)
class GithubFileLoader:
"""
GitHub file loader that fetches specific files asynchronously.
Returns LlamaIndex Document objects for each successfully loaded file.
"""
def __init__(
self,
github_token: Optional[str] = None,
concurrent_requests: int = 10,
timeout: int = 30,
retries: int = 3,
):
"""
Initialize GitHub file loader.
Args:
github_token: GitHub API token for higher rate limits
concurrent_requests: Number of concurrent requests
timeout: Request timeout in seconds
retries: Number of retry attempts for failed requests
"""
self.github_token = github_token
self.concurrent_requests = concurrent_requests
self.timeout = timeout
self.retries = retries
# Setup headers
self.headers = {
"Accept": "application/vnd.github.v3+json",
"User-Agent": "LlamaIndex-GitHub-Loader/1.0",
}
if self.github_token:
self.headers["Authorization"] = f"token {self.github_token}"
def fetch_repository_files(
self,
repo_url: str,
file_extensions: List[str] = [".md", ".mdx"],
branch: str = "main",
) -> Tuple[List[str], str]:
"""
Fetch files from GitHub repository using GitHub API
Args:
repo_url: GitHub repository URL or owner/repo format
file_extensions: List of file extensions to filter (e.g., [".md", ".mdx", ".txt"])
branch: Branch name to fetch from
Returns:
Tuple of (list_of_file_paths, status_message)
"""
try:
# Parse GitHub URL to extract owner and repo
repo_name = self._parse_repo_name(repo_url)
if not repo_name:
return (
[],
"Invalid GitHub URL format. Use: https://github.com/owner/repo or owner/repo",
)
# GitHub API endpoint for repository tree
api_url = f"https://api.github.com/repos/{repo_name}/git/trees/{branch}?recursive=1"
# Make request with authentication if token is available
response = requests.get(api_url, headers=self.headers, timeout=self.timeout)
if response.status_code == 200:
data = response.json()
filtered_files = []
# Filter for specified file extensions
for item in data.get("tree", []):
if item["type"] == "blob":
file_path = item["path"]
# Check if file has any of the specified extensions
if any(
file_path.lower().endswith(ext.lower())
for ext in file_extensions
):
filtered_files.append(file_path)
if filtered_files:
ext_str = ", ".join(file_extensions)
return (
filtered_files,
f"Found {len(filtered_files)} files with extensions ({ext_str}) in {repo_name}/{branch}",
)
else:
ext_str = ", ".join(file_extensions)
return (
[],
f"No files with extensions ({ext_str}) found in repository {repo_name}/{branch}",
)
elif response.status_code == 404:
return (
[],
f"Repository '{repo_name}' not found or branch '{branch}' doesn't exist",
)
elif response.status_code == 403:
if "rate limit" in response.text.lower():
return (
[],
"GitHub API rate limit exceeded. Consider using a GitHub token.",
)
else:
return (
[],
"Access denied. Repository may be private or require authentication.",
)
else:
return (
[],
f"GitHub API Error: {response.status_code} - {response.text[:200]}",
)
except requests.exceptions.Timeout:
return [], f"Request timeout after {self.timeout} seconds"
except requests.exceptions.RequestException as e:
return [], f"Network error: {str(e)}"
except Exception as e:
return [], f"Unexpected error: {str(e)}"
def _parse_repo_name(self, repo_url: str) -> Optional[str]:
"""
Parse repository URL to extract owner/repo format
Args:
repo_url: GitHub repository URL or owner/repo format
Returns:
Repository name in "owner/repo" format or None if invalid
"""
if "github.com" in repo_url:
# Extract from full URL
parts = (
repo_url.replace("https://github.com/", "")
.replace("http://github.com/", "")
.strip("/")
.split("/")
)
if len(parts) >= 2:
return f"{parts[0]}/{parts[1]}"
else:
# Assume format is owner/repo
parts = repo_url.strip().split("/")
if len(parts) == 2 and all(part.strip() for part in parts):
return repo_url.strip()
return None
def fetch_markdown_files(
self, repo_url: str, branch: str = "main"
) -> Tuple[List[str], str]:
"""
Fetch markdown files from GitHub repository (backward compatibility method)
Args:
repo_url: GitHub repository URL or owner/repo format
branch: Branch name to fetch from
Returns:
Tuple of (list_of_markdown_files, status_message)
"""
return self.fetch_repository_files(
repo_url=repo_url, file_extensions=[".md", ".mdx"], branch=branch
)
async def load_files(
self, repo_name: str, file_paths: List[str], branch: str = "main"
) -> Tuple[List[Document], List[str]]:
"""
Load files from GitHub repository asynchronously.
Args:
repo_name: Repository name in format "owner/repo"
file_paths: List of file paths to load
branch: Branch name to load from
Returns:
Tuple of (successfully_loaded_documents, failed_file_paths)
"""
if not file_paths:
return [], []
# Validate repo name format
if not re.match(r"^[^/]+/[^/]+$", repo_name):
raise ValueError(f"Invalid repo format: {repo_name}. Expected 'owner/repo'")
# Create semaphore to limit concurrent requests
semaphore = asyncio.Semaphore(self.concurrent_requests)
# Create session
connector = aiohttp.TCPConnector(limit=self.concurrent_requests)
timeout_config = aiohttp.ClientTimeout(total=self.timeout)
async with aiohttp.ClientSession(
headers=self.headers, connector=connector, timeout=timeout_config
) as session:
# Create tasks for all files
tasks = []
for file_path in file_paths:
task = asyncio.create_task(
self._fetch_file_with_retry(
session, semaphore, repo_name, file_path, branch
)
)
tasks.append(task)
# Wait for all tasks to complete
results = await asyncio.gather(*tasks, return_exceptions=True)
# Process results
documents = []
failed_files = []
for i, result in enumerate(results):
file_path = file_paths[i]
if isinstance(result, Exception):
logger.error(f"Failed to load {file_path}: {result}")
failed_files.append(file_path)
elif result is None:
logger.warning(f"No content returned for {file_path}")
failed_files.append(file_path)
else:
documents.append(result)
logger.info(
f"Successfully loaded {len(documents)} files, failed: {len(failed_files)}"
)
return documents, failed_files
async def _fetch_file_with_retry(
self,
session: aiohttp.ClientSession,
semaphore: asyncio.Semaphore,
repo_name: str,
file_path: str,
branch: str,
) -> Optional[Document]:
"""Fetch a single file with retry logic."""
async with semaphore:
for attempt in range(self.retries + 1):
try:
return await self._fetch_single_file(
session, repo_name, file_path, branch
)
except Exception as e:
if attempt == self.retries:
logger.error(
f"Failed to fetch {file_path} after {self.retries + 1} attempts: {e}"
)
raise
else:
logger.warning(
f"Attempt {attempt + 1} failed for {file_path}: {e}"
)
await asyncio.sleep(2**attempt) # Exponential backoff
return None
async def _fetch_single_file(
self,
session: aiohttp.ClientSession,
repo_name: str,
file_path: str,
branch: str,
) -> Document:
"""Fetch a single file from GitHub API."""
# Clean file path
clean_path = file_path.strip("/")
# Build API URL
api_url = f"https://api.github.com/repos/{repo_name}/contents/{clean_path}"
params = {"ref": branch}
logger.debug(f"Fetching: {api_url}")
async with session.get(api_url, params=params) as response:
if response.status == 404:
raise FileNotFoundError(f"File not found: {file_path}")
elif response.status == 403:
raise PermissionError("API rate limit exceeded or access denied")
elif response.status != 200:
raise Exception(f"HTTP {response.status}: {await response.text()}")
data = await response.json()
# Handle directory case
if isinstance(data, list):
raise ValueError(f"Path {file_path} is a directory, not a file")
# Decode file content
if data.get("encoding") == "base64":
try:
content_bytes = base64.b64decode(data["content"])
content_text = content_bytes.decode("utf-8")
except Exception as e:
logger.warning(f"Failed to decode {file_path}: {e}")
# Try to decode as latin-1 as fallback
content_text = content_bytes.decode("latin-1", errors="ignore")
else:
raise ValueError(f"Unsupported encoding: {data.get('encoding')}")
# Create Document
document = self._create_document(
content=content_text,
file_path=clean_path,
repo_name=repo_name,
branch=branch,
file_data=data,
)
return document
def _create_document(
self, content: str, file_path: str, repo_name: str, branch: str, file_data: Dict
) -> Document:
"""Create a LlamaIndex Document from file content and metadata."""
# Extract file info
filename = Path(file_path).name
file_extension = Path(file_path).suffix.lower()
directory = (
str(Path(file_path).parent) if Path(file_path).parent != Path(".") else ""
)
# Build URLs
html_url = f"https://github.com/{repo_name}/blob/{branch}/{file_path}"
raw_url = file_data.get("download_url", "")
# Create metadata
metadata = {
"file_path": file_path,
"file_name": filename,
"file_extension": file_extension,
"directory": directory,
"repo": repo_name,
"branch": branch,
"sha": file_data.get("sha", ""),
"size": file_data.get("size", 0),
"url": html_url,
"raw_url": raw_url,
"type": file_data.get("type", "file"),
}
# Create document with unique ID
doc_id = f"{repo_name}:{branch}:{file_path}"
document = Document(
text=content,
doc_id=doc_id,
metadata=metadata, # For backward compatibility
)
return document
def load_files_sync(
self, repo_name: str, file_paths: List[str], branch: str = "main"
) -> Tuple[List[Document], List[str]]:
"""
Synchronous wrapper for load_files.
Args:
repo_name: Repository name in format "owner/repo"
file_paths: List of file paths to load
branch: Branch name to load from
Returns:
Tuple of (successfully_loaded_documents, failed_file_paths)
"""
return asyncio.run(self.load_files(repo_name, file_paths, branch))
# Convenience functions
async def load_github_files_async(
repo_name: str,
file_paths: List[str],
branch: str = "main",
github_token: Optional[str] = None,
concurrent_requests: int = 10,
) -> Tuple[List[Document], List[str]]:
"""
Convenience function to load GitHub files asynchronously.
Args:
repo_name: Repository name in format "owner/repo"
file_paths: List of file paths to load
branch: Branch name to load from
github_token: GitHub API token
concurrent_requests: Number of concurrent requests
Returns:
Tuple of (documents, failed_files)
"""
loader = GithubFileLoader(
github_token=github_token, concurrent_requests=concurrent_requests
)
return await loader.load_files(repo_name, file_paths, branch)
def load_github_files(
repo_name: str,
file_paths: List[str],
branch: str = "main",
github_token: Optional[str] = None,
concurrent_requests: int = 10,
) -> Tuple[List[Document], List[str]]:
"""
Convenience function to load GitHub files synchronously.
Args:
repo_name: Repository name in format "owner/repo"
file_paths: List of file paths to load
branch: Branch name to load from
github_token: GitHub API token
concurrent_requests: Number of concurrent requests
Returns:
Tuple of (documents, failed_files)
"""
loader = GithubFileLoader(
github_token=github_token, concurrent_requests=concurrent_requests
)
return loader.load_files_sync(repo_name, file_paths, branch)
def fetch_markdown_files(
repo_url: str, github_token: Optional[str] = None, branch: str = "main"
) -> Tuple[List[str], str]:
"""
Convenience function to fetch markdown files from GitHub repository
Args:
repo_url: GitHub repository URL or owner/repo format
github_token: GitHub API token for higher rate limits
branch: Branch name to fetch from
Returns:
Tuple of (list_of_files, status_message)
"""
loader = GithubFileLoader(github_token=github_token)
return loader.fetch_markdown_files(repo_url, branch)
def fetch_repository_files(
repo_url: str,
file_extensions: List[str] = [".md", ".mdx"],
github_token: Optional[str] = None,
branch: str = "main",
) -> Tuple[List[str], str]:
"""
Convenience function to fetch files with specific extensions from GitHub repository
Args:
repo_url: GitHub repository URL or owner/repo format
file_extensions: List of file extensions to filter
github_token: GitHub API token for higher rate limits
branch: Branch name to fetch from
Returns:
Tuple of (list_of_files, status_message)
"""
loader = GithubFileLoader(github_token=github_token)
return loader.fetch_repository_files(repo_url, file_extensions, branch)
# Example usage
if __name__ == "__main__":
# Example file paths
file_paths = [
"docs/contribute/docs.mdx",
"docs/contribute/ml-handlers.mdx",
"docs/contribute/community.mdx",
"docs/contribute/python-coding-standards.mdx",
"docs/features/data-integrations.mdx",
"docs/features/ai-integrations.mdx",
"docs/integrations/ai-engines/langchain_embedding.mdx",
"docs/integrations/ai-engines/langchain.mdx",
"docs/integrations/ai-engines/google_gemini.mdx",
"docs/integrations/ai-engines/anomaly.mdx",
"docs/integrations/ai-engines/amazon-bedrock.mdx",
]
# Load files synchronously
documents, failed = load_github_files(
repo_name="mindsdb/mindsdb",
file_paths=file_paths,
branch="main", # Optional
)
print(f"Loaded {len(documents)} documents")
print(f"Failed to load {len(failed)} files: {failed}")
# Print first document info
if documents:
doc = documents[0]
print("\nFirst document:")
print(f"ID: {doc.doc_id}")
print(f"File: {doc.metadata['file_path']}")
print(f"Size: {len(doc.text)} characters")
print(f"Content preview: {doc.text[:200]}...")