Debito commited on
Commit
c7663a8
Β·
verified Β·
1 Parent(s): 368e060

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -14
app.py CHANGED
@@ -606,10 +606,12 @@ class UltimateModelLoader:
606
  "do_sample": True,
607
  "pad_token_id": getattr(self.tokenizer, 'pad_token_id', 50256),
608
  "eos_token_id": getattr(self.tokenizer, 'eos_token_id', 50256),
609
- "repetition_penalty": config["repetition_penalty"],
610
- "no_repeat_ngram_size": config["no_repeat_ngram_size"],
611
- "length_penalty": 1.0,
612
- "early_stopping": True
 
 
613
  }
614
 
615
  return optimal_params
@@ -1465,6 +1467,11 @@ class UltimateMambaSwarm:
1465
  if not prompt.strip():
1466
  return "Please enter a prompt.", ""
1467
 
 
 
 
 
 
1468
  try:
1469
  # Handle model switching if requested
1470
  if model_size != "auto" and model_size != self.current_model_size:
@@ -1580,17 +1587,23 @@ COMPREHENSIVE RESPONSE:"""
1580
 
1581
  print(f"πŸ“ Hybrid params: temp={gen_params['temperature']:.2f}, top_p={gen_params['top_p']:.2f}")
1582
 
1583
- # Tokenize hybrid prompt
 
1584
  inputs = self.model_loader.tokenizer.encode(
1585
- hybrid_prompt,
1586
  return_tensors="pt",
1587
  truncation=True,
1588
- max_length=700 # Larger context for web data
 
1589
  )
1590
  inputs = inputs.to(self.model_loader.device)
1591
 
1592
  # Generate with hybrid intelligence
1593
  with torch.no_grad():
 
 
 
 
1594
  outputs = self.model_loader.model.generate(inputs, **gen_params)
1595
 
1596
  # Decode and validate
@@ -1599,11 +1612,16 @@ COMPREHENSIVE RESPONSE:"""
1599
  # Extract response safely
1600
  if "COMPREHENSIVE RESPONSE:" in generated_text:
1601
  response = generated_text.split("COMPREHENSIVE RESPONSE:")[-1].strip()
 
 
1602
  elif generated_text.startswith(hybrid_prompt):
1603
  response = generated_text[len(hybrid_prompt):].strip()
1604
  else:
1605
  response = generated_text.strip()
1606
 
 
 
 
1607
  # Enhanced validation for hybrid responses
1608
  if self._is_inappropriate_content(response):
1609
  logger.warning("πŸ›‘οΈ Inappropriate hybrid content detected, using fallback")
@@ -1734,30 +1752,42 @@ COMPREHENSIVE RESPONSE:"""
1734
  print(f"πŸ“ Using prompt format: '{safe_prompt[:50]}...'")
1735
  print(f"βš™οΈ Generation params: temp={gen_params['temperature']:.2f}, top_p={gen_params['top_p']:.2f}")
1736
 
1737
- # Tokenize with safety
 
1738
  inputs = self.model_loader.tokenizer.encode(
1739
- safe_prompt,
1740
  return_tensors="pt",
1741
  truncation=True,
1742
- max_length=512
 
1743
  )
1744
  inputs = inputs.to(self.model_loader.device)
1745
 
1746
  # Generate with optimal parameters
1747
  with torch.no_grad():
 
 
 
 
1748
  outputs = self.model_loader.model.generate(inputs, **gen_params)
1749
 
1750
  # Decode and validate
1751
  generated_text = self.model_loader.tokenizer.decode(outputs[0], skip_special_tokens=True)
1752
 
1753
- # Extract response safely
1754
- if generated_text.startswith(safe_prompt):
 
 
1755
  response = generated_text[len(safe_prompt):].strip()
1756
  elif generated_text.startswith(prompt):
1757
  response = generated_text[len(prompt):].strip()
1758
  else:
1759
  response = generated_text.strip()
1760
 
 
 
 
 
1761
  # Content safety filtering
1762
  if self._is_inappropriate_content(response):
1763
  logger.warning("πŸ›‘οΈ Inappropriate content detected, using domain-specific fallback")
@@ -2567,10 +2597,17 @@ def create_ultimate_interface():
2567
  )
2568
 
2569
  # Event handlers
 
 
 
 
 
 
 
2570
  generate_btn.click(
2571
- fn=swarm.generate_text_ultimate,
2572
  inputs=[prompt_input, max_length, temperature, top_p, num_encoders, model_size, show_routing, enable_search],
2573
- outputs=[response_output, routing_output]
2574
  )
2575
 
2576
  refresh_btn.click(
 
606
  "do_sample": True,
607
  "pad_token_id": getattr(self.tokenizer, 'pad_token_id', 50256),
608
  "eos_token_id": getattr(self.tokenizer, 'eos_token_id', 50256),
609
+ "repetition_penalty": max(config["repetition_penalty"], 1.2), # Increased to prevent repetition
610
+ "no_repeat_ngram_size": max(config["no_repeat_ngram_size"], 3), # Increased to prevent repetition
611
+ "length_penalty": 1.1, # Slight length penalty to encourage variety
612
+ "early_stopping": True,
613
+ "num_beams": 1, # Use sampling instead of beam search for more variety
614
+ "top_k": 50 # Add top-k sampling to improve variety
615
  }
616
 
617
  return optimal_params
 
1467
  if not prompt.strip():
1468
  return "Please enter a prompt.", ""
1469
 
1470
+ # Add randomness to prevent identical responses
1471
+ import random
1472
+ random.seed(int(time.time() * 1000) % 2**32) # Use current time as seed
1473
+ np.random.seed(int(time.time() * 1000) % 2**32)
1474
+
1475
  try:
1476
  # Handle model switching if requested
1477
  if model_size != "auto" and model_size != self.current_model_size:
 
1587
 
1588
  print(f"πŸ“ Hybrid params: temp={gen_params['temperature']:.2f}, top_p={gen_params['top_p']:.2f}")
1589
 
1590
+ # Tokenize hybrid prompt with uniqueness
1591
+ hybrid_prompt_unique = f"{hybrid_prompt} [Session: {int(time.time())}]"
1592
  inputs = self.model_loader.tokenizer.encode(
1593
+ hybrid_prompt_unique,
1594
  return_tensors="pt",
1595
  truncation=True,
1596
+ max_length=650, # Smaller to account for session marker
1597
+ add_special_tokens=True
1598
  )
1599
  inputs = inputs.to(self.model_loader.device)
1600
 
1601
  # Generate with hybrid intelligence
1602
  with torch.no_grad():
1603
+ # Clear any cached states to prevent repetition
1604
+ if hasattr(self.model_loader.model, 'reset_cache'):
1605
+ self.model_loader.model.reset_cache()
1606
+
1607
  outputs = self.model_loader.model.generate(inputs, **gen_params)
1608
 
1609
  # Decode and validate
 
1612
  # Extract response safely
1613
  if "COMPREHENSIVE RESPONSE:" in generated_text:
1614
  response = generated_text.split("COMPREHENSIVE RESPONSE:")[-1].strip()
1615
+ elif generated_text.startswith(hybrid_prompt_unique):
1616
+ response = generated_text[len(hybrid_prompt_unique):].strip()
1617
  elif generated_text.startswith(hybrid_prompt):
1618
  response = generated_text[len(hybrid_prompt):].strip()
1619
  else:
1620
  response = generated_text.strip()
1621
 
1622
+ # Clean up any session markers
1623
+ response = re.sub(r'\[Session: \d+\]', '', response).strip()
1624
+
1625
  # Enhanced validation for hybrid responses
1626
  if self._is_inappropriate_content(response):
1627
  logger.warning("πŸ›‘οΈ Inappropriate hybrid content detected, using fallback")
 
1752
  print(f"πŸ“ Using prompt format: '{safe_prompt[:50]}...'")
1753
  print(f"βš™οΈ Generation params: temp={gen_params['temperature']:.2f}, top_p={gen_params['top_p']:.2f}")
1754
 
1755
+ # Tokenize with safety and uniqueness
1756
+ prompt_with_timestamp = f"{safe_prompt} [Time: {int(time.time())}]" # Add timestamp to make each prompt unique
1757
  inputs = self.model_loader.tokenizer.encode(
1758
+ prompt_with_timestamp,
1759
  return_tensors="pt",
1760
  truncation=True,
1761
+ max_length=500, # Slightly smaller to account for timestamp
1762
+ add_special_tokens=True
1763
  )
1764
  inputs = inputs.to(self.model_loader.device)
1765
 
1766
  # Generate with optimal parameters
1767
  with torch.no_grad():
1768
+ # Clear any cached states
1769
+ if hasattr(self.model_loader.model, 'reset_cache'):
1770
+ self.model_loader.model.reset_cache()
1771
+
1772
  outputs = self.model_loader.model.generate(inputs, **gen_params)
1773
 
1774
  # Decode and validate
1775
  generated_text = self.model_loader.tokenizer.decode(outputs[0], skip_special_tokens=True)
1776
 
1777
+ # Extract response safely and remove timestamp
1778
+ if generated_text.startswith(prompt_with_timestamp):
1779
+ response = generated_text[len(prompt_with_timestamp):].strip()
1780
+ elif generated_text.startswith(safe_prompt):
1781
  response = generated_text[len(safe_prompt):].strip()
1782
  elif generated_text.startswith(prompt):
1783
  response = generated_text[len(prompt):].strip()
1784
  else:
1785
  response = generated_text.strip()
1786
 
1787
+ # Remove any remaining timestamp artifacts
1788
+ import re
1789
+ response = re.sub(r'\[Time: \d+\]', '', response).strip()
1790
+
1791
  # Content safety filtering
1792
  if self._is_inappropriate_content(response):
1793
  logger.warning("πŸ›‘οΈ Inappropriate content detected, using domain-specific fallback")
 
2597
  )
2598
 
2599
  # Event handlers
2600
+ def generate_and_clear(prompt, max_length, temperature, top_p, num_encoders, model_size, show_routing, enable_search):
2601
+ """Generate response and clear the input field"""
2602
+ response, routing = swarm.generate_text_ultimate(
2603
+ prompt, max_length, temperature, top_p, num_encoders, model_size, show_routing, enable_search
2604
+ )
2605
+ return response, routing, "" # Return empty string to clear input
2606
+
2607
  generate_btn.click(
2608
+ fn=generate_and_clear,
2609
  inputs=[prompt_input, max_length, temperature, top_p, num_encoders, model_size, show_routing, enable_search],
2610
+ outputs=[response_output, routing_output, prompt_input] # Include prompt_input in outputs to clear it
2611
  )
2612
 
2613
  refresh_btn.click(