Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -20,6 +20,7 @@ from transformers import ( | |
| 20 | 
             
                TextIteratorStreamer,
         | 
| 21 | 
             
                Qwen2VLForConditionalGeneration,
         | 
| 22 | 
             
                AutoProcessor,
         | 
|  | |
| 23 | 
             
            )
         | 
| 24 | 
             
            from transformers.image_utils import load_image
         | 
| 25 | 
             
            from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
         | 
| @@ -208,6 +209,15 @@ def save_image(img: Image.Image) -> str: | |
| 208 | 
             
                img.save(unique_name)
         | 
| 209 | 
             
                return unique_name
         | 
| 210 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 211 | 
             
            # -----------------------
         | 
| 212 | 
             
            # MAIN GENERATION FUNCTION
         | 
| 213 | 
             
            # -----------------------
         | 
| @@ -225,7 +235,8 @@ def generate( | |
| 225 | 
             
                files = input_dict.get("files", [])
         | 
| 226 |  | 
| 227 | 
             
                lower_text = text.lower().strip()
         | 
| 228 | 
            -
             | 
|  | |
| 229 | 
             
                if (lower_text.startswith("@lightningv5") or 
         | 
| 230 | 
             
                    lower_text.startswith("@lightningv4") or 
         | 
| 231 | 
             
                    lower_text.startswith("@turbov3")):
         | 
| @@ -277,6 +288,52 @@ def generate( | |
| 277 | 
             
                    yield gr.Image(image_path)
         | 
| 278 | 
             
                    return
         | 
| 279 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 280 | 
             
                # Otherwise, handle text/chat (and TTS) generation.
         | 
| 281 | 
             
                tts_prefix = "@tts"
         | 
| 282 | 
             
                is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
         | 
| @@ -391,7 +448,7 @@ demo = gr.ChatInterface( | |
| 391 | 
             
                description=DESCRIPTION,
         | 
| 392 | 
             
                css=css,
         | 
| 393 | 
             
                fill_height=True,
         | 
| 394 | 
            -
                textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 for  | 
| 395 | 
             
                stop_btn="Stop Generation",
         | 
| 396 | 
             
                multimodal=True,
         | 
| 397 | 
             
            )
         | 
|  | |
| 20 | 
             
                TextIteratorStreamer,
         | 
| 21 | 
             
                Qwen2VLForConditionalGeneration,
         | 
| 22 | 
             
                AutoProcessor,
         | 
| 23 | 
            +
                Gemma3ForConditionalGeneration,  # New import for Gemma3-4B
         | 
| 24 | 
             
            )
         | 
| 25 | 
             
            from transformers.image_utils import load_image
         | 
| 26 | 
             
            from diffusers import StableDiffusionXLPipeline, EulerAncestralDiscreteScheduler
         | 
|  | |
| 209 | 
             
                img.save(unique_name)
         | 
| 210 | 
             
                return unique_name
         | 
| 211 |  | 
| 212 | 
            +
            # -----------------------
         | 
| 213 | 
            +
            # GEMMA3-4B MULTIMODAL MODEL
         | 
| 214 | 
            +
            # -----------------------
         | 
| 215 | 
            +
            gemma3_model_id = "google/gemma-3-4b-it"
         | 
| 216 | 
            +
            gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
         | 
| 217 | 
            +
                gemma3_model_id, device_map="auto"
         | 
| 218 | 
            +
            ).eval()
         | 
| 219 | 
            +
            gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
         | 
| 220 | 
            +
             | 
| 221 | 
             
            # -----------------------
         | 
| 222 | 
             
            # MAIN GENERATION FUNCTION
         | 
| 223 | 
             
            # -----------------------
         | 
|  | |
| 235 | 
             
                files = input_dict.get("files", [])
         | 
| 236 |  | 
| 237 | 
             
                lower_text = text.lower().strip()
         | 
| 238 | 
            +
             | 
| 239 | 
            +
                # Image Generation Branch (Stable Diffusion models)
         | 
| 240 | 
             
                if (lower_text.startswith("@lightningv5") or 
         | 
| 241 | 
             
                    lower_text.startswith("@lightningv4") or 
         | 
| 242 | 
             
                    lower_text.startswith("@turbov3")):
         | 
|  | |
| 288 | 
             
                    yield gr.Image(image_path)
         | 
| 289 | 
             
                    return
         | 
| 290 |  | 
| 291 | 
            +
                # GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
         | 
| 292 | 
            +
                if lower_text.startswith("@gemma3-4b"):
         | 
| 293 | 
            +
                    # Remove the gemma3 flag from the prompt.
         | 
| 294 | 
            +
                    prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
         | 
| 295 | 
            +
                    if files:
         | 
| 296 | 
            +
                        # If image files are provided, load them.
         | 
| 297 | 
            +
                        images = [load_image(f) for f in files]
         | 
| 298 | 
            +
                        messages = [{
         | 
| 299 | 
            +
                            "role": "user",
         | 
| 300 | 
            +
                            "content": [
         | 
| 301 | 
            +
                                *[{"type": "image", "image": image} for image in images],
         | 
| 302 | 
            +
                                {"type": "text", "text": prompt_clean},
         | 
| 303 | 
            +
                            ]
         | 
| 304 | 
            +
                        }]
         | 
| 305 | 
            +
                    else:
         | 
| 306 | 
            +
                        messages = [
         | 
| 307 | 
            +
                            {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
         | 
| 308 | 
            +
                            {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
         | 
| 309 | 
            +
                        ]
         | 
| 310 | 
            +
                    inputs = gemma3_processor.apply_chat_template(
         | 
| 311 | 
            +
                        messages, add_generation_prompt=True, tokenize=True,
         | 
| 312 | 
            +
                        return_dict=True, return_tensors="pt"
         | 
| 313 | 
            +
                    ).to(gemma3_model.device, dtype=torch.bfloat16)
         | 
| 314 | 
            +
                    streamer = TextIteratorStreamer(
         | 
| 315 | 
            +
                        gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
         | 
| 316 | 
            +
                    )
         | 
| 317 | 
            +
                    generation_kwargs = {
         | 
| 318 | 
            +
                        **inputs,
         | 
| 319 | 
            +
                        "streamer": streamer,
         | 
| 320 | 
            +
                        "max_new_tokens": max_new_tokens,
         | 
| 321 | 
            +
                        "do_sample": True,
         | 
| 322 | 
            +
                        "temperature": temperature,
         | 
| 323 | 
            +
                        "top_p": top_p,
         | 
| 324 | 
            +
                        "top_k": top_k,
         | 
| 325 | 
            +
                        "repetition_penalty": repetition_penalty,
         | 
| 326 | 
            +
                    }
         | 
| 327 | 
            +
                    thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
         | 
| 328 | 
            +
                    thread.start()
         | 
| 329 | 
            +
                    buffer = ""
         | 
| 330 | 
            +
                    yield progress_bar_html("Processing with Gemma3-4b")
         | 
| 331 | 
            +
                    for new_text in streamer:
         | 
| 332 | 
            +
                        buffer += new_text
         | 
| 333 | 
            +
                        time.sleep(0.01)
         | 
| 334 | 
            +
                        yield buffer
         | 
| 335 | 
            +
                    return
         | 
| 336 | 
            +
             | 
| 337 | 
             
                # Otherwise, handle text/chat (and TTS) generation.
         | 
| 338 | 
             
                tts_prefix = "@tts"
         | 
| 339 | 
             
                is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
         | 
|  | |
| 448 | 
             
                description=DESCRIPTION,
         | 
| 449 | 
             
                css=css,
         | 
| 450 | 
             
                fill_height=True,
         | 
| 451 | 
            +
                textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple", placeholder="use the tags @lightningv5 @lightningv4 @turbov3 or @gemma3-4b for multimodal gen !"),
         | 
| 452 | 
             
                stop_btn="Stop Generation",
         | 
| 453 | 
             
                multimodal=True,
         | 
| 454 | 
             
            )
         | 
