// Copyright (c) EZSCALE. // SPDX-License-Identifier: MPL-2.0 package provider import ( "context" "encoding/json" "errors" "fmt" "terraform-provider-virtfusion/internal/client" "github.com/hashicorp/terraform-plugin-framework/attr" "github.com/hashicorp/terraform-plugin-framework/diag" "github.com/hashicorp/terraform-plugin-framework/resource" "github.com/hashicorp/terraform-plugin-framework/resource/schema" "github.com/hashicorp/terraform-plugin-framework/resource/schema/planmodifier" "github.com/hashicorp/terraform-plugin-framework/resource/schema/stringdefault" "github.com/hashicorp/terraform-plugin-framework/resource/schema/stringplanmodifier" "github.com/hashicorp/terraform-plugin-framework/types" ) var ( _ resource.Resource = &ServerFirewallResource{} _ resource.ResourceWithConfigure = &ServerFirewallResource{} ) // NewServerFirewallResource returns a new resource for managing server firewalls. func NewServerFirewallResource() resource.Resource { return &ServerFirewallResource{} } // ServerFirewallResource defines the resource implementation. type ServerFirewallResource struct { client *client.Client } // ServerFirewallResourceModel describes the resource data model. type ServerFirewallResourceModel struct { ID types.String `tfsdk:"id"` ServerID types.Int64 `tfsdk:"server_id"` InterfaceName types.String `tfsdk:"interface_name"` Rules types.List `tfsdk:"rules"` } // FirewallRuleModel describes a single firewall rule. type FirewallRuleModel struct { Action types.String `tfsdk:"action"` Direction types.String `tfsdk:"direction"` Protocol types.String `tfsdk:"protocol"` Port types.String `tfsdk:"port"` IP types.String `tfsdk:"ip"` } func (r *ServerFirewallResource) Metadata(_ context.Context, req resource.MetadataRequest, resp *resource.MetadataResponse) { resp.TypeName = req.ProviderTypeName + "_server_firewall" } func (r *ServerFirewallResource) Schema(_ context.Context, _ resource.SchemaRequest, resp *resource.SchemaResponse) { resp.Schema = schema.Schema{ MarkdownDescription: "Manages a VirtFusion server firewall.", Attributes: map[string]schema.Attribute{ "id": schema.StringAttribute{ MarkdownDescription: "Composite identifier in the format `server_id/interface_name`.", Computed: true, PlanModifiers: []planmodifier.String{ stringplanmodifier.UseStateForUnknown(), }, }, "server_id": schema.Int64Attribute{ MarkdownDescription: "The ID of the server.", Required: true, }, "interface_name": schema.StringAttribute{ MarkdownDescription: "The network interface name. Defaults to `eth0`.", Optional: true, Computed: true, Default: stringdefault.StaticString("eth0"), }, "rules": schema.ListNestedAttribute{ MarkdownDescription: "The firewall rules.", Optional: true, NestedObject: schema.NestedAttributeObject{ Attributes: map[string]schema.Attribute{ "action": schema.StringAttribute{ MarkdownDescription: "The action for the rule (e.g. `accept`, `drop`).", Required: true, }, "direction": schema.StringAttribute{ MarkdownDescription: "The direction for the rule (e.g. `in`, `out`).", Required: true, }, "protocol": schema.StringAttribute{ MarkdownDescription: "The protocol for the rule (e.g. `tcp`, `udp`).", Required: true, }, "port": schema.StringAttribute{ MarkdownDescription: "The port or port range for the rule.", Required: true, }, "ip": schema.StringAttribute{ MarkdownDescription: "The IP address or CIDR for the rule.", Required: true, }, }, }, }, }, } } func (r *ServerFirewallResource) Configure(_ context.Context, req resource.ConfigureRequest, resp *resource.ConfigureResponse) { if req.ProviderData == nil { return } c, ok := req.ProviderData.(*client.Client) if !ok { resp.Diagnostics.AddError( "Unexpected Resource Configure Type", fmt.Sprintf("Expected *client.Client, got: %T.", req.ProviderData), ) return } r.client = c } func (r *ServerFirewallResource) Create(ctx context.Context, req resource.CreateRequest, resp *resource.CreateResponse) { var data ServerFirewallResourceModel resp.Diagnostics.Append(req.Plan.Get(ctx, &data)...) if resp.Diagnostics.HasError() { return } serverID := data.ServerID.ValueInt64() iface := data.InterfaceName.ValueString() // Enable the firewall _, err := r.client.Post(ctx, fmt.Sprintf("/servers/%d/firewall/%s/enable", serverID, iface), nil) if err != nil { resp.Diagnostics.AddError("Error enabling server firewall", err.Error()) return } // Set rules if provided rules, diags := r.extractRules(ctx, data) resp.Diagnostics.Append(diags...) if resp.Diagnostics.HasError() { return } if len(rules) > 0 { rulesReq := client.FirewallSetRulesRequest{Rules: rules} _, err = r.client.Post(ctx, fmt.Sprintf("/servers/%d/firewall/%s/rules", serverID, iface), rulesReq) if err != nil { resp.Diagnostics.AddError("Error setting firewall rules", err.Error()) return } } data.ID = types.StringValue(fmt.Sprintf("%d/%s", serverID, iface)) resp.Diagnostics.Append(resp.State.Set(ctx, &data)...) } func (r *ServerFirewallResource) Read(ctx context.Context, req resource.ReadRequest, resp *resource.ReadResponse) { var data ServerFirewallResourceModel resp.Diagnostics.Append(req.State.Get(ctx, &data)...) if resp.Diagnostics.HasError() { return } serverID := data.ServerID.ValueInt64() iface := data.InterfaceName.ValueString() result, err := r.client.Get(ctx, fmt.Sprintf("/servers/%d/firewall/%s", serverID, iface)) if err != nil { var apiErr *client.APIError if errors.As(err, &apiErr) && apiErr.IsNotFound() { resp.State.RemoveResource(ctx) return } resp.Diagnostics.AddError("Error reading server firewall", err.Error()) return } var fwResp client.FirewallResponse if err := json.Unmarshal(result, &fwResp); err != nil { resp.Diagnostics.AddError("Error parsing firewall response", err.Error()) return } // If the firewall is not enabled, remove from state if !fwResp.Data.Enabled { resp.State.RemoveResource(ctx) return } data.ID = types.StringValue(fmt.Sprintf("%d/%s", serverID, iface)) // Map API rules to the model ruleObjects := make([]attr.Value, len(fwResp.Data.Rules)) for i, rule := range fwResp.Data.Rules { ruleObj, diags := types.ObjectValue( firewallRuleAttrTypes(), map[string]attr.Value{ "action": types.StringValue(rule.Action), "direction": types.StringValue(rule.Direction), "protocol": types.StringValue(rule.Protocol), "port": types.StringValue(rule.Port), "ip": types.StringValue(rule.IP), }, ) resp.Diagnostics.Append(diags...) if resp.Diagnostics.HasError() { return } ruleObjects[i] = ruleObj } rulesList, diags := types.ListValue(types.ObjectType{AttrTypes: firewallRuleAttrTypes()}, ruleObjects) resp.Diagnostics.Append(diags...) if resp.Diagnostics.HasError() { return } data.Rules = rulesList resp.Diagnostics.Append(resp.State.Set(ctx, &data)...) } func (r *ServerFirewallResource) Update(ctx context.Context, req resource.UpdateRequest, resp *resource.UpdateResponse) { var data ServerFirewallResourceModel resp.Diagnostics.Append(req.Plan.Get(ctx, &data)...) if resp.Diagnostics.HasError() { return } serverID := data.ServerID.ValueInt64() iface := data.InterfaceName.ValueString() rules, diags := r.extractRules(ctx, data) resp.Diagnostics.Append(diags...) if resp.Diagnostics.HasError() { return } rulesReq := client.FirewallSetRulesRequest{Rules: rules} _, err := r.client.Post(ctx, fmt.Sprintf("/servers/%d/firewall/%s/rules", serverID, iface), rulesReq) if err != nil { resp.Diagnostics.AddError("Error updating firewall rules", err.Error()) return } data.ID = types.StringValue(fmt.Sprintf("%d/%s", serverID, iface)) resp.Diagnostics.Append(resp.State.Set(ctx, &data)...) } func (r *ServerFirewallResource) Delete(ctx context.Context, req resource.DeleteRequest, resp *resource.DeleteResponse) { var data ServerFirewallResourceModel resp.Diagnostics.Append(req.State.Get(ctx, &data)...) if resp.Diagnostics.HasError() { return } serverID := data.ServerID.ValueInt64() iface := data.InterfaceName.ValueString() _, err := r.client.Post(ctx, fmt.Sprintf("/servers/%d/firewall/%s/disable", serverID, iface), nil) if err != nil { var apiErr *client.APIError if errors.As(err, &apiErr) && apiErr.IsNotFound() { return } resp.Diagnostics.AddError("Error disabling server firewall", err.Error()) } } // extractRules converts the rules list from the model into client.FirewallRule slice. func (r *ServerFirewallResource) extractRules(ctx context.Context, data ServerFirewallResourceModel) ([]client.FirewallRule, diag.Diagnostics) { var diags diag.Diagnostics if data.Rules.IsNull() || data.Rules.IsUnknown() { return nil, diags } var ruleModels []FirewallRuleModel diags.Append(data.Rules.ElementsAs(ctx, &ruleModels, false)...) if diags.HasError() { return nil, diags } rules := make([]client.FirewallRule, len(ruleModels)) for i, rm := range ruleModels { rules[i] = client.FirewallRule{ Action: rm.Action.ValueString(), Direction: rm.Direction.ValueString(), Protocol: rm.Protocol.ValueString(), Port: rm.Port.ValueString(), IP: rm.IP.ValueString(), } } return rules, diags } // firewallRuleAttrTypes returns the attribute types for a firewall rule object. func firewallRuleAttrTypes() map[string]attr.Type { return map[string]attr.Type{ "action": types.StringType, "direction": types.StringType, "protocol": types.StringType, "port": types.StringType, "ip": types.StringType, } }