Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Update app.py
Browse files
    	
        app.py
    CHANGED
    
    | @@ -33,9 +33,8 @@ MAX_SEED = np.iinfo(np.int32).max | |
| 33 |  | 
| 34 | 
             
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         | 
| 35 |  | 
| 36 | 
            -
            # -----------------------
         | 
| 37 | 
             
            # PROGRESS BAR HELPER
         | 
| 38 | 
            -
             | 
| 39 | 
             
            def progress_bar_html(label: str) -> str:
         | 
| 40 | 
             
                """
         | 
| 41 | 
             
                Returns an HTML snippet for a thin progress bar with a label.
         | 
| @@ -56,9 +55,8 @@ def progress_bar_html(label: str) -> str: | |
| 56 | 
             
            </style>
         | 
| 57 | 
             
                '''
         | 
| 58 |  | 
| 59 | 
            -
            # -----------------------
         | 
| 60 | 
             
            # TEXT & TTS MODELS
         | 
| 61 | 
            -
             | 
| 62 | 
             
            model_id = "prithivMLmods/FastThink-0.5B-Tiny"
         | 
| 63 | 
             
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 64 | 
             
            model = AutoModelForCausalLM.from_pretrained(
         | 
| @@ -73,9 +71,8 @@ TTS_VOICES = [ | |
| 73 | 
             
                "en-US-GuyNeural",    # @tts2
         | 
| 74 | 
             
            ]
         | 
| 75 |  | 
| 76 | 
            -
            # -----------------------
         | 
| 77 | 
             
            # MULTIMODAL (OCR) MODELS
         | 
| 78 | 
            -
             | 
| 79 | 
             
            MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" 
         | 
| 80 | 
             
            processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
         | 
| 81 | 
             
            model_m = Qwen2VLForConditionalGeneration.from_pretrained(
         | 
| @@ -84,15 +81,6 @@ model_m = Qwen2VLForConditionalGeneration.from_pretrained( | |
| 84 | 
             
                torch_dtype=torch.float16
         | 
| 85 | 
             
            ).to("cuda").eval()
         | 
| 86 |  | 
| 87 | 
            -
            # -----------------------
         | 
| 88 | 
            -
            # GEMMA3-4B MODEL SETUP (NEW FEATURE)
         | 
| 89 | 
            -
            # -----------------------
         | 
| 90 | 
            -
            gemma3_model_id = "google/gemma-3-4b-it"
         | 
| 91 | 
            -
            gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
         | 
| 92 | 
            -
                gemma3_model_id, device_map="auto"
         | 
| 93 | 
            -
            ).eval()
         | 
| 94 | 
            -
            gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
         | 
| 95 | 
            -
             | 
| 96 | 
             
            async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
         | 
| 97 | 
             
                communicate = edge_tts.Communicate(text, voice)
         | 
| 98 | 
             
                await communicate.save(output_file)
         | 
| @@ -130,9 +118,9 @@ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1" | |
| 130 |  | 
| 131 | 
             
            dtype = torch.float16 if device.type == "cuda" else torch.float32
         | 
| 132 |  | 
| 133 | 
            -
             | 
| 134 | 
             
            # STABLE DIFFUSION IMAGE GENERATION MODELS
         | 
| 135 | 
            -
             | 
| 136 | 
             
            if torch.cuda.is_available():
         | 
| 137 | 
             
                # Lightning 5 model
         | 
| 138 | 
             
                pipe = StableDiffusionXLPipeline.from_pretrained(
         | 
| @@ -218,9 +206,18 @@ def save_image(img: Image.Image) -> str: | |
| 218 | 
             
                img.save(unique_name)
         | 
| 219 | 
             
                return unique_name
         | 
| 220 |  | 
| 221 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 222 | 
             
            # MAIN GENERATION FUNCTION
         | 
| 223 | 
            -
             | 
| 224 | 
             
            @spaces.GPU
         | 
| 225 | 
             
            def generate(
         | 
| 226 | 
             
                input_dict: dict,
         | 
| @@ -235,8 +232,8 @@ def generate( | |
| 235 | 
             
                files = input_dict.get("files", [])
         | 
| 236 |  | 
| 237 | 
             
                lower_text = text.lower().strip()
         | 
| 238 | 
            -
             | 
| 239 | 
            -
                #  | 
| 240 | 
             
                if (lower_text.startswith("@lightningv5") or 
         | 
| 241 | 
             
                    lower_text.startswith("@lightningv4") or 
         | 
| 242 | 
             
                    lower_text.startswith("@turbov3")):
         | 
| @@ -288,52 +285,53 @@ def generate( | |
| 288 | 
             
                    yield gr.Image(image_path)
         | 
| 289 | 
             
                    return
         | 
| 290 |  | 
| 291 | 
            -
                #  | 
| 292 | 
             
                if lower_text.startswith("@gemma3-4b"):
         | 
| 293 | 
            -
                    # Remove the flag from the  | 
| 294 | 
             
                    prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
         | 
| 295 | 
            -
                    # Build messages: include a system message and user message.
         | 
| 296 | 
            -
                    messages = []
         | 
| 297 | 
            -
                    messages.append({
         | 
| 298 | 
            -
                        "role": "system",
         | 
| 299 | 
            -
                        "content": [{"type": "text", "text": "You are a helpful assistant."}]
         | 
| 300 | 
            -
                    })
         | 
| 301 | 
            -
                    user_content = []
         | 
| 302 | 
             
                    if files:
         | 
| 303 | 
            -
                        # If  | 
| 304 | 
            -
                        images = [load_image( | 
| 305 | 
            -
                         | 
| 306 | 
            -
                             | 
| 307 | 
            -
             | 
| 308 | 
            -
             | 
| 309 | 
            -
             | 
| 310 | 
            -
             | 
| 311 | 
            -
                         | 
| 312 | 
            -
                     | 
| 313 | 
            -
             | 
| 314 | 
            -
             | 
|  | |
|  | |
| 315 | 
             
                    inputs = gemma3_processor.apply_chat_template(
         | 
| 316 | 
             
                        messages, add_generation_prompt=True, tokenize=True,
         | 
| 317 | 
             
                        return_dict=True, return_tensors="pt"
         | 
| 318 | 
             
                    ).to(gemma3_model.device, dtype=torch.bfloat16)
         | 
| 319 | 
            -
                    
         | 
| 320 | 
            -
             | 
| 321 | 
            -
                     | 
| 322 | 
            -
                     | 
| 323 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 324 | 
             
                    thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
         | 
| 325 | 
             
                    thread.start()
         | 
| 326 | 
            -
             | 
| 327 | 
             
                    buffer = ""
         | 
| 328 | 
            -
                    yield progress_bar_html("Processing with Gemma3- | 
| 329 | 
             
                    for new_text in streamer:
         | 
| 330 | 
             
                        buffer += new_text
         | 
|  | |
| 331 | 
             
                        yield buffer
         | 
| 332 | 
            -
                    final_response = buffer
         | 
| 333 | 
            -
                    yield final_response
         | 
| 334 | 
             
                    return
         | 
| 335 |  | 
| 336 | 
            -
                #  | 
| 337 | 
             
                tts_prefix = "@tts"
         | 
| 338 | 
             
                is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
         | 
| 339 | 
             
                voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
         | 
|  | |
| 33 |  | 
| 34 | 
             
            device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
         | 
| 35 |  | 
|  | |
| 36 | 
             
            # PROGRESS BAR HELPER
         | 
| 37 | 
            +
             | 
| 38 | 
             
            def progress_bar_html(label: str) -> str:
         | 
| 39 | 
             
                """
         | 
| 40 | 
             
                Returns an HTML snippet for a thin progress bar with a label.
         | 
|  | |
| 55 | 
             
            </style>
         | 
| 56 | 
             
                '''
         | 
| 57 |  | 
|  | |
| 58 | 
             
            # TEXT & TTS MODELS
         | 
| 59 | 
            +
             | 
| 60 | 
             
            model_id = "prithivMLmods/FastThink-0.5B-Tiny"
         | 
| 61 | 
             
            tokenizer = AutoTokenizer.from_pretrained(model_id)
         | 
| 62 | 
             
            model = AutoModelForCausalLM.from_pretrained(
         | 
|  | |
| 71 | 
             
                "en-US-GuyNeural",    # @tts2
         | 
| 72 | 
             
            ]
         | 
| 73 |  | 
|  | |
| 74 | 
             
            # MULTIMODAL (OCR) MODELS
         | 
| 75 | 
            +
             | 
| 76 | 
             
            MODEL_ID_VL = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct" 
         | 
| 77 | 
             
            processor = AutoProcessor.from_pretrained(MODEL_ID_VL, trust_remote_code=True)
         | 
| 78 | 
             
            model_m = Qwen2VLForConditionalGeneration.from_pretrained(
         | 
|  | |
| 81 | 
             
                torch_dtype=torch.float16
         | 
| 82 | 
             
            ).to("cuda").eval()
         | 
| 83 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 84 | 
             
            async def text_to_speech(text: str, voice: str, output_file="output.mp3"):
         | 
| 85 | 
             
                communicate = edge_tts.Communicate(text, voice)
         | 
| 86 | 
             
                await communicate.save(output_file)
         | 
|  | |
| 118 |  | 
| 119 | 
             
            dtype = torch.float16 if device.type == "cuda" else torch.float32
         | 
| 120 |  | 
| 121 | 
            +
             | 
| 122 | 
             
            # STABLE DIFFUSION IMAGE GENERATION MODELS
         | 
| 123 | 
            +
             | 
| 124 | 
             
            if torch.cuda.is_available():
         | 
| 125 | 
             
                # Lightning 5 model
         | 
| 126 | 
             
                pipe = StableDiffusionXLPipeline.from_pretrained(
         | 
|  | |
| 206 | 
             
                img.save(unique_name)
         | 
| 207 | 
             
                return unique_name
         | 
| 208 |  | 
| 209 | 
            +
             | 
| 210 | 
            +
            # GEMMA3-4B MULTIMODAL MODEL
         | 
| 211 | 
            +
             | 
| 212 | 
            +
            gemma3_model_id = "google/gemma-3-4b-it"
         | 
| 213 | 
            +
            gemma3_model = Gemma3ForConditionalGeneration.from_pretrained(
         | 
| 214 | 
            +
                gemma3_model_id, device_map="auto"
         | 
| 215 | 
            +
            ).eval()
         | 
| 216 | 
            +
            gemma3_processor = AutoProcessor.from_pretrained(gemma3_model_id)
         | 
| 217 | 
            +
             | 
| 218 | 
            +
             | 
| 219 | 
             
            # MAIN GENERATION FUNCTION
         | 
| 220 | 
            +
             | 
| 221 | 
             
            @spaces.GPU
         | 
| 222 | 
             
            def generate(
         | 
| 223 | 
             
                input_dict: dict,
         | 
|  | |
| 232 | 
             
                files = input_dict.get("files", [])
         | 
| 233 |  | 
| 234 | 
             
                lower_text = text.lower().strip()
         | 
| 235 | 
            +
             | 
| 236 | 
            +
                # Image Generation Branch (Stable Diffusion models)
         | 
| 237 | 
             
                if (lower_text.startswith("@lightningv5") or 
         | 
| 238 | 
             
                    lower_text.startswith("@lightningv4") or 
         | 
| 239 | 
             
                    lower_text.startswith("@turbov3")):
         | 
|  | |
| 285 | 
             
                    yield gr.Image(image_path)
         | 
| 286 | 
             
                    return
         | 
| 287 |  | 
| 288 | 
            +
                # GEMMA3-4B Branch for Multimodal/Text Generation with Streaming
         | 
| 289 | 
             
                if lower_text.startswith("@gemma3-4b"):
         | 
| 290 | 
            +
                    # Remove the gemma3 flag from the prompt.
         | 
| 291 | 
             
                    prompt_clean = re.sub(r"@gemma3-4b", "", text, flags=re.IGNORECASE).strip().strip('"')
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 292 | 
             
                    if files:
         | 
| 293 | 
            +
                        # If image files are provided, load them.
         | 
| 294 | 
            +
                        images = [load_image(f) for f in files]
         | 
| 295 | 
            +
                        messages = [{
         | 
| 296 | 
            +
                            "role": "user",
         | 
| 297 | 
            +
                            "content": [
         | 
| 298 | 
            +
                                *[{"type": "image", "image": image} for image in images],
         | 
| 299 | 
            +
                                {"type": "text", "text": prompt_clean},
         | 
| 300 | 
            +
                            ]
         | 
| 301 | 
            +
                        }]
         | 
| 302 | 
            +
                    else:
         | 
| 303 | 
            +
                        messages = [
         | 
| 304 | 
            +
                            {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
         | 
| 305 | 
            +
                            {"role": "user", "content": [{"type": "text", "text": prompt_clean}]}
         | 
| 306 | 
            +
                        ]
         | 
| 307 | 
             
                    inputs = gemma3_processor.apply_chat_template(
         | 
| 308 | 
             
                        messages, add_generation_prompt=True, tokenize=True,
         | 
| 309 | 
             
                        return_dict=True, return_tensors="pt"
         | 
| 310 | 
             
                    ).to(gemma3_model.device, dtype=torch.bfloat16)
         | 
| 311 | 
            +
                    streamer = TextIteratorStreamer(
         | 
| 312 | 
            +
                        gemma3_processor.tokenizer, timeout=20.0, skip_prompt=True, skip_special_tokens=True
         | 
| 313 | 
            +
                    )
         | 
| 314 | 
            +
                    generation_kwargs = {
         | 
| 315 | 
            +
                        **inputs,
         | 
| 316 | 
            +
                        "streamer": streamer,
         | 
| 317 | 
            +
                        "max_new_tokens": max_new_tokens,
         | 
| 318 | 
            +
                        "do_sample": True,
         | 
| 319 | 
            +
                        "temperature": temperature,
         | 
| 320 | 
            +
                        "top_p": top_p,
         | 
| 321 | 
            +
                        "top_k": top_k,
         | 
| 322 | 
            +
                        "repetition_penalty": repetition_penalty,
         | 
| 323 | 
            +
                    }
         | 
| 324 | 
             
                    thread = Thread(target=gemma3_model.generate, kwargs=generation_kwargs)
         | 
| 325 | 
             
                    thread.start()
         | 
|  | |
| 326 | 
             
                    buffer = ""
         | 
| 327 | 
            +
                    yield progress_bar_html("Processing with Gemma3-4b")
         | 
| 328 | 
             
                    for new_text in streamer:
         | 
| 329 | 
             
                        buffer += new_text
         | 
| 330 | 
            +
                        time.sleep(0.01)
         | 
| 331 | 
             
                        yield buffer
         | 
|  | |
|  | |
| 332 | 
             
                    return
         | 
| 333 |  | 
| 334 | 
            +
                # Otherwise, handle text/chat (and TTS) generation.
         | 
| 335 | 
             
                tts_prefix = "@tts"
         | 
| 336 | 
             
                is_tts = any(text.strip().lower().startswith(f"{tts_prefix}{i}") for i in range(1, 3))
         | 
| 337 | 
             
                voice_index = next((i for i in range(1, 3) if text.strip().lower().startswith(f"{tts_prefix}{i}")), None)
         | 
