File size: 7,210 Bytes
ef37daa
 
305d245
 
 
 
 
 
45b720d
 
 
 
 
9d0ab74
ef37daa
5ac6df3
 
45b720d
f147126
9d0ab74
 
 
 
 
 
 
 
 
 
 
 
 
5ac6df3
45b720d
 
 
 
9d0ab74
 
 
 
 
 
 
 
 
45b720d
9d0ab74
45b720d
 
9d0ab74
45b720d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ac6df3
 
 
 
69bd0b3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305d245
69bd0b3
 
305d245
69bd0b3
 
305d245
ef37daa
5ac6df3
 
 
 
 
ef37daa
9d0ab74
ef37daa
a387258
 
 
 
 
ef37daa
5ac6df3
 
 
9d0ab74
5ac6df3
9d0ab74
 
5ac6df3
 
 
 
 
 
 
 
9d0ab74
305d245
f147126
ef37daa
464da3a
69bd0b3
 
 
 
 
464da3a
 
a387258
f147126
5ac6df3
ef37daa
5ac6df3
ef37daa
 
 
 
 
 
 
 
5ac6df3
 
 
 
 
 
 
f147126
5ac6df3
 
ef37daa
9d0ab74
464da3a
 
 
a387258
69bd0b3
a387258
 
 
 
 
9d0ab74
a387258
 
 
 
 
 
 
 
 
 
ef37daa
 
 
 
 
a387258
ef37daa
9d0ab74
 
 
 
 
 
a387258
d598d13
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import gradio as gr
from huggingface_hub import InferenceClient
from deep_translator import GoogleTranslator
from indic_transliteration import sanscript
from indic_transliteration.detect import detect as detect_script
from indic_transliteration.sanscript import transliterate
import langdetect
import re
import requests
import json
import base64
from PIL import Image
import io
import time

# Initialize clients
text_client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
SPACE_URL = "https://ijohn07-dalle-4k.hf.space"

# Add image style options
IMAGE_STYLES = {
    "3840 x 2160": "3840 x 2160",
    "2560 x 1440": "2560 x 1440",
    "Photo": "Photo",
    "Cinematic": "Cinematic",
    "Anime": "Anime",
    "3D Model": "3D Model",
    "No style": "(No style)"
}

def generate_image_space(prompt: str, style: str) -> Image.Image:
    """Generate an image using the DALLE-4K Space with specified style."""
    try:
        # First get the session hash
        response = requests.post(f"{SPACE_URL}/queue/join")
        session_hash = response.json().get('session_hash')
        
        # Modify prompt based on style
        if style != "(No style)":
            # Format the prompt according to the style
            if style in ["3840 x 2160", "2560 x 1440"]:
                # For resolution styles, add the resolution to the prompt
                prompt = f"{prompt}, {style} resolution"
            else:
                # For other styles, append the style to the prompt
                prompt = f"{prompt}, {style.lower()} style"
        
        # Send the generation request
        response = requests.post(f"{SPACE_URL}/run/predict", json={
            "data": [
                prompt,  # Prompt with style
                "",     # Negative prompt
                7.5,    # Guidance scale
                30,     # Steps
                "DPM++ SDE Karras",  # Scheduler
                False,  # High resolution
                False,  # Image to image
                None,   # Image upload
                1       # Batch size
            ],
            "session_hash": session_hash
        })
        
        # Poll for results
        while True:
            status_response = requests.post(f"{SPACE_URL}/queue/status", json={
                "session_hash": session_hash
            })
            status_data = status_response.json()
            
            if status_data.get('status') == 'complete':
                # Get the image data
                image_data = status_data['data']['image']
                # Convert base64 to PIL Image
                image_bytes = base64.b64decode(image_data.split(',')[1])
                image = Image.open(io.BytesIO(image_bytes))
                return image
            elif status_data.get('status') == 'error':
                raise Exception(f"Image generation failed: {status_data.get('error')}")
            
            time.sleep(1)  # Wait before polling again
            
    except Exception as e:
        print(f"Image generation error: {e}")
        return None

def romanized_to_bengali(text: str) -> str:
    """Convert romanized Bengali text to Bengali script."""
    bengali_mappings = {
        'ami': 'আমি',
        'tumi': 'তুমি',
        'apni': 'আপনি',
        'kemon': 'কেমন',
        'achen': 'আছেন',
        'acchen': 'আছেন',
        'bhalo': 'ভালো',
        'achi': 'আছি',
        'ki': 'কি',
        'kothay': 'কোথায়',
        'keno': 'কেন',
    }
    
    text_lower = text.lower()
    for roman, bengali in bengali_mappings.items():
        text_lower = re.sub(r'\b' + roman + r'\b', bengali, text_lower)
    
    if text_lower == text.lower():
        try:
            return transliterate(text, sanscript.ITRANS, sanscript.BENGALI)
        except:
            return text
            
    return text_lower

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
    image_style: str,  # New parameter for image style
):
    # First check for custom responses
    custom_response = check_custom_responses(message)
    if custom_response:
        yield custom_response
        return

    # Check if this is an image generation request
    if is_image_request(message):
        try:
            image = generate_image_space(message, image_style)
            if image:
                style_info = f" using {image_style} style" if image_style != "(No style)" else ""
                yield (image, f"Here's your generated image based on: {message}{style_info}")
                return
            else:
                yield "Sorry, I couldn't generate the image. Please try again."
                return
        except Exception as e:
            yield f"An error occurred while generating the image: {str(e)}"
            return

    # Rest of the code remains the same...
    translated_msg, original_lang, was_transliterated = translate_text(message)
    messages = [{"role": "system", "content": system_message}]
    for val in history:
        if val[0]:
            if len(val[0].split()) > 2:
                trans_user_msg, _, _ = translate_text(val[0])
                messages.append({"role": "user", "content": trans_user_msg})
            else:
                messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})
    
    messages.append({"role": "user", "content": translated_msg})

    response = ""
    for message in text_client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content
        response += token

    if original_lang != 'en' and len(message.split()) > 2:
        try:
            translator = GoogleTranslator(source='en', target=original_lang)
            translated_response = translator.translate(response)
            yield translated_response
        except:
            yield response
    else:
        yield response

# Updated Gradio interface with image style selector
demo = gr.ChatInterface(
    respond,
    additional_inputs=[
        gr.Textbox(
            value="You are a friendly Chatbot who always responds in English unless the user specifically uses another language.",
            label="System message"
        ),
        gr.Slider(
            minimum=1,
            maximum=2048,
            value=2048,
            step=1,
            label="Max new tokens"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=4.0,
            value=0.7,
            step=0.1,
            label="Temperature"
        ),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)"
        ),
        gr.Radio(
            choices=list(IMAGE_STYLES.values()),
            value="3840 x 2160",
            label="Image Style",
            info="Select the style for generated images"
        ),
    ]
)

if __name__ == "__main__":
    demo.launch(share=True)