hatmanstack commited on
Commit
bd05777
·
1 Parent(s): 1c930ab

changed rate_limit to env variable

Browse files
Files changed (1) hide show
  1. generate.py +10 -6
generate.py CHANGED
@@ -32,9 +32,10 @@ def handle_bedrock_errors(func):
32
 
33
  aws_id = os.getenv('AWS_ID')
34
  aws_secret = os.getenv('AWS_SECRET')
 
35
  nova_image_bucket='nova-image-data'
36
  bucket_region='us-west-2'
37
- rate_limit_message = """<div style='text-align: center;'>{} rate limit exceeded. Check back later, use the
38
  <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a> or
39
  try it out without an AWS account on <a href='https://partyrock.aws/'>PartyRock</a>.</div>"""
40
 
@@ -166,15 +167,18 @@ def check_rate_limit(body):
166
  # Clean up old entries
167
  rate_data['premium'] = [t for t in rate_data['premium'] if t > twenty_minutes_ago]
168
  rate_data['standard'] = [t for t in rate_data['standard'] if t > twenty_minutes_ago]
169
-
 
 
 
170
  # Check limits based on quality
171
  if quality == 'premium':
172
- if len(rate_data['premium']) >= 3:
173
- raise ImageError(rate_limit_message.format('Premium'))
174
  rate_data['premium'].append(current_time)
175
  else: # standard
176
- if len(rate_data['standard']) >= 100:
177
- raise ImageError(rate_limit_message.format('Standard'))
178
  rate_data['standard'].append(current_time)
179
 
180
  # Update rate limit file
 
32
 
33
  aws_id = os.getenv('AWS_ID')
34
  aws_secret = os.getenv('AWS_SECRET')
35
+ rate_limit = os.getenv('RATE_LIMIT')
36
  nova_image_bucket='nova-image-data'
37
  bucket_region='us-west-2'
38
+ rate_limit_message = """<div style='text-align: center;'>Rate limit exceeded. Check back later, use the
39
  <a href='https://docs.aws.amazon.com/bedrock/latest/userguide/playgrounds.html'>Bedrock Playground</a> or
40
  try it out without an AWS account on <a href='https://partyrock.aws/'>PartyRock</a>.</div>"""
41
 
 
167
  # Clean up old entries
168
  rate_data['premium'] = [t for t in rate_data['premium'] if t > twenty_minutes_ago]
169
  rate_data['standard'] = [t for t in rate_data['standard'] if t > twenty_minutes_ago]
170
+
171
+ # Calculate the total count of requests in the last 20 minutes
172
+ total_count = len(rate_data['premium']) * 2 + len(rate_data['standard'])
173
+
174
  # Check limits based on quality
175
  if quality == 'premium':
176
+ if total_count + 2 > rate_limit: # Check if adding 2 would exceed the threshold
177
+ raise ImageError(rate_limit_message)
178
  rate_data['premium'].append(current_time)
179
  else: # standard
180
+ if total_count + 1 > rate_limit: # Check if adding 1 would exceed the threshold
181
+ raise ImageError(rate_limit_message)
182
  rate_data['standard'].append(current_time)
183
 
184
  # Update rate limit file