gokilashree's picture
Update app.py
5826cb3 verified
raw
history blame
3.73 kB
import torch
from transformers import MBartForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, pipeline
import gradio as gr
import requests
import io
from PIL import Image
import os
# Set up the Hugging Face API key from environment variables
hf_api_key = os.getenv("new_hf_token")
if not hf_api_key:
raise ValueError("Hugging Face API key not found! Please set the 'HF_API_KEY' environment variable.")
headers = {"Authorization": f"Bearer {hf_api_key}"}
# Define the text-to-image model URLs
model_urls = {
"stable_diffusion_v1_4": "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4",
"stable_diffusion_v1_5": "https://api-inference.huggingface.co/models/runwayml/stable-diffusion-v1-5",
}
API_URL = model_urls["stable_diffusion_v1_4"]
# Define the translation model for multilingual text inputs
translation_model_name = "facebook/mbart-large-50-many-to-one-mmt"
tokenizer = AutoTokenizer.from_pretrained(translation_model_name)
translation_model = MBartForConditionalGeneration.from_pretrained(translation_model_name)
# Load a text generation model from Hugging Face
text_generation_model_name = "EleutherAI/gpt-neo-2.7B"
text_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
text_model = AutoModelForCausalLM.from_pretrained(text_generation_model_name, device_map="auto", torch_dtype=torch.float32)
# Create a pipeline for text generation
text_generator = pipeline("text-generation", model=text_model, tokenizer=text_tokenizer)
# Function to generate an image using Hugging Face's text-to-image model
def generate_image_from_text(translated_text):
payload = {"inputs": translated_text, "options": {"wait_for_model": True}}
response = requests.post(API_URL, headers=headers, json=payload)
if response.status_code == 200:
image_data = response.content
image = Image.open(io.BytesIO(image_data))
return image
else:
# If the model is loading, check the estimated wait time
if response.status_code == 503:
error_message = response.json()
estimated_time = error_message.get("estimated_time", "Unknown")
return f"Model is currently loading. Estimated wait time: {estimated_time} seconds. Try again later."
else:
return f"Failed to generate image. Error: {response.status_code}, Message: {response.text}"
# Function to translate text using the MBart model
def translate_text(input_text, src_lang="en"):
# Tokenize and translate
tokenizer.src_lang = src_lang
encoded_input = tokenizer(input_text, return_tensors="pt")
translated_tokens = translation_model.generate(**encoded_input)
translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
return translated_text
# Function to generate text using the GPT-Neo model
def generate_text(prompt, max_length=150):
generated_texts = text_generator(prompt, max_length=max_length, num_return_sequences=1)
return generated_texts[0]["generated_text"]
# Define the Gradio Interface
def app_interface(input_text, src_language="en"):
translated_text = translate_text(input_text, src_lang=src_language)
generated_image = generate_image_from_text(translated_text)
generated_text = generate_text(translated_text)
return generated_text, generated_image
# Launch the Gradio App
gr.Interface(
fn=app_interface,
inputs=[gr.inputs.Textbox(lines=2, placeholder="Enter text here..."), gr.inputs.Dropdown(["en", "fr", "de", "es"], default="en", label="Source Language")],
outputs=[gr.outputs.Textbox(label="Generated Text"), gr.outputs.Image(label="Generated Image")],
title="Multilingual Text-to-Image & Text Generation"
).launch()