ginipick commited on
Commit
58da738
Β·
verified Β·
1 Parent(s): cb17632

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -15
app.py CHANGED
@@ -154,27 +154,42 @@ def install_flash_attn():
154
 
155
  logging.info(f"Detected CUDA version: {cuda_version}")
156
 
157
- # CUDA 버전별 wheel 파일 선택
158
- if cuda_version.startswith("12.1"):
159
- flash_attn_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.5/flash_attn-2.7.5+cu121torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
160
- elif cuda_version.startswith("11.8"):
161
- flash_attn_url = "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.3/flash_attn-2.7.3+cu11torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
162
- else:
163
- logging.warning(f"Unsupported CUDA version: {cuda_version}, skipping flash-attn installation")
164
- return False
165
 
166
- subprocess.run(
167
- ["pip", "install", flash_attn_url],
168
- check=True,
169
- capture_output=True
170
- )
 
 
 
 
 
 
 
171
 
172
- logging.info("flash-attn installed successfully!")
173
- return True
174
  except Exception as e:
175
  logging.warning(f"Failed to install flash-attn: {e}")
176
  return False
177
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  def initialize_system():
179
  optimize_gpu_settings()
180
  has_flash_attn = install_flash_attn()
 
154
 
155
  logging.info(f"Detected CUDA version: {cuda_version}")
156
 
157
+ try:
158
+ import flash_attn
159
+ logging.info("flash-attn already installed")
160
+ return True
161
+ except ImportError:
162
+ logging.info("Installing flash-attn...")
 
 
163
 
164
+ # CUDA 12.1용 직접 μ„€μΉ˜ μ‹œλ„
165
+ try:
166
+ subprocess.run(
167
+ ["pip", "install", "flash-attn", "--no-build-isolation"],
168
+ check=True,
169
+ capture_output=True
170
+ )
171
+ logging.info("flash-attn installed successfully!")
172
+ return True
173
+ except subprocess.CalledProcessError:
174
+ logging.warning("Failed to install flash-attn via pip, skipping...")
175
+ return False
176
 
 
 
177
  except Exception as e:
178
  logging.warning(f"Failed to install flash-attn: {e}")
179
  return False
180
 
181
+ # ... (λ‚˜λ¨Έμ§€ μ½”λ“œλŠ” 동일) ...
182
+
183
+ # μ„œλ²„ μ„€μ •μœΌλ‘œ μ‹€ν–‰ λΆ€λΆ„λ§Œ μˆ˜μ •
184
+ demo.queue(max_size=20).launch(
185
+ server_name="0.0.0.0",
186
+ server_port=7860,
187
+ share=True,
188
+ show_api=True,
189
+ show_error=True,
190
+ max_threads=2 # concurrency_count λŒ€μ‹  max_threads μ‚¬μš©
191
+ )
192
+
193
  def initialize_system():
194
  optimize_gpu_settings()
195
  has_flash_attn = install_flash_attn()