gokilashree's picture
Update app.py
e515701 verified
raw
history blame
3.75 kB
import torch
from transformers import MBartForConditionalGeneration, AutoTokenizer, AutoModelForCausalLM, pipeline
import gradio as gr
import requests
import io
from PIL import Image
import os
# Set up the Hugging Face API key from environment variables
hf_api_key = os.getenv("new_hf_token")
if not hf_api_key:
raise ValueError("Hugging Face API key not found! Please set the 'HF_API_KEY' environment variable.")
headers = {"Authorization": f"Bearer {hf_api_key}"}
# Define the text-to-image model URL
API_URL = "https://api-inference.huggingface.co/models/CompVis/stable-diffusion-v1-4"
# Use AutoTokenizer to avoid tokenizer mismatch warnings
translation_model_name = "facebook/mbart-large-50-many-to-one-mmt"
tokenizer = AutoTokenizer.from_pretrained(translation_model_name) # Use AutoTokenizer to avoid warnings
translation_model = MBartForConditionalGeneration.from_pretrained(translation_model_name)
# Load a text generation model from Hugging Face using accelerate for memory optimization
text_generation_model_name = "EleutherAI/gpt-neo-2.7B"
text_tokenizer = AutoTokenizer.from_pretrained(text_generation_model_name)
text_model = AutoModelForCausalLM.from_pretrained(
text_generation_model_name,
device_map="auto",
torch_dtype=torch.float32
)
# Create a pipeline for text generation
text_generator = pipeline("text-generation", model=text_model, tokenizer=text_tokenizer)
# Function to generate an image using Hugging Face's text-to-image model
def generate_image_from_text(translated_text):
try:
response = requests.post(API_URL, headers=headers, json={"inputs": translated_text})
if response.status_code != 200:
return None, f"Error generating image: {response.text}"
image_bytes = response.content
image = Image.open(io.BytesIO(image_bytes))
return image, None
except Exception as e:
return None, f"Error during image generation: {e}"
# Define the function to translate Tamil text, generate an image, and create a descriptive text
def translate_generate_image_and_text(tamil_text):
try:
tokenizer.src_lang = "ta_IN"
inputs = tokenizer(tamil_text, return_tensors="pt")
translated_tokens = translation_model.generate(**inputs, forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
translated_text = tokenizer.batch_decode(translated_tokens, skip_special_tokens=True)[0]
except Exception as e:
return f"Error during translation: {e}", None, None
try:
image, error_message = generate_image_from_text(translated_text)
if error_message:
return translated_text, None, error_message
except Exception as e:
return translated_text, None, f"Error during image generation: {e}"
try:
descriptive_text = text_generator(translated_text, max_length=100, num_return_sequences=1, temperature=0.7, top_p=0.9)[0]['generated_text']
except Exception as e:
return translated_text, image, f"Error during text generation: {e}"
return translated_text, image, descriptive_text
# Gradio interface setup
iface = gr.Interface(
fn=translate_generate_image_and_text,
inputs=gr.Textbox(lines=2, placeholder="Enter Tamil text here..."),
outputs=[gr.Textbox(label="Translated English Text"),
gr.Image(label="Generated Image"),
gr.Textbox(label="Generated Descriptive Text")],
title="Tamil to English Translation, Image Creation, and Descriptive Text Generation",
description="Translate Tamil text to English using Facebook's mbart-large-50 model, create an image using the translated text, and generate a descriptive text based on the translated content.",
)
# Launch the Gradio app
iface.launch()