throaway2854's picture
Update app.py
b1c343c verified
raw
history blame
2.96 kB
># Import necessary libraries
import gradio as gr
import json
import os
import zipfile
# Define helper functions
def create_dataset(dataset_name):
dataset_path = f'{dataset_name}.zip'
if not os.path.exists(dataset_path):
with zipfile.ZipFile(dataset_path, 'w') as zip_file:
zip_file.writestr('images/', '')
zip_file.writestr('data.jsonl', '')
return dataset_path
def upload_pair(dataset_path, image, prompt):
with zipfile.ZipFile(dataset_path, 'a') as zip_file:
image_path = f'images/{image.name}'
zip_file.writestr(image_path, image.read())
data = {'image': image_path, 'prompt': prompt}
zip_file.writestr('data.jsonl', json.dumps(data) + '\n')
def edit_prompt(dataset_path, image_path, new_prompt):
with zipfile.ZipFile(dataset_path, 'r') as zip_file:
data = json.load(zip_file.open('data.jsonl'))
for item in data:
if item['image'] == image_path:
item['prompt'] = new_prompt
break
with zipfile.ZipFile(dataset_path, 'w') as zip_file:
zip_file.writestr('data.jsonl', json.dumps(data))
def delete_pair(dataset_path, image_path):
with zipfile.ZipFile(dataset_path, 'r') as zip_file:
data = json.load(zip_file.open('data.jsonl'))
data = [item for item in data if item['image'] != image_path]
with zipfile.ZipFile(dataset_path, 'w') as zip_file:
zip_file.writestr('data.jsonl', json.dumps(data))
def download_dataset(dataset_path):
return dataset_path
# Define Gradio application
demo = gr.Blocks()
with demo:
# Create dataset
dataset_name = gr.Textbox(label='Dataset Name')
create_button = gr.Button('Create Dataset')
create_button.click(create_dataset, inputs=[dataset_name], outputs=[])
# Upload pair
image_upload = gr.File(label='Image')
prompt = gr.Textbox(label='Prompt')
upload_button = gr.Button('Upload Pair')
upload_button.click(upload_pair, inputs=[dataset_name, image_upload, prompt], outputs=[])
# Edit prompt
image_path = gr.Textbox(label='Image Path')
new_prompt = gr.Textbox(label='New Prompt')
edit_button = gr.Button('Edit Prompt')
edit_button.click(edit_prompt, inputs=[dataset_name, image_path, new_prompt], outputs=[])
# Delete pair
delete_button = gr.Button('Delete Pair')
delete_button.click(delete_pair, inputs=[dataset_name, image_path], outputs=[])
# Download dataset
download_button = gr.Button('Download Dataset')
download_button.click(download_dataset, inputs=[dataset_name], outputs=[])
# Upload dataset
dataset_upload = gr.File(label='Dataset')
upload_dataset_button = gr.Button('Upload Dataset')
upload_dataset_button.click(create_dataset, inputs=[dataset_upload], outputs=[])
# Horizontal gallery
gallery = gr.Gallery(label='Dataset Gallery')
demo.append(gallery)
# Launch Gradio application
demo.launch()