Spaces:
Build error
Build error
File size: 4,561 Bytes
9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 9d051b5 1376e14 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForZeroShotImageClassification
from diffusers import DiffusionPipeline
import requests
from PIL import Image
from io import BytesIO
# Initialize models
anime_model = DiffusionPipeline.from_pretrained("SmilingWolf/wd-v1-4-vit-tagger")
photo_model = AutoModelForZeroShotImageClassification.from_pretrained("facebook/florence-base-in21k-retrieval")
processor = AutoProcessor.from_pretrained("facebook/florence-base-in21k-retrieval")
def get_booru_image(booru, image_id):
# This is a placeholder function. You'd need to implement the actual API calls for each booru.
url = f"https://api.{booru}.org/images/{image_id}"
response = requests.get(url)
img = Image.open(BytesIO(response.content))
tags = ["tag1", "tag2", "tag3"] # Placeholder
return img, tags
def transcribe_image(image, image_type, transcriber, booru_tags=None):
if image_type == "Anime":
with torch.no_grad():
tags = anime_model(image)
else:
inputs = processor(images=image, return_tensors="pt")
outputs = photo_model(**inputs)
tags = outputs.logits.topk(50).indices.squeeze().tolist()
tags = [processor.config.id2label[t] for t in tags]
if booru_tags:
tags = list(set(tags + booru_tags))
return ", ".join(tags)
def update_image(image_type, booru, image_id, uploaded_image):
if image_type == "Anime" and booru != "Upload":
image, booru_tags = get_booru_image(booru, image_id)
return image, gr.update(visible=True), booru_tags
elif uploaded_image is not None:
return uploaded_image, gr.update(visible=True), None
else:
return None, gr.update(visible=False), None
def on_image_type_change(image_type):
if image_type == "Anime":
return gr.update(visible=True), gr.update(visible=True), gr.update(choices=["Anime", "Photo/Other"])
else:
return gr.update(visible=False), gr.update(visible=True), gr.update(choices=["Photo/Other", "Anime"])
with gr.Blocks() as app:
gr.Markdown("# Image Transcription App")
with gr.Tab("Step 1: Image"):
image_type = gr.Dropdown(["Anime", "Photo/Other"], label="Image type")
with gr.Column(visible=False) as anime_options:
booru = gr.Dropdown(["Gelbooru", "Danbooru", "Upload"], label="Boorus")
image_id = gr.Textbox(label="Image ID")
get_image_btn = gr.Button("Get image")
upload_btn = gr.UploadButton("Upload Image", visible=False)
image_display = gr.Image(label="Image to transcribe", visible=False)
booru_tags = gr.State(None)
transcribe_btn = gr.Button("Transcribe", visible=False)
transcribe_with_tags_btn = gr.Button("Transcribe with booru tags", visible=False)
with gr.Tab("Step 2: Transcribe"):
transcriber = gr.Dropdown(["Anime", "Photo/Other"], label="Transcriber")
transcribe_image_display = gr.Image(label="Image to transcribe")
transcribe_btn_final = gr.Button("Transcribe")
tags_output = gr.Textbox(label="Transcribed tags")
image_type.change(on_image_type_change, inputs=[image_type],
outputs=[anime_options, upload_btn, transcriber])
get_image_btn.click(update_image,
inputs=[image_type, booru, image_id, upload_btn],
outputs=[image_display, transcribe_btn, booru_tags])
upload_btn.upload(update_image,
inputs=[image_type, booru, image_id, upload_btn],
outputs=[image_display, transcribe_btn, booru_tags])
def transcribe_and_update(image, image_type, transcriber, booru_tags):
tags = transcribe_image(image, image_type, transcriber, booru_tags)
return image, tags
transcribe_btn.click(transcribe_and_update,
inputs=[image_display, image_type, transcriber, booru_tags],
outputs=[transcribe_image_display, tags_output])
transcribe_with_tags_btn.click(transcribe_and_update,
inputs=[image_display, image_type, transcriber, booru_tags],
outputs=[transcribe_image_display, tags_output])
transcribe_btn_final.click(transcribe_image,
inputs=[transcribe_image_display, image_type, transcriber],
outputs=[tags_output])
app.launch() |