File size: 2,829 Bytes
93cb70c
 
 
 
3a6b0f7
93cb70c
 
3f10588
93cb70c
3a6b0f7
 
 
3f10588
3a6b0f7
 
93cb70c
 
3a6b0f7
 
 
93cb70c
 
2816496
3a6b0f7
 
3f10588
3a6b0f7
 
93cb70c
 
 
3a6b0f7
 
 
93cb70c
 
 
3a6b0f7
 
 
 
 
93cb70c
2816496
3a6b0f7
 
 
 
 
2816496
3a6b0f7
 
 
 
 
 
 
 
93cb70c
2816496
1786f21
93cb70c
2816496
93cb70c
 
 
 
 
2816496
 
93cb70c
3a6b0f7
1786f21
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import requests
import io
from PIL import Image
import gradio as gr
from transformers import MarianMTModel, MarianTokenizer, AutoModelForCausalLM, AutoTokenizer
import os

# Load models and tokenizers globally to avoid reloading them for every request
model_name = "Helsinki-NLP/opus-mt-mul-en"
translation_model = MarianMTModel.from_pretrained(model_name)
translation_tokenizer = MarianTokenizer.from_pretrained(model_name)

gpt_model_name = "EleutherAI/gpt-neo-1.3B"
gpt_tokenizer = AutoTokenizer.from_pretrained(gpt_model_name)
gpt_model = AutoModelForCausalLM.from_pretrained(gpt_model_name)

def translate_text(tamil_text):
    inputs = translation_tokenizer(tamil_text, return_tensors="pt")
    translated_tokens = translation_model.generate(**inputs)
    translation = translation_tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
    return translation

def query_gpt_neo(translated_text):
    prompt = f"Continue the story based on the following text: {translated_text}"
    inputs = gpt_tokenizer(prompt, return_tensors="pt")
    outputs = gpt_model.generate(inputs['input_ids'], max_length=50, num_return_sequences=1)  # Reduced max_length
    creative_text = gpt_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return creative_text

def query_image(payload):
    huggingface_api_key = os.getenv('HUGGINGFACE_API_KEY')
    if not huggingface_api_key:
        return "Error: Hugging Face API key not set."

    API_URL = "https://api-inference.huggingface.co/models/black-forest-labs/FLUX.1-dev"
    headers = {"Authorization": f"Bearer {huggingface_api_key}"}
    response = requests.post(API_URL, headers=headers, json=payload)
    
    if response.status_code == 200:
        return response.content
    else:
        return f"Error: {response.status_code} - {response.text}"

def process_input(tamil_input):
    try:
        # Translate the input text
        translated_output = translate_text(tamil_input)
        
        # Generate creative text using GPT-Neo
        creative_output = query_gpt_neo(translated_output)
        
        # Generate an image using Hugging Face's FLUX model
        image_bytes = query_image({"inputs": translated_output})
        image = Image.open(io.BytesIO(image_bytes))
        
        return translated_output, creative_output, image
    except Exception as e:
        return f"Error occurred: {str(e)}", "", None

# Create a Gradio interface
interface = gr.Interface(
    fn=process_input,
    inputs=[gr.Textbox(label="Input Tamil Text")],
    outputs=[
        gr.Textbox(label="Translated Text"),
        gr.Textbox(label="Creative Text"),
        gr.Image(label="Generated Image")
    ],
    title="TRANSART",
    description="Enter Tamil text to translate to English and generate an image based on the translated text."
)

interface.launch()