Spaces:
Build error
Build error
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification | |
from diffusers import DiffusionPipeline | |
import requests | |
from PIL import Image | |
from io import BytesIO | |
# Initialize models | |
anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-v1-4-vit-tagger") | |
photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval") | |
processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval") | |
def get_booru_image(booru, image_id): | |
# This is a placeholder function. You'd need to implement the actual API calls for each booru. | |
url = f"https://api.{booru}.org/images/{image_id}" | |
response = requests.get(url) | |
img = Image.open(BytesIO(response.content)) | |
tags = ["tag1", "tag2", "tag3"] # Placeholder | |
return img, tags | |
def transcribe_image(image, image_type, transcriber, booru_tags=None): | |
if image_type == "Anime": | |
with torch.no_grad(): | |
tags = anime_model(image) | |
else: | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = photo_model(**inputs) | |
tags = outputs.logits.topk(50).indices.squeeze().tolist() | |
tags = [processor.config.id2label[t] for t in tags] | |
if booru_tags: | |
tags = list(set(tags + booru_tags)) | |
return ", ".join(tags) | |
def update_image(image_type, booru, image_id, uploaded_image): | |
if image_type == "Anime" and booru != "Upload": | |
image, booru_tags = get_booru_image(booru, image_id) | |
return image, gr.update(visible=True), booru_tags | |
elif uploaded_image is not None: | |
return uploaded_image, gr.update(visible=True), None | |
else: | |
return None, gr.update(visible=False), None | |
def on_image_type_change(image_type): | |
if image_type == "Anime": | |
return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"]) | |
else: | |
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"]) | |
with gr.Blocks() as app: | |
gr.Markdown("# Image Transcription App") | |
with gr.Tab("Step 1: Image"): | |
image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type") | |
with gr.Column(visible=False) as anime_options: | |
booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus") | |
image_id = gr.Textbox(label="Image ID") | |
get_image_btn = gr.Button("Get image") | |
upload_btn = gr.UploadButton("Upload Image", visible=False) | |
image_display = gr.Image(label="Image to transcribe", visible=False) | |
booru_tags = gr.State(None) | |
transcribe_btn = gr.Button("Transcribe", visible=False) | |
transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False) | |
with gr.Tab("Step 2: Transcribe"): | |
transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber") | |
transcribe_image_display = gr.Image(label="Image to transcribe") | |
transcribe_btn_final = gr.Button("Transcribe") | |
tags_output = gr.Textbox(label="Transcribed tags") | |
image_type.change(on_image_type_change, inputs=[image_type], | |
outputs=[anime_options, upload_btn, transcriber]) | |
get_image_btn.click(update_image, | |
inputs=[image_type, booru, image_id, upload_btn], | |
outputs=[image_display, transcribe_btn, booru_tags]) | |
upload_btn.upload(update_image, | |
inputs=[image_type, booru, image_id, upload_btn], | |
outputs=[image_display, transcribe_btn, booru_tags]) | |
def transcribe_and_update(image, image_type, transcriber, booru_tags): | |
tags = transcribe_image(image, image_type, transcriber, booru_tags) | |
return image, tags | |
transcribe_btn.click(transcribe_and_update, | |
inputs=[image_display, image_type, transcriber, booru_tags], | |
outputs=[transcribe_image_display, tags_output]) | |
transcribe_with_tags_btn.click(transcribe_and_update, | |
inputs=[image_display, image_type, transcriber, booru_tags], | |
outputs=[transcribe_image_display, tags_output]) | |
transcribe_btn_final.click(transcribe_image, | |
inputs=[transcribe_image_display, image_type, transcriber], | |
outputs=[tags_output]) | |
app.launch() |