cli.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327
  1. import json
  2. import os
  3. import shutil
  4. import signal
  5. import subprocess
  6. import sys
  7. import tempfile
  8. import time
  9. import zipfile
  10. from pathlib import Path
  11. import click
  12. import requests
  13. from rich.console import Console
  14. from embedchain.telemetry.posthog import AnonymousTelemetry
  15. from embedchain.utils.cli import (deploy_fly, deploy_gradio_app,
  16. deploy_hf_spaces, deploy_modal,
  17. deploy_render, deploy_streamlit,
  18. get_pkg_path_from_name, setup_fly_io_app,
  19. setup_gradio_app, setup_hf_app,
  20. setup_modal_com_app, setup_render_com_app,
  21. setup_streamlit_io_app)
  22. console = Console()
  23. api_process = None
  24. ui_process = None
  25. anonymous_telemetry = AnonymousTelemetry()
  26. def signal_handler(sig, frame):
  27. """Signal handler to catch termination signals and kill server processes."""
  28. global api_process, ui_process
  29. console.print("\n🛑 [bold yellow]Stopping servers...[/bold yellow]")
  30. if api_process:
  31. api_process.terminate()
  32. console.print("🛑 [bold yellow]API server stopped.[/bold yellow]")
  33. if ui_process:
  34. ui_process.terminate()
  35. console.print("🛑 [bold yellow]UI server stopped.[/bold yellow]")
  36. sys.exit(0)
  37. @click.group()
  38. def cli():
  39. pass
  40. @cli.command()
  41. @click.argument("app_name")
  42. @click.option("--docker", is_flag=True, help="Use docker to create the app.")
  43. @click.pass_context
  44. def create_app(ctx, app_name, docker):
  45. if Path(app_name).exists():
  46. console.print(
  47. f"❌ [red]Directory '{app_name}' already exists. Try using a new directory name, or remove it.[/red]"
  48. )
  49. return
  50. os.makedirs(app_name)
  51. os.chdir(app_name)
  52. # Step 1: Download the zip file
  53. zip_url = "http://github.com/embedchain/ec-admin/archive/main.zip"
  54. console.print(f"Creating a new embedchain app in [green]{Path().resolve()}[/green]\n")
  55. try:
  56. response = requests.get(zip_url)
  57. response.raise_for_status()
  58. with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
  59. tmp_file.write(response.content)
  60. zip_file_path = tmp_file.name
  61. console.print("✅ [bold green]Fetched template successfully.[/bold green]")
  62. except requests.RequestException as e:
  63. console.print(f"❌ [bold red]Failed to download zip file: {e}[/bold red]")
  64. anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": False})
  65. return
  66. # Step 2: Extract the zip file
  67. try:
  68. with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
  69. # Get the name of the root directory inside the zip file
  70. root_dir = Path(zip_ref.namelist()[0])
  71. for member in zip_ref.infolist():
  72. # Build the path to extract the file to, skipping the root directory
  73. target_file = Path(member.filename).relative_to(root_dir)
  74. source_file = zip_ref.open(member, "r")
  75. if member.is_dir():
  76. # Create directory if it doesn't exist
  77. os.makedirs(target_file, exist_ok=True)
  78. else:
  79. with open(target_file, "wb") as file:
  80. # Write the file
  81. shutil.copyfileobj(source_file, file)
  82. console.print("✅ [bold green]Extracted zip file successfully.[/bold green]")
  83. anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": True})
  84. except zipfile.BadZipFile:
  85. console.print("❌ [bold red]Error in extracting zip file. The file might be corrupted.[/bold red]")
  86. anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": False})
  87. return
  88. if docker:
  89. subprocess.run(["docker-compose", "build"], check=True)
  90. else:
  91. ctx.invoke(install_reqs)
  92. @cli.command()
  93. def install_reqs():
  94. try:
  95. console.print("Installing python requirements...\n")
  96. time.sleep(2)
  97. os.chdir("api")
  98. subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True)
  99. os.chdir("..")
  100. console.print("\n ✅ [bold green]Installed API requirements successfully.[/bold green]\n")
  101. except Exception as e:
  102. console.print(f"❌ [bold red]Failed to install API requirements: {e}[/bold red]")
  103. anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": False})
  104. return
  105. try:
  106. os.chdir("ui")
  107. subprocess.run(["yarn"], check=True)
  108. console.print("\n✅ [bold green]Successfully installed frontend requirements.[/bold green]")
  109. anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": True})
  110. except Exception as e:
  111. console.print(f"❌ [bold red]Failed to install frontend requirements. Error: {e}[/bold red]")
  112. anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": False})
  113. @cli.command()
  114. @click.option("--docker", is_flag=True, help="Run inside docker.")
  115. def start(docker):
  116. if docker:
  117. subprocess.run(["docker-compose", "up"], check=True)
  118. return
  119. # Set up signal handling
  120. signal.signal(signal.SIGINT, signal_handler)
  121. signal.signal(signal.SIGTERM, signal_handler)
  122. # Step 1: Start the API server
  123. try:
  124. os.chdir("api")
  125. api_process = subprocess.Popen(["python", "-m", "main"], stdout=None, stderr=None)
  126. os.chdir("..")
  127. console.print("✅ [bold green]API server started successfully.[/bold green]")
  128. except Exception as e:
  129. console.print(f"❌ [bold red]Failed to start the API server: {e}[/bold red]")
  130. anonymous_telemetry.capture(event_name="ec_start", properties={"success": False})
  131. return
  132. # Sleep for 2 seconds to give the user time to read the message
  133. time.sleep(2)
  134. # Step 2: Install UI requirements and start the UI server
  135. try:
  136. os.chdir("ui")
  137. subprocess.run(["yarn"], check=True)
  138. ui_process = subprocess.Popen(["yarn", "dev"])
  139. console.print("✅ [bold green]UI server started successfully.[/bold green]")
  140. anonymous_telemetry.capture(event_name="ec_start", properties={"success": True})
  141. except Exception as e:
  142. console.print(f"❌ [bold red]Failed to start the UI server: {e}[/bold red]")
  143. anonymous_telemetry.capture(event_name="ec_start", properties={"success": False})
  144. # Keep the script running until it receives a kill signal
  145. try:
  146. api_process.wait()
  147. ui_process.wait()
  148. except KeyboardInterrupt:
  149. console.print("\n🛑 [bold yellow]Stopping server...[/bold yellow]")
  150. @cli.command()
  151. @click.option("--template", default="fly.io", help="The template to use.")
  152. @click.argument("extra_args", nargs=-1, type=click.UNPROCESSED)
  153. def create(template, extra_args):
  154. anonymous_telemetry.capture(event_name="ec_create", properties={"template_used": template})
  155. template_dir = template
  156. if "/" in template_dir:
  157. template_dir = template.split("/")[1]
  158. src_path = get_pkg_path_from_name(template_dir)
  159. shutil.copytree(src_path, os.getcwd(), dirs_exist_ok=True)
  160. console.print(f"✅ [bold green]Successfully created app from template '{template}'.[/bold green]")
  161. if template == "fly.io":
  162. setup_fly_io_app(extra_args)
  163. elif template == "modal.com":
  164. setup_modal_com_app(extra_args)
  165. elif template == "render.com":
  166. setup_render_com_app()
  167. elif template == "streamlit.io":
  168. setup_streamlit_io_app()
  169. elif template == "gradio.app":
  170. setup_gradio_app()
  171. elif template == "hf/gradio.app" or template == "hf/streamlit.io":
  172. setup_hf_app()
  173. else:
  174. raise ValueError(f"Unknown template '{template}'.")
  175. embedchain_config = {"provider": template}
  176. with open("embedchain.json", "w") as file:
  177. json.dump(embedchain_config, file, indent=4)
  178. console.print(
  179. f"🎉 [green]All done! Successfully created `embedchain.json` with '{template}' as provider.[/green]"
  180. )
  181. def run_dev_fly_io(debug, host, port):
  182. uvicorn_command = ["uvicorn", "app:app"]
  183. if debug:
  184. uvicorn_command.append("--reload")
  185. uvicorn_command.extend(["--host", host, "--port", str(port)])
  186. try:
  187. console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(uvicorn_command)}[/bold cyan]")
  188. subprocess.run(uvicorn_command, check=True)
  189. except subprocess.CalledProcessError as e:
  190. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  191. except KeyboardInterrupt:
  192. console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
  193. def run_dev_modal_com():
  194. modal_run_cmd = ["modal", "serve", "app"]
  195. try:
  196. console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(modal_run_cmd)}[/bold cyan]")
  197. subprocess.run(modal_run_cmd, check=True)
  198. except subprocess.CalledProcessError as e:
  199. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  200. except KeyboardInterrupt:
  201. console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
  202. def run_dev_streamlit_io():
  203. streamlit_run_cmd = ["streamlit", "run", "app.py"]
  204. try:
  205. console.print(f"🚀 [bold cyan]Running Streamlit app with command: {' '.join(streamlit_run_cmd)}[/bold cyan]")
  206. subprocess.run(streamlit_run_cmd, check=True)
  207. except subprocess.CalledProcessError as e:
  208. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  209. except KeyboardInterrupt:
  210. console.print("\n🛑 [bold yellow]Streamlit server stopped[/bold yellow]")
  211. def run_dev_render_com(debug, host, port):
  212. uvicorn_command = ["uvicorn", "app:app"]
  213. if debug:
  214. uvicorn_command.append("--reload")
  215. uvicorn_command.extend(["--host", host, "--port", str(port)])
  216. try:
  217. console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(uvicorn_command)}[/bold cyan]")
  218. subprocess.run(uvicorn_command, check=True)
  219. except subprocess.CalledProcessError as e:
  220. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  221. except KeyboardInterrupt:
  222. console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
  223. def run_dev_gradio():
  224. gradio_run_cmd = ["gradio", "app.py"]
  225. try:
  226. console.print(f"🚀 [bold cyan]Running Gradio app with command: {' '.join(gradio_run_cmd)}[/bold cyan]")
  227. subprocess.run(gradio_run_cmd, check=True)
  228. except subprocess.CalledProcessError as e:
  229. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  230. except KeyboardInterrupt:
  231. console.print("\n🛑 [bold yellow]Gradio server stopped[/bold yellow]")
  232. @cli.command()
  233. @click.option("--debug", is_flag=True, help="Enable or disable debug mode.")
  234. @click.option("--host", default="127.0.0.1", help="The host address to run the FastAPI app on.")
  235. @click.option("--port", default=8000, help="The port to run the FastAPI app on.")
  236. def dev(debug, host, port):
  237. template = ""
  238. with open("embedchain.json", "r") as file:
  239. embedchain_config = json.load(file)
  240. template = embedchain_config["provider"]
  241. anonymous_telemetry.capture(event_name="ec_dev", properties={"template_used": template})
  242. if template == "fly.io":
  243. run_dev_fly_io(debug, host, port)
  244. elif template == "modal.com":
  245. run_dev_modal_com()
  246. elif template == "render.com":
  247. run_dev_render_com(debug, host, port)
  248. elif template == "streamlit.io" or template == "hf/streamlit.io":
  249. run_dev_streamlit_io()
  250. elif template == "gradio.app" or template == "hf/gradio.app":
  251. run_dev_gradio()
  252. else:
  253. raise ValueError(f"Unknown template '{template}'.")
  254. @cli.command()
  255. def deploy():
  256. # Check for platform-specific files
  257. template = ""
  258. ec_app_name = ""
  259. with open("embedchain.json", "r") as file:
  260. embedchain_config = json.load(file)
  261. ec_app_name = embedchain_config["name"] if "name" in embedchain_config else None
  262. template = embedchain_config["provider"]
  263. anonymous_telemetry.capture(event_name="ec_deploy", properties={"template_used": template})
  264. if template == "fly.io":
  265. deploy_fly()
  266. elif template == "modal.com":
  267. deploy_modal()
  268. elif template == "render.com":
  269. deploy_render()
  270. elif template == "streamlit.io":
  271. deploy_streamlit()
  272. elif template == "gradio.app":
  273. deploy_gradio_app()
  274. elif template.startswith("hf/"):
  275. deploy_hf_spaces(ec_app_name)
  276. else:
  277. console.print("❌ [bold red]No recognized deployment platform found.[/bold red]")