from pydrive2.auth import GoogleAuth from pydrive2.drive import GoogleDrive import os import gradio as gr from datasets import load_dataset, Dataset import pandas as pd from PIL import Image from tqdm import tqdm import logging import yaml # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Load settings with open('settings.yaml', 'r') as file: settings = yaml.safe_load(file) class DatasetManager: def __init__(self, local_images_dir="downloaded_cards"): self.local_images_dir = local_images_dir self.drive = None self.dataset_name = "GotThatData/sports-cards" # Create local directory if it doesn't exist os.makedirs(local_images_dir, exist_ok=True) def authenticate_drive(self): """Authenticate with Google Drive""" try: gauth = GoogleAuth() gauth.settings['client_config_file'] = settings['client_secrets_file'] # Try to load saved credentials gauth.LoadCredentialsFile("credentials.txt") if gauth.credentials is None: gauth.LocalWebserverAuth() elif gauth.access_token_expired: gauth.Refresh() else: gauth.Authorize() gauth.SaveCredentialsFile("credentials.txt") self.drive = GoogleDrive(gauth) return True, "Successfully authenticated with Google Drive" except Exception as e: return False, f"Authentication failed: {str(e)}" def download_and_rename_files(self, drive_folder_id, naming_convention): """Download files from Google Drive and rename them""" if not self.drive: return False, "Google Drive not authenticated", [] try: query = f"'{drive_folder_id}' in parents and trashed=false" file_list = self.drive.ListFile({'q': query}).GetList() if not file_list: file = self.drive.CreateFile({'id': drive_folder_id}) if file: file_list = [file] else: return False, "No files found with the specified ID", [] renamed_files = [] try: existing_dataset = load_dataset(self.dataset_name) logger.info(f"Loaded existing dataset: {self.dataset_name}") start_index = len(existing_dataset['train']) if 'train' in existing_dataset else 0 except Exception as e: logger.info(f"No existing dataset found, starting fresh: {str(e)}") start_index = 0 for i, file in enumerate(tqdm(file_list, desc="Downloading files")): if file['mimeType'].startswith('image/'): new_filename = f"{naming_convention}_{start_index + i + 1}.jpg" file_path = os.path.join(self.local_images_dir, new_filename) file.GetContentFile(file_path) try: with Image.open(file_path) as img: img.verify() renamed_files.append({ 'file_path': file_path, 'original_name': file['title'], 'new_name': new_filename, 'image': file_path }) except Exception as e: logger.error(f"Error processing image {file['title']}: {str(e)}") if os.path.exists(file_path): os.remove(file_path) return True, f"Successfully processed {len(renamed_files)} images", renamed_files except Exception as e: return False, f"Error downloading files: {str(e)}", [] def update_huggingface_dataset(self, renamed_files): """Update the sports-cards dataset with new images""" try: df = pd.DataFrame(renamed_files) new_dataset = Dataset.from_pandas(df) try: existing_dataset = load_dataset(self.dataset_name) if 'train' in existing_dataset: new_dataset = concatenate_datasets([existing_dataset['train'], new_dataset]) except Exception: logger.info("Creating new dataset") new_dataset.push_to_hub(self.dataset_name, split="train") return True, f"Successfully updated dataset '{self.dataset_name}' with {len(renamed_files)} new images" except Exception as e: return False, f"Error updating Hugging Face dataset: {str(e)}" def process_pipeline(folder_id, naming_convention): """Main pipeline to process images and update dataset""" manager = DatasetManager() auth_success, auth_message = manager.authenticate_drive() if not auth_success: return auth_message success, message, renamed_files = manager.download_and_rename_files(folder_id, naming_convention) if not success: return message success, hf_message = manager.update_huggingface_dataset(renamed_files) return f"{message}\n{hf_message}" # Custom CSS for web-safe fonts and clean styling custom_css = """ .gradio-container { font-family: Arial, sans-serif !important; } h1, h2, h3 { font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif !important; font-weight: 600 !important; } .gr-button { font-family: Arial, sans-serif !important; } .gr-input { font-family: 'Courier New', Courier, monospace !important; } .gr-box { border-radius: 8px !important; border: 1px solid #e5e5e5 !important; } .gr-padded { padding: 16px !important; } """ # Gradio interface with custom theme with gr.Blocks(css=custom_css) as demo: gr.Markdown("# Sports Cards Dataset Processor") with gr.Box(): gr.Markdown(""" ### Instructions 1. Enter the Google Drive folder/file ID 2. Choose a naming convention for your cards 3. Click Process to start """) with gr.Row(): with gr.Column(): folder_id = gr.Textbox( label="Google Drive File/Folder ID", placeholder="Enter the ID from your Google Drive URL", value="151VOxPO91mg0C3ORiioGUd4hogzP1ujm" ) naming = gr.Textbox( label="Naming Convention", placeholder="e.g., sports_card", value="sports_card" ) process_btn = gr.Button("Process Images", variant="primary") with gr.Box(): output = gr.Textbox( label="Processing Status", show_label=True, lines=5 ) process_btn.click( fn=process_pipeline, inputs=[folder_id, naming], outputs=output ) if __name__ == "__main__": demo.launch()