gokilashree's picture
Update app.py
0501446 verified
raw
history blame
4.16 kB
import torch
from transformers import MBartForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, pipeline
import gradio as gr
import requests
import io
from PIL import Image
import os
# Set up the Hugging Face API key from environment variables
hf_api_key = os.getenv("new_hf_token")
if not hf_api_key:
raise ValueError("Hugging Face API key not found! Please set the 'HF_API_KEY' environment variable.")
headers = {"Authorization": f"Bearer {hf_api_key}"}
# Define the text-to-image model URLs
model_urls = {
"stable_diffusion_v1_4": "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4",
"stable_diffusion_v1_5": "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5",
}
API_URL = model_urls["stable_diffusion_v1_4"]
# Define the translation model for multilingual text inputs
translation_model_name = "facebook/mbart-large-50-many-to-one-mmt"
tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
translation_model = MBartForConditionalGeneration.from_pretrained(translation_model_name)
# Load a text generation model from Hugging Face
text_generation_model_name = "EleutherAI/gpt-neo-2.7B"
text_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
text_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name, device_map="auto", torch_dtype=torch.float32)
# Create a pipeline for text generation
text_generator = pipeline("text-generation", model=text_model, tokenizer=text_tokenizer)
# Function to generate an image using Hugging Face's text-to-image model
def generate_image_from_text(translated_text):
payload = {"inputs": translated_text, "options": {"wait_for_model": True}}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
image_data = response.content
image = Image.open(io.BytesIO(image_data))
return image
else:
# If the model is loading, check the estimated wait time
if response.status_code == 503:
error_message = response.json()
estimated_time = error_message.get("estimated_time", "Unknown")
return f"Model is currently loading. Estimated wait time: {estimated_time} seconds. Try again later."
else:
return f"Failed to generate image. Error: {response.status_code}, Message: {response.text}"
# Function to translate text using the MBart model
def translate_text(input_text, src_lang="en"):
# Tokenize and translate
tokenizer.src_lang = src_lang
encoded_input = tokenizer(input_text, return_tensors="pt")
translated_tokens = translation_model.generate(**encoded_input)
translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translated_text
# Function to generate text using the GPT-Neo model
def generate_text(prompt, max_length=50):
generated_texts = text_generator(prompt, max_length=max_length, num_return_sequences=1)
return generated_texts[0]["generated_text"]
# Define the Gradio Interface
def app_interface(input_text, src_language="en"):
translated_text = translate_text(input_text, src_lang=src_language)
generated_image = generate_image_from_text(translated_text)
generated_text = generate_text(translated_text)
return generated_text, generated_image
# Launch the Gradio App using the new Gradio components
with gr.Blocks() as demo:
gr.Markdown("# Multilingual Text-to-Image & Text Generation")
# Define Gradio components
input_text = gr.Textbox(lines=2, placeholder="Enter text here...")
src_language = gr.Dropdown(["en", "fr", "de", "es"], value="en", label="Source Language")
# Display outputs for text and image generation
generated_text_output = gr.Textbox(label="Generated Text")
generated_image_output = gr.Image(label="Generated Image")
# Button to trigger the processing
generate_button = gr.Button("Generate")
# Link the button to the function call
generate_button.click(fn=app_interface, inputs=[input_text, src_language], outputs=[generated_text_output, generated_image_output])
# Run the app
demo.launch()