cli.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318
  1. import json
  2. import os
  3. import shutil
  4. import signal
  5. import subprocess
  6. import sys
  7. import tempfile
  8. import time
  9. import zipfile
  10. from pathlib import Path
  11. import click
  12. import requests
  13. from rich.console import Console
  14. from embedchain.telemetry.posthog import AnonymousTelemetry
  15. from embedchain.utils.cli import (deploy_fly, deploy_gradio_app,
  16. deploy_hf_spaces, deploy_modal,
  17. deploy_render, deploy_streamlit,
  18. get_pkg_path_from_name, setup_fly_io_app,
  19. setup_gradio_app, setup_hf_app,
  20. setup_modal_com_app, setup_render_com_app,
  21. setup_streamlit_io_app)
  22. console = Console()
  23. api_process = None
  24. ui_process = None
  25. anonymous_telemetry = AnonymousTelemetry()
  26. def signal_handler(sig, frame):
  27. """Signal handler to catch termination signals and kill server processes."""
  28. global api_process, ui_process
  29. console.print("\n🛑 [bold yellow]Stopping servers...[/bold yellow]")
  30. if api_process:
  31. api_process.terminate()
  32. console.print("🛑 [bold yellow]API server stopped.[/bold yellow]")
  33. if ui_process:
  34. ui_process.terminate()
  35. console.print("🛑 [bold yellow]UI server stopped.[/bold yellow]")
  36. sys.exit(0)
  37. @click.group()
  38. def cli():
  39. pass
  40. @cli.command()
  41. @click.argument("app_name")
  42. @click.pass_context
  43. def create_app(ctx, app_name):
  44. if Path(app_name).exists():
  45. console.print(
  46. f"❌ [red]Directory '{app_name}' already exists. Try using a new directory name, or remove it.[/red]"
  47. )
  48. return
  49. os.makedirs(app_name)
  50. os.chdir(app_name)
  51. # Step 1: Download the zip file
  52. zip_url = "http://github.com/embedchain/ec-admin/archive/main.zip"
  53. console.print(f"Creating a new embedchain app in [green]{Path().resolve()}[/green]\n")
  54. try:
  55. response = requests.get(zip_url)
  56. response.raise_for_status()
  57. with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
  58. tmp_file.write(response.content)
  59. zip_file_path = tmp_file.name
  60. console.print("✅ [bold green]Fetched template successfully.[/bold green]")
  61. except requests.RequestException as e:
  62. console.print(f"❌ [bold red]Failed to download zip file: {e}[/bold red]")
  63. anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": False})
  64. return
  65. # Step 2: Extract the zip file
  66. try:
  67. with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
  68. # Get the name of the root directory inside the zip file
  69. root_dir = Path(zip_ref.namelist()[0])
  70. for member in zip_ref.infolist():
  71. # Build the path to extract the file to, skipping the root directory
  72. target_file = Path(member.filename).relative_to(root_dir)
  73. source_file = zip_ref.open(member, "r")
  74. if member.is_dir():
  75. # Create directory if it doesn't exist
  76. os.makedirs(target_file, exist_ok=True)
  77. else:
  78. with open(target_file, "wb") as file:
  79. # Write the file
  80. shutil.copyfileobj(source_file, file)
  81. console.print("✅ [bold green]Extracted zip file successfully.[/bold green]")
  82. anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": True})
  83. except zipfile.BadZipFile:
  84. console.print("❌ [bold red]Error in extracting zip file. The file might be corrupted.[/bold red]")
  85. anonymous_telemetry.capture(event_name="ec_create_app", properties={"success": False})
  86. return
  87. ctx.invoke(install_reqs)
  88. @cli.command()
  89. def install_reqs():
  90. try:
  91. console.print("Installing python requirements...\n")
  92. time.sleep(2)
  93. os.chdir("api")
  94. subprocess.run(["pip", "install", "-r", "requirements.txt"], check=True)
  95. os.chdir("..")
  96. console.print("\n ✅ [bold green]Installed API requirements successfully.[/bold green]\n")
  97. except Exception as e:
  98. console.print(f"❌ [bold red]Failed to install API requirements: {e}[/bold red]")
  99. anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": False})
  100. return
  101. try:
  102. os.chdir("ui")
  103. subprocess.run(["yarn"], check=True)
  104. console.print("\n✅ [bold green]Successfully installed frontend requirements.[/bold green]")
  105. anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": True})
  106. except Exception as e:
  107. console.print(f"❌ [bold red]Failed to install frontend requirements. Error: {e}[/bold red]")
  108. anonymous_telemetry.capture(event_name="ec_install_reqs", properties={"success": False})
  109. @cli.command()
  110. def start():
  111. # Set up signal handling
  112. signal.signal(signal.SIGINT, signal_handler)
  113. signal.signal(signal.SIGTERM, signal_handler)
  114. # Step 1: Start the API server
  115. try:
  116. os.chdir("api")
  117. api_process = subprocess.Popen(["python", "-m", "main"], stdout=None, stderr=None)
  118. os.chdir("..")
  119. console.print("✅ [bold green]API server started successfully.[/bold green]")
  120. except Exception as e:
  121. console.print(f"❌ [bold red]Failed to start the API server: {e}[/bold red]")
  122. anonymous_telemetry.capture(event_name="ec_start", properties={"success": False})
  123. return
  124. # Sleep for 2 seconds to give the user time to read the message
  125. time.sleep(2)
  126. # Step 2: Install UI requirements and start the UI server
  127. try:
  128. os.chdir("ui")
  129. subprocess.run(["yarn"], check=True)
  130. ui_process = subprocess.Popen(["yarn", "dev"])
  131. console.print("✅ [bold green]UI server started successfully.[/bold green]")
  132. anonymous_telemetry.capture(event_name="ec_start", properties={"success": True})
  133. except Exception as e:
  134. console.print(f"❌ [bold red]Failed to start the UI server: {e}[/bold red]")
  135. anonymous_telemetry.capture(event_name="ec_start", properties={"success": False})
  136. # Keep the script running until it receives a kill signal
  137. try:
  138. api_process.wait()
  139. ui_process.wait()
  140. except KeyboardInterrupt:
  141. console.print("\n🛑 [bold yellow]Stopping server...[/bold yellow]")
  142. @cli.command()
  143. @click.option("--template", default="fly.io", help="The template to use.")
  144. @click.argument("extra_args", nargs=-1, type=click.UNPROCESSED)
  145. def create(template, extra_args):
  146. anonymous_telemetry.capture(event_name="ec_create", properties={"template_used": template})
  147. template_dir = template
  148. if "/" in template_dir:
  149. template_dir = template.split("/")[1]
  150. src_path = get_pkg_path_from_name(template_dir)
  151. shutil.copytree(src_path, os.getcwd(), dirs_exist_ok=True)
  152. console.print(f"✅ [bold green]Successfully created app from template '{template}'.[/bold green]")
  153. if template == "fly.io":
  154. setup_fly_io_app(extra_args)
  155. elif template == "modal.com":
  156. setup_modal_com_app(extra_args)
  157. elif template == "render.com":
  158. setup_render_com_app()
  159. elif template == "streamlit.io":
  160. setup_streamlit_io_app()
  161. elif template == "gradio.app":
  162. setup_gradio_app()
  163. elif template == "hf/gradio.app" or template == "hf/streamlit.io":
  164. setup_hf_app()
  165. else:
  166. raise ValueError(f"Unknown template '{template}'.")
  167. embedchain_config = {"provider": template}
  168. with open("embedchain.json", "w") as file:
  169. json.dump(embedchain_config, file, indent=4)
  170. console.print(
  171. f"🎉 [green]All done! Successfully created `embedchain.json` with '{template}' as provider.[/green]"
  172. )
  173. def run_dev_fly_io(debug, host, port):
  174. uvicorn_command = ["uvicorn", "app:app"]
  175. if debug:
  176. uvicorn_command.append("--reload")
  177. uvicorn_command.extend(["--host", host, "--port", str(port)])
  178. try:
  179. console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(uvicorn_command)}[/bold cyan]")
  180. subprocess.run(uvicorn_command, check=True)
  181. except subprocess.CalledProcessError as e:
  182. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  183. except KeyboardInterrupt:
  184. console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
  185. def run_dev_modal_com():
  186. modal_run_cmd = ["modal", "serve", "app"]
  187. try:
  188. console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(modal_run_cmd)}[/bold cyan]")
  189. subprocess.run(modal_run_cmd, check=True)
  190. except subprocess.CalledProcessError as e:
  191. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  192. except KeyboardInterrupt:
  193. console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
  194. def run_dev_streamlit_io():
  195. streamlit_run_cmd = ["streamlit", "run", "app.py"]
  196. try:
  197. console.print(f"🚀 [bold cyan]Running Streamlit app with command: {' '.join(streamlit_run_cmd)}[/bold cyan]")
  198. subprocess.run(streamlit_run_cmd, check=True)
  199. except subprocess.CalledProcessError as e:
  200. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  201. except KeyboardInterrupt:
  202. console.print("\n🛑 [bold yellow]Streamlit server stopped[/bold yellow]")
  203. def run_dev_render_com(debug, host, port):
  204. uvicorn_command = ["uvicorn", "app:app"]
  205. if debug:
  206. uvicorn_command.append("--reload")
  207. uvicorn_command.extend(["--host", host, "--port", str(port)])
  208. try:
  209. console.print(f"🚀 [bold cyan]Running FastAPI app with command: {' '.join(uvicorn_command)}[/bold cyan]")
  210. subprocess.run(uvicorn_command, check=True)
  211. except subprocess.CalledProcessError as e:
  212. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  213. except KeyboardInterrupt:
  214. console.print("\n🛑 [bold yellow]FastAPI server stopped[/bold yellow]")
  215. def run_dev_gradio():
  216. gradio_run_cmd = ["gradio", "app.py"]
  217. try:
  218. console.print(f"🚀 [bold cyan]Running Gradio app with command: {' '.join(gradio_run_cmd)}[/bold cyan]")
  219. subprocess.run(gradio_run_cmd, check=True)
  220. except subprocess.CalledProcessError as e:
  221. console.print(f"❌ [bold red]An error occurred: {e}[/bold red]")
  222. except KeyboardInterrupt:
  223. console.print("\n🛑 [bold yellow]Gradio server stopped[/bold yellow]")
  224. @cli.command()
  225. @click.option("--debug", is_flag=True, help="Enable or disable debug mode.")
  226. @click.option("--host", default="127.0.0.1", help="The host address to run the FastAPI app on.")
  227. @click.option("--port", default=8000, help="The port to run the FastAPI app on.")
  228. def dev(debug, host, port):
  229. template = ""
  230. with open("embedchain.json", "r") as file:
  231. embedchain_config = json.load(file)
  232. template = embedchain_config["provider"]
  233. anonymous_telemetry.capture(event_name="ec_dev", properties={"template_used": template})
  234. if template == "fly.io":
  235. run_dev_fly_io(debug, host, port)
  236. elif template == "modal.com":
  237. run_dev_modal_com()
  238. elif template == "render.com":
  239. run_dev_render_com(debug, host, port)
  240. elif template == "streamlit.io" or template == "hf/streamlit.app":
  241. run_dev_streamlit_io()
  242. elif template == "gradio.app" or template == "hf/gradio.app":
  243. run_dev_gradio()
  244. else:
  245. raise ValueError(f"Unknown template '{template}'.")
  246. @cli.command()
  247. def deploy():
  248. # Check for platform-specific files
  249. template = ""
  250. ec_app_name = ""
  251. with open("embedchain.json", "r") as file:
  252. embedchain_config = json.load(file)
  253. ec_app_name = embedchain_config["name"] if "name" in embedchain_config else None
  254. template = embedchain_config["provider"]
  255. anonymous_telemetry.capture(event_name="ec_deploy", properties={"template_used": template})
  256. if template == "fly.io":
  257. deploy_fly()
  258. elif template == "modal.com":
  259. deploy_modal()
  260. elif template == "render.com":
  261. deploy_render()
  262. elif template == "streamlit.io":
  263. deploy_streamlit()
  264. elif template == "gradio.app":
  265. deploy_gradio_app()
  266. elif template.startswith("hf/"):
  267. deploy_hf_spaces(ec_app_name)
  268. else:
  269. console.print("❌ [bold red]No recognized deployment platform found.[/bold red]")