Update app.py
Browse files
app.py
CHANGED
@@ -158,6 +158,23 @@ def stage1_process(
|
|
158 |
print('<<== stage1_process')
|
159 |
return LQ, gr.update(visible = True)
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
def stage2_process(*args, **kwargs):
|
162 |
try:
|
163 |
return restore_in_Xmin(*args, **kwargs)
|
|
|
158 |
print('<<== stage1_process')
|
159 |
return LQ, gr.update(visible = True)
|
160 |
|
161 |
+
from torch.cuda.amp import autocast # autocastをインポート
|
162 |
+
|
163 |
+
@spaces.GPU(duration=20) # GPUを利用する関数にデコレーターを追加
|
164 |
+
@torch.no_grad()
|
165 |
+
def llave_process(input_image, temperature, top_p, qs=None):
|
166 |
+
torch.cuda.set_device(LLaVA_device)
|
167 |
+
with autocast(): # AMPを使用
|
168 |
+
if use_llava and llava_agent is not None:
|
169 |
+
LQ = HWC3(input_image)
|
170 |
+
LQ = Image.fromarray(LQ.astype('uint8'))
|
171 |
+
captions = llava_agent.gen_image_caption([LQ], temperature=temperature, top_p=top_p, qs=qs)
|
172 |
+
else:
|
173 |
+
captions = ['LLaVA is not available. Please add text manually.']
|
174 |
+
torch.cuda.empty_cache() # メモリを解放
|
175 |
+
return captions[0]
|
176 |
+
|
177 |
+
|
178 |
def stage2_process(*args, **kwargs):
|
179 |
try:
|
180 |
return restore_in_Xmin(*args, **kwargs)
|