Spaces:
Running
Running
feat: improve image generation
Browse files- src/agent/llm.py +32 -1
- src/agent/llm_agent.py +58 -23
- src/audio/audio_generator.py +5 -5
- src/config.py +3 -0
- src/css.py +6 -7
- src/game_constructor.py +29 -12
- src/game_setting.py +13 -0
- src/game_state.py +3 -3
- src/images/image_generator.py +55 -27
- src/main.py +41 -15
src/agent/llm.py
CHANGED
|
@@ -12,7 +12,7 @@ def create_llm(temperature: float = settings.temperature, top_p: float = setting
|
|
| 12 |
global _google_api_keys_list, _current_google_key_idx
|
| 13 |
|
| 14 |
if not _google_api_keys_list:
|
| 15 |
-
api_keys_str = settings.
|
| 16 |
if api_keys_str:
|
| 17 |
_google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
|
| 18 |
|
|
@@ -38,6 +38,37 @@ def create_llm(temperature: float = settings.temperature, top_p: float = setting
|
|
| 38 |
top_p=top_p,
|
| 39 |
thinking_budget=1024
|
| 40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def create_precise_llm():
|
| 43 |
return create_llm(temperature=0, top_p=1)
|
|
|
|
| 12 |
global _google_api_keys_list, _current_google_key_idx
|
| 13 |
|
| 14 |
if not _google_api_keys_list:
|
| 15 |
+
api_keys_str = settings.gemini_api_keys.get_secret_value()
|
| 16 |
if api_keys_str:
|
| 17 |
_google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
|
| 18 |
|
|
|
|
| 38 |
top_p=top_p,
|
| 39 |
thinking_budget=1024
|
| 40 |
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def create_light_llm(temperature: float = settings.temperature, top_p: float = settings.top_p):
|
| 44 |
+
global _google_api_keys_list, _current_google_key_idx
|
| 45 |
+
|
| 46 |
+
if not _google_api_keys_list:
|
| 47 |
+
api_keys_str = settings.gemini_api_keys.get_secret_value()
|
| 48 |
+
if api_keys_str:
|
| 49 |
+
_google_api_keys_list = [key.strip() for key in api_keys_str.split(',') if key.strip()]
|
| 50 |
+
|
| 51 |
+
if not _google_api_keys_list:
|
| 52 |
+
logger.error("Google API keys are not configured or are empty in settings.")
|
| 53 |
+
raise ValueError("Google API keys are not configured or are invalid for round-robin.")
|
| 54 |
+
|
| 55 |
+
if not _google_api_keys_list: # Safeguard, though previous block should handle it.
|
| 56 |
+
logger.error("No Google API keys available for round-robin.")
|
| 57 |
+
raise ValueError("No Google API keys available for round-robin.")
|
| 58 |
+
|
| 59 |
+
key_index_to_use = _current_google_key_idx
|
| 60 |
+
selected_api_key = _google_api_keys_list[key_index_to_use]
|
| 61 |
+
|
| 62 |
+
_current_google_key_idx = (key_index_to_use + 1) % len(_google_api_keys_list)
|
| 63 |
+
|
| 64 |
+
logger.debug(f"Using Google API key at index {key_index_to_use} (ending with ...{selected_api_key[-4:] if len(selected_api_key) > 4 else selected_api_key}) for round-robin.")
|
| 65 |
+
|
| 66 |
+
return ChatGoogleGenerativeAI(
|
| 67 |
+
model="gemini-2.0-flash",
|
| 68 |
+
google_api_key=selected_api_key,
|
| 69 |
+
temperature=temperature,
|
| 70 |
+
top_p=top_p
|
| 71 |
+
)
|
| 72 |
|
| 73 |
def create_precise_llm():
|
| 74 |
return create_llm(temperature=0, top_p=1)
|
src/agent/llm_agent.py
CHANGED
|
@@ -1,38 +1,73 @@
|
|
| 1 |
from agent.llm import create_llm
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
-
from typing import
|
| 4 |
import logging
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
| 8 |
-
|
| 9 |
-
change_scene: bool = Field(description="Whether the scene should be changed")
|
| 10 |
-
scene_description: Optional[str] = None
|
| 11 |
-
|
| 12 |
-
class ChangeMusic(BaseModel):
|
| 13 |
-
change_music: bool = Field(description="Whether the music should be changed")
|
| 14 |
-
music_description: Optional[str] = None
|
| 15 |
-
|
| 16 |
class PlayerOption(BaseModel):
|
| 17 |
-
option_description: str = Field(
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
| 19 |
class LLMOutput(BaseModel):
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
player_options: List[PlayerOption] = Field(
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
|
|
|
| 28 |
"""
|
| 29 |
Process user input and update the state.
|
| 30 |
"""
|
| 31 |
-
|
| 32 |
-
|
|
|
|
| 33 |
response: LLMOutput = await llm.ainvoke(input)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
-
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from agent.llm import create_llm
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
+
from typing import List
|
| 4 |
import logging
|
| 5 |
+
from agent.image_agent import ChangeScene
|
| 6 |
+
import asyncio
|
| 7 |
+
from agent.music_agent import generate_music_prompt
|
| 8 |
+
from agent.image_agent import generate_scene_image
|
| 9 |
+
import uuid
|
| 10 |
|
| 11 |
logger = logging.getLogger(__name__)
|
| 12 |
|
| 13 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class PlayerOption(BaseModel):
|
| 15 |
+
option_description: str = Field(
|
| 16 |
+
description="The description of the option, Examples: [Change location] Go to the forest; [Say] Hello!"
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
class LLMOutput(BaseModel):
|
| 21 |
+
game_message: str = Field(
|
| 22 |
+
description="The message to the player, Example: You entered the forest, and you see unknown scary creatures. What do you do?"
|
| 23 |
+
)
|
| 24 |
+
player_options: List[PlayerOption] = Field(
|
| 25 |
+
description="The list of up to 3 options for the player to choose from."
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class MultiAgentResponse(BaseModel):
|
| 30 |
+
game_message: str = Field(
|
| 31 |
+
description="The message to the player, Example: You entered the forest, and you see unknown scary creatures. What do you do?"
|
| 32 |
+
)
|
| 33 |
+
player_options: List[PlayerOption] = Field(
|
| 34 |
+
description="The list of up to 3 options for the player to choose from."
|
| 35 |
+
)
|
| 36 |
+
music_prompt: str = Field(description="The prompt for the music generation model.")
|
| 37 |
+
change_scene: ChangeScene = Field(description="The change to the scene.")
|
| 38 |
+
|
| 39 |
+
llm = create_llm().with_structured_output(MultiAgentResponse)
|
| 40 |
|
| 41 |
+
|
| 42 |
+
async def process_user_input(input: str) -> MultiAgentResponse:
|
| 43 |
"""
|
| 44 |
Process user input and update the state.
|
| 45 |
"""
|
| 46 |
+
request_id = str(uuid.uuid4())
|
| 47 |
+
logger.info(f"LLM input received: {request_id}")
|
| 48 |
+
|
| 49 |
response: LLMOutput = await llm.ainvoke(input)
|
| 50 |
+
|
| 51 |
+
# return response
|
| 52 |
+
current_state = f"""{input}
|
| 53 |
+
|
| 54 |
+
Game reaction: {response.game_message}
|
| 55 |
+
Player options: {response.player_options}
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
music_prompt_task = generate_music_prompt(current_state, request_id)
|
| 59 |
|
| 60 |
+
change_scene_task = generate_scene_image(current_state, request_id)
|
| 61 |
|
| 62 |
+
music_prompt, change_scene = await asyncio.gather(music_prompt_task, change_scene_task)
|
| 63 |
+
|
| 64 |
+
multi_agent_response = MultiAgentResponse(
|
| 65 |
+
game_message=response.game_message,
|
| 66 |
+
player_options=response.player_options,
|
| 67 |
+
music_prompt=music_prompt,
|
| 68 |
+
change_scene=change_scene,
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
logger.info(f"LLM responded: {request_id}")
|
| 72 |
+
|
| 73 |
+
return multi_agent_response
|
src/audio/audio_generator.py
CHANGED
|
@@ -13,10 +13,12 @@ logger = logging.getLogger(__name__)
|
|
| 13 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
|
| 14 |
|
| 15 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
| 16 |
-
|
|
|
|
|
|
|
| 17 |
client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
|
| 18 |
asyncio.TaskGroup() as tg,
|
| 19 |
-
|
| 20 |
# Set up task to receive server messages.
|
| 21 |
tg.create_task(receive_audio(session, user_hash))
|
| 22 |
|
|
@@ -31,10 +33,9 @@ async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
|
| 31 |
)
|
| 32 |
await session.play()
|
| 33 |
logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
|
| 34 |
-
await cleanup_music_session(user_hash)
|
| 35 |
sessions[user_hash] = {
|
| 36 |
'session': session,
|
| 37 |
-
'queue': queue.Queue(
|
| 38 |
}
|
| 39 |
|
| 40 |
async def change_music_tone(user_hash: str, new_tone):
|
|
@@ -43,7 +44,6 @@ async def change_music_tone(user_hash: str, new_tone):
|
|
| 43 |
if not session:
|
| 44 |
logger.error(f"No session found for user hash {user_hash}")
|
| 45 |
return
|
| 46 |
-
await session.reset_context()
|
| 47 |
await session.set_weighted_prompts(
|
| 48 |
prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
|
| 49 |
)
|
|
|
|
| 13 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value(), http_options={'api_version': 'v1alpha'})
|
| 14 |
|
| 15 |
async def generate_music(user_hash: str, music_tone: str, receive_audio):
|
| 16 |
+
if user_hash in sessions:
|
| 17 |
+
return
|
| 18 |
+
async with (
|
| 19 |
client.aio.live.music.connect(model='models/lyria-realtime-exp') as session,
|
| 20 |
asyncio.TaskGroup() as tg,
|
| 21 |
+
):
|
| 22 |
# Set up task to receive server messages.
|
| 23 |
tg.create_task(receive_audio(session, user_hash))
|
| 24 |
|
|
|
|
| 33 |
)
|
| 34 |
await session.play()
|
| 35 |
logger.info(f"Started music generation for user hash {user_hash}, music tone: {music_tone}")
|
|
|
|
| 36 |
sessions[user_hash] = {
|
| 37 |
'session': session,
|
| 38 |
+
'queue': queue.Queue()
|
| 39 |
}
|
| 40 |
|
| 41 |
async def change_music_tone(user_hash: str, new_tone):
|
|
|
|
| 44 |
if not session:
|
| 45 |
logger.error(f"No session found for user hash {user_hash}")
|
| 46 |
return
|
|
|
|
| 47 |
await session.set_weighted_prompts(
|
| 48 |
prompts=[types.WeightedPrompt(text=new_tone, weight=1.0)]
|
| 49 |
)
|
src/config.py
CHANGED
|
@@ -21,8 +21,11 @@ class BaseAppSettings(BaseSettings):
|
|
| 21 |
|
| 22 |
class AppSettings(BaseAppSettings):
|
| 23 |
gemini_api_key: SecretStr
|
|
|
|
|
|
|
| 24 |
top_p: float = 0.95
|
| 25 |
temperature: float = 0.5
|
|
|
|
| 26 |
|
| 27 |
|
| 28 |
settings = AppSettings()
|
|
|
|
| 21 |
|
| 22 |
class AppSettings(BaseAppSettings):
|
| 23 |
gemini_api_key: SecretStr
|
| 24 |
+
gemini_api_keys: SecretStr
|
| 25 |
+
# assistant_api_key: SecretStr
|
| 26 |
top_p: float = 0.95
|
| 27 |
temperature: float = 0.5
|
| 28 |
+
pregenerate_next_scene: bool = True
|
| 29 |
|
| 30 |
|
| 31 |
settings = AppSettings()
|
src/css.py
CHANGED
|
@@ -33,11 +33,11 @@ custom_css = """
|
|
| 33 |
background: rgba(0,0,0,0.7) !important;
|
| 34 |
border: none !important;
|
| 35 |
color: white !important;
|
| 36 |
-
font-size:
|
| 37 |
line-height: 1.5 !important;
|
| 38 |
-
padding:
|
| 39 |
border-radius: 10px !important;
|
| 40 |
-
margin-bottom:
|
| 41 |
}
|
| 42 |
|
| 43 |
img {
|
|
@@ -49,7 +49,7 @@ img {
|
|
| 49 |
border: none !important;
|
| 50 |
color: white !important;
|
| 51 |
-webkit-text-fill-color: white !important;
|
| 52 |
-
font-size:
|
| 53 |
resize: none !important;
|
| 54 |
}
|
| 55 |
|
|
@@ -57,13 +57,12 @@ img {
|
|
| 57 |
.choice-buttons {
|
| 58 |
background: rgba(0,0,0,0.7) !important;
|
| 59 |
border-radius: 10px !important;
|
| 60 |
-
padding:
|
| 61 |
}
|
| 62 |
|
| 63 |
.choice-buttons label {
|
| 64 |
color: white !important;
|
| 65 |
-
font-size:
|
| 66 |
-
margin-bottom: 10px !important;
|
| 67 |
}
|
| 68 |
|
| 69 |
/* Fix radio button backgrounds */
|
|
|
|
| 33 |
background: rgba(0,0,0,0.7) !important;
|
| 34 |
border: none !important;
|
| 35 |
color: white !important;
|
| 36 |
+
font-size: 15px !important;
|
| 37 |
line-height: 1.5 !important;
|
| 38 |
+
padding: 10px !important;
|
| 39 |
border-radius: 10px !important;
|
| 40 |
+
margin-bottom: 10px !important;
|
| 41 |
}
|
| 42 |
|
| 43 |
img {
|
|
|
|
| 49 |
border: none !important;
|
| 50 |
color: white !important;
|
| 51 |
-webkit-text-fill-color: white !important;
|
| 52 |
+
font-size: 15px !important;
|
| 53 |
resize: none !important;
|
| 54 |
}
|
| 55 |
|
|
|
|
| 57 |
.choice-buttons {
|
| 58 |
background: rgba(0,0,0,0.7) !important;
|
| 59 |
border-radius: 10px !important;
|
| 60 |
+
padding: 10px !important;
|
| 61 |
}
|
| 62 |
|
| 63 |
.choice-buttons label {
|
| 64 |
color: white !important;
|
| 65 |
+
font-size: 14px !important;
|
|
|
|
| 66 |
}
|
| 67 |
|
| 68 |
/* Fix radio button backgrounds */
|
src/game_constructor.py
CHANGED
|
@@ -1,12 +1,14 @@
|
|
| 1 |
import gradio as gr
|
| 2 |
import json
|
| 3 |
import uuid
|
| 4 |
-
from game_setting import Character, GameSetting
|
| 5 |
from game_state import story, state, get_current_scene
|
| 6 |
from agent.llm_agent import process_user_input
|
| 7 |
from images.image_generator import generate_image
|
| 8 |
from audio.audio_generator import start_music_generation
|
| 9 |
import asyncio
|
|
|
|
|
|
|
| 10 |
|
| 11 |
# Predefined suggestions for demo
|
| 12 |
SETTING_SUGGESTIONS = [
|
|
@@ -107,6 +109,7 @@ def save_game_config(
|
|
| 107 |
except Exception as e:
|
| 108 |
return f"❌ Error saving configuration: {str(e)}"
|
| 109 |
|
|
|
|
| 110 |
async def start_game_with_settings(
|
| 111 |
user_hash: str,
|
| 112 |
setting_desc: str,
|
|
@@ -155,27 +158,41 @@ Genre: {game_setting.genre}
|
|
| 155 |
|
| 156 |
You find yourself at the beginning of your adventure. The world around you feels alive with possibilities. What do you choose to do first?
|
| 157 |
|
| 158 |
-
NOTE FOR THE ASSISTANT: YOU HAVE TO GENERATE
|
| 159 |
"""
|
| 160 |
|
| 161 |
response = await process_user_input(initial_story)
|
| 162 |
-
|
| 163 |
-
music_tone = response.
|
| 164 |
-
|
| 165 |
asyncio.create_task(start_music_generation(user_hash, music_tone))
|
| 166 |
|
| 167 |
img = "forest.jpg"
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
|
|
|
|
|
|
| 173 |
|
| 174 |
story["start"] = {
|
| 175 |
"text": response.game_message,
|
| 176 |
"image": img,
|
| 177 |
-
"choices":
|
| 178 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
}
|
| 180 |
state["scene"] = "start"
|
| 181 |
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
import json
|
| 3 |
import uuid
|
| 4 |
+
from game_setting import Character, GameSetting, get_user_story
|
| 5 |
from game_state import story, state, get_current_scene
|
| 6 |
from agent.llm_agent import process_user_input
|
| 7 |
from images.image_generator import generate_image
|
| 8 |
from audio.audio_generator import start_music_generation
|
| 9 |
import asyncio
|
| 10 |
+
from config import settings
|
| 11 |
+
|
| 12 |
|
| 13 |
# Predefined suggestions for demo
|
| 14 |
SETTING_SUGGESTIONS = [
|
|
|
|
| 109 |
except Exception as e:
|
| 110 |
return f"❌ Error saving configuration: {str(e)}"
|
| 111 |
|
| 112 |
+
|
| 113 |
async def start_game_with_settings(
|
| 114 |
user_hash: str,
|
| 115 |
setting_desc: str,
|
|
|
|
| 158 |
|
| 159 |
You find yourself at the beginning of your adventure. The world around you feels alive with possibilities. What do you choose to do first?
|
| 160 |
|
| 161 |
+
NOTE FOR THE ASSISTANT: YOU HAVE TO GENERATE A NEW IMAGE FOR THE START SCENE.
|
| 162 |
"""
|
| 163 |
|
| 164 |
response = await process_user_input(initial_story)
|
| 165 |
+
|
| 166 |
+
music_tone = response.music_prompt
|
| 167 |
+
|
| 168 |
asyncio.create_task(start_music_generation(user_hash, music_tone))
|
| 169 |
|
| 170 |
img = "forest.jpg"
|
| 171 |
+
img_description = ""
|
| 172 |
+
|
| 173 |
+
img_path, img_description = await generate_image(
|
| 174 |
+
response.change_scene.scene_description
|
| 175 |
+
)
|
| 176 |
+
if img_path:
|
| 177 |
+
img = img_path
|
| 178 |
|
| 179 |
story["start"] = {
|
| 180 |
"text": response.game_message,
|
| 181 |
"image": img,
|
| 182 |
+
"choices": {
|
| 183 |
+
option.option_description: asyncio.create_task(
|
| 184 |
+
process_user_input(
|
| 185 |
+
get_user_story(
|
| 186 |
+
response.game_message,
|
| 187 |
+
response.change_scene.scene_description,
|
| 188 |
+
option.option_description,
|
| 189 |
+
)
|
| 190 |
+
)
|
| 191 |
+
) if settings.pregenerate_next_scene else None
|
| 192 |
+
for option in response.player_options
|
| 193 |
+
},
|
| 194 |
+
"music_tone": response.music_prompt,
|
| 195 |
+
"img_description": img_description,
|
| 196 |
}
|
| 197 |
state["scene"] = "start"
|
| 198 |
|
src/game_setting.py
CHANGED
|
@@ -1,12 +1,25 @@
|
|
| 1 |
from pydantic import BaseModel
|
| 2 |
|
|
|
|
| 3 |
class Character(BaseModel):
|
| 4 |
name: str
|
| 5 |
age: str
|
| 6 |
background: str
|
| 7 |
personality: str
|
| 8 |
|
|
|
|
| 9 |
class GameSetting(BaseModel):
|
| 10 |
character: Character
|
| 11 |
setting: str
|
| 12 |
genre: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from pydantic import BaseModel
|
| 2 |
|
| 3 |
+
|
| 4 |
class Character(BaseModel):
|
| 5 |
name: str
|
| 6 |
age: str
|
| 7 |
background: str
|
| 8 |
personality: str
|
| 9 |
|
| 10 |
+
|
| 11 |
class GameSetting(BaseModel):
|
| 12 |
character: Character
|
| 13 |
setting: str
|
| 14 |
genre: str
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def get_user_story(
|
| 18 |
+
scene_description: str, scene_image_description: str, user_choice: str
|
| 19 |
+
) -> str:
|
| 20 |
+
return f"""Current scene description:
|
| 21 |
+
{scene_description}
|
| 22 |
+
Current scene image description: {scene_image_description}
|
| 23 |
+
|
| 24 |
+
User's choice: {user_choice}
|
| 25 |
+
"""
|
src/game_state.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
| 1 |
-
|
| 2 |
story = {
|
| 3 |
"start": {
|
| 4 |
"text": "You wake up in a mysterious forest. What do you do?",
|
| 5 |
"image": "forest.jpg",
|
| 6 |
-
"choices":
|
| 7 |
"music_tone": "neutral",
|
|
|
|
| 8 |
},
|
| 9 |
}
|
| 10 |
|
|
@@ -12,4 +12,4 @@ state = {"scene": "start"}
|
|
| 12 |
|
| 13 |
def get_current_scene():
|
| 14 |
scene = story[state["scene"]]
|
| 15 |
-
return scene["text"], scene["image"], scene["choices"]
|
|
|
|
|
|
|
| 1 |
story = {
|
| 2 |
"start": {
|
| 3 |
"text": "You wake up in a mysterious forest. What do you do?",
|
| 4 |
"image": "forest.jpg",
|
| 5 |
+
"choices": {"Explore": None, "Wait": None},
|
| 6 |
"music_tone": "neutral",
|
| 7 |
+
"img_description": "forest in the fog",
|
| 8 |
},
|
| 9 |
}
|
| 10 |
|
|
|
|
| 12 |
|
| 13 |
def get_current_scene():
|
| 14 |
scene = story[state["scene"]]
|
| 15 |
+
return scene["text"], scene["image"], scene["choices"].keys()
|
src/images/image_generator.py
CHANGED
|
@@ -6,25 +6,47 @@ from io import BytesIO
|
|
| 6 |
from datetime import datetime
|
| 7 |
from config import settings
|
| 8 |
import logging
|
|
|
|
|
|
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
async def generate_image(prompt: str) -> tuple[str, str] | None:
|
| 15 |
"""
|
| 16 |
Generate an image using Google's Gemini model and save it to generated/images directory.
|
| 17 |
-
|
| 18 |
Args:
|
| 19 |
prompt (str): The text prompt to generate the image from
|
| 20 |
-
|
| 21 |
Returns:
|
| 22 |
str: Path to the generated image file, or None if generation failed
|
| 23 |
"""
|
| 24 |
# Ensure the generated/images directory exists
|
| 25 |
output_dir = "generated/images"
|
| 26 |
os.makedirs(output_dir, exist_ok=True)
|
| 27 |
-
|
| 28 |
logger.info(f"Generating image with prompt: {prompt}")
|
| 29 |
|
| 30 |
try:
|
|
@@ -32,8 +54,9 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
|
|
| 32 |
model="gemini-2.0-flash-preview-image-generation",
|
| 33 |
contents=prompt,
|
| 34 |
config=types.GenerateContentConfig(
|
| 35 |
-
response_modalities=[
|
| 36 |
-
|
|
|
|
| 37 |
)
|
| 38 |
|
| 39 |
# Process the response parts
|
|
@@ -44,19 +67,20 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
|
|
| 44 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 45 |
filename = f"gemini_{timestamp}.png"
|
| 46 |
filepath = os.path.join(output_dir, filename)
|
| 47 |
-
|
| 48 |
# Save the image
|
| 49 |
image = Image.open(BytesIO(part.inline_data.data))
|
| 50 |
-
image.save
|
| 51 |
logger.info(f"Image saved to: {filepath}")
|
| 52 |
image_saved = True
|
| 53 |
-
|
| 54 |
-
return filepath,
|
| 55 |
-
|
| 56 |
if not image_saved:
|
|
|
|
| 57 |
logger.error("No image was generated in the response.")
|
| 58 |
return None, None
|
| 59 |
-
|
| 60 |
except Exception as e:
|
| 61 |
logger.error(f"Error generating image: {e}")
|
| 62 |
return None, None
|
|
@@ -65,38 +89,41 @@ async def generate_image(prompt: str) -> tuple[str, str] | None:
|
|
| 65 |
async def modify_image(image_path: str, modification_prompt: str) -> str | None:
|
| 66 |
"""
|
| 67 |
Modify an existing image using Google's Gemini model based on a text prompt.
|
| 68 |
-
|
| 69 |
Args:
|
| 70 |
image_path (str): Path to the existing image file
|
| 71 |
modification_prompt (str): The text prompt describing how to modify the image
|
| 72 |
-
|
| 73 |
Returns:
|
| 74 |
str: Path to the modified image file, or None if modification failed
|
| 75 |
"""
|
| 76 |
# Ensure the generated/images directory exists
|
| 77 |
output_dir = "generated/images"
|
| 78 |
os.makedirs(output_dir, exist_ok=True)
|
| 79 |
-
|
|
|
|
|
|
|
| 80 |
# Check if the input image exists
|
| 81 |
if not os.path.exists(image_path):
|
| 82 |
logger.error(f"Error: Image file not found at {image_path}")
|
| 83 |
return None
|
| 84 |
-
|
| 85 |
key = settings.gemini_api_key.get_secret_value()
|
| 86 |
-
|
| 87 |
client = genai.Client(api_key=key).aio
|
| 88 |
|
| 89 |
try:
|
| 90 |
# Load the input image
|
| 91 |
input_image = Image.open(image_path)
|
| 92 |
-
|
| 93 |
# Make the API call with both text and image
|
| 94 |
response = await client.models.generate_content(
|
| 95 |
model="gemini-2.0-flash-preview-image-generation",
|
| 96 |
contents=[modification_prompt, input_image],
|
| 97 |
config=types.GenerateContentConfig(
|
| 98 |
-
response_modalities=[
|
| 99 |
-
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
# Process the response parts
|
|
@@ -107,19 +134,20 @@ async def modify_image(image_path: str, modification_prompt: str) -> str | None:
|
|
| 107 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 108 |
filename = f"gemini_modified_{timestamp}.png"
|
| 109 |
filepath = os.path.join(output_dir, filename)
|
| 110 |
-
|
| 111 |
# Save the modified image
|
| 112 |
modified_image = Image.open(BytesIO(part.inline_data.data))
|
| 113 |
-
modified_image.save
|
| 114 |
logger.info(f"Modified image saved to: {filepath}")
|
| 115 |
image_saved = True
|
| 116 |
-
|
| 117 |
-
return filepath,
|
| 118 |
-
|
| 119 |
if not image_saved:
|
|
|
|
| 120 |
logger.error("No modified image was generated in the response.")
|
| 121 |
return None, None
|
| 122 |
-
|
| 123 |
except Exception as e:
|
| 124 |
logger.error(f"Error modifying image: {e}")
|
| 125 |
return None, None
|
|
@@ -129,10 +157,10 @@ if __name__ == "__main__":
|
|
| 129 |
# Example usage
|
| 130 |
sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game"
|
| 131 |
generated_image_path = generate_image(sample_prompt)
|
| 132 |
-
|
| 133 |
# if generated_image_path:
|
| 134 |
# # Example modification
|
| 135 |
# modification_prompt = "Now the house is destroyed, and the jawas are running away"
|
| 136 |
# modified_image_path = modify_image(generated_image_path, modification_prompt)
|
| 137 |
# if modified_image_path:
|
| 138 |
-
# print(f"Successfully modified image: {modified_image_path}")
|
|
|
|
| 6 |
from datetime import datetime
|
| 7 |
from config import settings
|
| 8 |
import logging
|
| 9 |
+
import asyncio
|
| 10 |
+
import gradio as gr
|
| 11 |
|
| 12 |
logger = logging.getLogger(__name__)
|
| 13 |
|
| 14 |
client = genai.Client(api_key=settings.gemini_api_key.get_secret_value()).aio
|
| 15 |
|
| 16 |
+
safety_settings = [
|
| 17 |
+
types.SafetySetting(
|
| 18 |
+
category="HARM_CATEGORY_HARASSMENT",
|
| 19 |
+
threshold="BLOCK_NONE", # Block none
|
| 20 |
+
),
|
| 21 |
+
types.SafetySetting(
|
| 22 |
+
category="HARM_CATEGORY_HATE_SPEECH",
|
| 23 |
+
threshold="BLOCK_NONE", # Block none
|
| 24 |
+
),
|
| 25 |
+
types.SafetySetting(
|
| 26 |
+
category="HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
| 27 |
+
threshold="BLOCK_NONE", # Block none
|
| 28 |
+
),
|
| 29 |
+
types.SafetySetting(
|
| 30 |
+
category="HARM_CATEGORY_DANGEROUS_CONTENT",
|
| 31 |
+
threshold="BLOCK_NONE", # Block none
|
| 32 |
+
),
|
| 33 |
+
]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
async def generate_image(prompt: str) -> tuple[str, str] | None:
|
| 37 |
"""
|
| 38 |
Generate an image using Google's Gemini model and save it to generated/images directory.
|
| 39 |
+
|
| 40 |
Args:
|
| 41 |
prompt (str): The text prompt to generate the image from
|
| 42 |
+
|
| 43 |
Returns:
|
| 44 |
str: Path to the generated image file, or None if generation failed
|
| 45 |
"""
|
| 46 |
# Ensure the generated/images directory exists
|
| 47 |
output_dir = "generated/images"
|
| 48 |
os.makedirs(output_dir, exist_ok=True)
|
| 49 |
+
|
| 50 |
logger.info(f"Generating image with prompt: {prompt}")
|
| 51 |
|
| 52 |
try:
|
|
|
|
| 54 |
model="gemini-2.0-flash-preview-image-generation",
|
| 55 |
contents=prompt,
|
| 56 |
config=types.GenerateContentConfig(
|
| 57 |
+
response_modalities=["TEXT", "IMAGE"],
|
| 58 |
+
safety_settings=safety_settings,
|
| 59 |
+
),
|
| 60 |
)
|
| 61 |
|
| 62 |
# Process the response parts
|
|
|
|
| 67 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 68 |
filename = f"gemini_{timestamp}.png"
|
| 69 |
filepath = os.path.join(output_dir, filename)
|
| 70 |
+
|
| 71 |
# Save the image
|
| 72 |
image = Image.open(BytesIO(part.inline_data.data))
|
| 73 |
+
await asyncio.to_thread(image.save, filepath, "PNG")
|
| 74 |
logger.info(f"Image saved to: {filepath}")
|
| 75 |
image_saved = True
|
| 76 |
+
|
| 77 |
+
return filepath, prompt
|
| 78 |
+
|
| 79 |
if not image_saved:
|
| 80 |
+
gr.Warning("Image was censored by Google!")
|
| 81 |
logger.error("No image was generated in the response.")
|
| 82 |
return None, None
|
| 83 |
+
|
| 84 |
except Exception as e:
|
| 85 |
logger.error(f"Error generating image: {e}")
|
| 86 |
return None, None
|
|
|
|
| 89 |
async def modify_image(image_path: str, modification_prompt: str) -> str | None:
|
| 90 |
"""
|
| 91 |
Modify an existing image using Google's Gemini model based on a text prompt.
|
| 92 |
+
|
| 93 |
Args:
|
| 94 |
image_path (str): Path to the existing image file
|
| 95 |
modification_prompt (str): The text prompt describing how to modify the image
|
| 96 |
+
|
| 97 |
Returns:
|
| 98 |
str: Path to the modified image file, or None if modification failed
|
| 99 |
"""
|
| 100 |
# Ensure the generated/images directory exists
|
| 101 |
output_dir = "generated/images"
|
| 102 |
os.makedirs(output_dir, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
logger.info(f"Modifying current scene image with prompt: {modification_prompt}")
|
| 105 |
+
|
| 106 |
# Check if the input image exists
|
| 107 |
if not os.path.exists(image_path):
|
| 108 |
logger.error(f"Error: Image file not found at {image_path}")
|
| 109 |
return None
|
| 110 |
+
|
| 111 |
key = settings.gemini_api_key.get_secret_value()
|
| 112 |
+
|
| 113 |
client = genai.Client(api_key=key).aio
|
| 114 |
|
| 115 |
try:
|
| 116 |
# Load the input image
|
| 117 |
input_image = Image.open(image_path)
|
| 118 |
+
|
| 119 |
# Make the API call with both text and image
|
| 120 |
response = await client.models.generate_content(
|
| 121 |
model="gemini-2.0-flash-preview-image-generation",
|
| 122 |
contents=[modification_prompt, input_image],
|
| 123 |
config=types.GenerateContentConfig(
|
| 124 |
+
response_modalities=["TEXT", "IMAGE"],
|
| 125 |
+
safety_settings=safety_settings,
|
| 126 |
+
),
|
| 127 |
)
|
| 128 |
|
| 129 |
# Process the response parts
|
|
|
|
| 134 |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 135 |
filename = f"gemini_modified_{timestamp}.png"
|
| 136 |
filepath = os.path.join(output_dir, filename)
|
| 137 |
+
|
| 138 |
# Save the modified image
|
| 139 |
modified_image = Image.open(BytesIO(part.inline_data.data))
|
| 140 |
+
await asyncio.to_thread(modified_image.save, filepath, "PNG")
|
| 141 |
logger.info(f"Modified image saved to: {filepath}")
|
| 142 |
image_saved = True
|
| 143 |
+
|
| 144 |
+
return filepath, modification_prompt
|
| 145 |
+
|
| 146 |
if not image_saved:
|
| 147 |
+
gr.Warning("Updated image was censored by Google!")
|
| 148 |
logger.error("No modified image was generated in the response.")
|
| 149 |
return None, None
|
| 150 |
+
|
| 151 |
except Exception as e:
|
| 152 |
logger.error(f"Error modifying image: {e}")
|
| 153 |
return None, None
|
|
|
|
| 157 |
# Example usage
|
| 158 |
sample_prompt = "A Luke Skywalker half height sprite with white background for visual novel game"
|
| 159 |
generated_image_path = generate_image(sample_prompt)
|
| 160 |
+
|
| 161 |
# if generated_image_path:
|
| 162 |
# # Example modification
|
| 163 |
# modification_prompt = "Now the house is destroyed, and the jawas are running away"
|
| 164 |
# modified_image_path = modify_image(generated_image_path, modification_prompt)
|
| 165 |
# if modified_image_path:
|
| 166 |
+
# print(f"Successfully modified image: {modified_image_path}")
|
src/main.py
CHANGED
|
@@ -7,7 +7,7 @@ from audio.audio_generator import (
|
|
| 7 |
)
|
| 8 |
import logging
|
| 9 |
from agent.llm_agent import process_user_input
|
| 10 |
-
from images.image_generator import
|
| 11 |
import uuid
|
| 12 |
from game_state import story, state
|
| 13 |
from game_constructor import (
|
|
@@ -19,6 +19,8 @@ from game_constructor import (
|
|
| 19 |
start_game_with_settings,
|
| 20 |
)
|
| 21 |
import asyncio
|
|
|
|
|
|
|
| 22 |
|
| 23 |
logger = logging.getLogger(__name__)
|
| 24 |
|
|
@@ -43,29 +45,53 @@ async def update_scene(user_hash: str, choice):
|
|
| 43 |
}
|
| 44 |
state["scene"] = new_scene
|
| 45 |
|
| 46 |
-
user_story =
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
"""
|
| 50 |
|
| 51 |
-
response = await
|
|
|
|
|
|
|
| 52 |
|
| 53 |
story[new_scene]["text"] = response.game_message
|
| 54 |
|
| 55 |
-
story[new_scene]["choices"] =
|
| 56 |
-
option.option_description
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
# run both tasks in parallel
|
| 60 |
img_res, _ = await asyncio.gather(
|
| 61 |
-
|
| 62 |
-
change_music_tone(user_hash, response.change_music.music_description) if response.change_music.change_music else asyncio.sleep(0)
|
| 63 |
)
|
| 64 |
-
|
| 65 |
if img_res and response.change_scene.change_scene:
|
| 66 |
-
img_path,
|
| 67 |
if img_path:
|
| 68 |
story[new_scene]["image"] = img_path
|
|
|
|
| 69 |
|
| 70 |
scene = story[state["scene"]]
|
| 71 |
return (
|
|
@@ -136,7 +162,7 @@ with gr.Blocks(
|
|
| 136 |
# Fullscreen Loading Indicator (hidden by default)
|
| 137 |
with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
|
| 138 |
gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
|
| 139 |
-
|
| 140 |
local_storage = gr.BrowserState(str(uuid.uuid4()), "user_hash")
|
| 141 |
|
| 142 |
# Constructor Interface (visible by default)
|
|
|
|
| 7 |
)
|
| 8 |
import logging
|
| 9 |
from agent.llm_agent import process_user_input
|
| 10 |
+
from images.image_generator import modify_image
|
| 11 |
import uuid
|
| 12 |
from game_state import story, state
|
| 13 |
from game_constructor import (
|
|
|
|
| 19 |
start_game_with_settings,
|
| 20 |
)
|
| 21 |
import asyncio
|
| 22 |
+
from game_setting import get_user_story
|
| 23 |
+
from config import settings
|
| 24 |
|
| 25 |
logger = logging.getLogger(__name__)
|
| 26 |
|
|
|
|
| 45 |
}
|
| 46 |
state["scene"] = new_scene
|
| 47 |
|
| 48 |
+
user_story = get_user_story(
|
| 49 |
+
story[old_scene]["text"], story[old_scene]["img_description"], choice
|
| 50 |
+
)
|
|
|
|
| 51 |
|
| 52 |
+
response = await (
|
| 53 |
+
story[old_scene]["choices"][choice] or process_user_input(user_story)
|
| 54 |
+
)
|
| 55 |
|
| 56 |
story[new_scene]["text"] = response.game_message
|
| 57 |
|
| 58 |
+
story[new_scene]["choices"] = {
|
| 59 |
+
option.option_description: asyncio.create_task(
|
| 60 |
+
process_user_input(
|
| 61 |
+
get_user_story(
|
| 62 |
+
response.game_message,
|
| 63 |
+
response.change_scene.scene_description,
|
| 64 |
+
option.option_description,
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
)
|
| 68 |
+
if settings.pregenerate_next_scene
|
| 69 |
+
else None
|
| 70 |
+
for option in response.player_options
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
img_task = None
|
| 74 |
+
# always modify the image to avoid hallucinations in which image is being generated in entirely different style
|
| 75 |
+
if (
|
| 76 |
+
response.change_scene.change_scene == "change_completely"
|
| 77 |
+
or response.change_scene.change_scene == "modify"
|
| 78 |
+
):
|
| 79 |
+
img_task = modify_image(
|
| 80 |
+
story[old_scene]["image"], response.change_scene.scene_description
|
| 81 |
+
)
|
| 82 |
+
else:
|
| 83 |
+
img_task = asyncio.sleep(0)
|
| 84 |
+
|
| 85 |
# run both tasks in parallel
|
| 86 |
img_res, _ = await asyncio.gather(
|
| 87 |
+
img_task, change_music_tone(user_hash, response.music_prompt)
|
|
|
|
| 88 |
)
|
| 89 |
+
|
| 90 |
if img_res and response.change_scene.change_scene:
|
| 91 |
+
img_path, img_description = img_res
|
| 92 |
if img_path:
|
| 93 |
story[new_scene]["image"] = img_path
|
| 94 |
+
story[new_scene]["img_description"] = img_description
|
| 95 |
|
| 96 |
scene = story[state["scene"]]
|
| 97 |
return (
|
|
|
|
| 162 |
# Fullscreen Loading Indicator (hidden by default)
|
| 163 |
with gr.Column(visible=False, elem_id="loading-indicator") as loading_indicator:
|
| 164 |
gr.HTML("<div class='loading-text'>🚀 Starting your adventure...</div>")
|
| 165 |
+
|
| 166 |
local_storage = gr.BrowserState(str(uuid.uuid4()), "user_hash")
|
| 167 |
|
| 168 |
# Constructor Interface (visible by default)
|